grammar-inference-engine/make_charts.py

71 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import matplotlib.pyplot as plt
import numpy as np
plt.xkcd(scale=0.7, length=60, randomness=2)
FIG_W = 8
FIG_H = 5
# ── Chart 1: Context cost vs examples ──
fig1, ax1 = plt.subplots(figsize=(FIG_W, FIG_H))
N = [1, 5, 15, 36]
raw = [100, 500, 1500, 3600] # ~100 tokens/example
dervish = [40, 60, 60, 200] # grammar grows only when diversity grows
x = np.arange(len(N))
w = 0.35
bars1 = ax1.bar(x - w/2, raw, w, label='Raw examples', color='#e74c3c', alpha=0.85)
bars2 = ax1.bar(x + w/2, dervish, w, label='Dervish grammar', color='#3498db', alpha=0.85)
ax1.set_xticks(x)
ax1.set_xticklabels([f'{n} examples' for n in N])
ax1.set_ylabel('Tokens needed in context')
ax1.set_title('Context cost: raw examples vs Dervish grammar')
ax1.legend(frameon=False)
for bar in bars1:
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 80,
f'{int(bar.get_height())}', ha='center', va='bottom', fontsize=9)
for bar in bars2:
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 80,
f'{int(bar.get_height())}', ha='center', va='bottom', fontsize=9)
ax1.set_ylim(0, 4500)
fig1.tight_layout()
fig1.savefig('chart_context_cost.png', dpi=200)
plt.close(fig1)
# ── Chart 2: Tokens — Without vs With Dervish (per dataset) ──
fig2, ax2 = plt.subplots(figsize=(FIG_W, FIG_H))
datasets = ['Ansible Galaxy\n(15 roles)', 'Helm\n(6 configs)', 'Go lint\n(6 jobs)']
without = [5000, 3000, 900]
with_derv = [60, 40, 30]
ratios = [f'{int(w/d)}×' for w, d in zip(without, with_derv)]
x2 = np.arange(len(datasets))
w2 = 0.3
bw = ax2.bar(x2 - w2/2, without, w2, label='Without Dervish', color='#e74c3c', alpha=0.85)
bd = ax2.bar(x2 + w2/2, with_derv, w2, label='With Dervish', color='#3498db', alpha=0.85)
ax2.set_xticks(x2)
ax2.set_xticklabels(datasets)
ax2.set_ylabel('Tokens')
ax2.set_title('Token savings per dataset')
ax2.legend(frameon=False)
ax2.set_yscale('log')
ax2.set_ylim(5, 30000)
# Label compression ratios
for i, (r, wbar, dbar) in enumerate(zip(ratios, bw, bd)):
ax2.text(x2[i], without[i] * 1.3, r, ha='center', va='bottom', fontsize=11, fontweight='bold',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white', edgecolor='gray', alpha=0.8))
fig2.tight_layout()
fig2.savefig('chart_token_savings.png', dpi=200)
plt.close(fig2)
print("Charts saved: chart_context_cost.png, chart_token_savings.png")