Interactive explainer and manim visualization of Percepta's 'Can LLMs Be Computers?'
- Interactive web app (interactive/) explaining how transformer weights execute deterministic WASM programs: softmax sharpening, 2D parabola trick for exact memory lookup, stack machine step-through, and full execution trace visualization - Manim animation script (manim_project/scene.py) with 9 scenes covering the article's key concepts Co-authored-by: Ona <no-reply@ona.com>
This commit is contained in:
17
.gitignore
vendored
Normal file
17
.gitignore
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
# Manim build artifacts
|
||||
manim_project/media/
|
||||
manim_project/output.mp4
|
||||
manim_project/__pycache__/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
# Ona
|
||||
.ona/
|
||||
|
||||
# Devcontainer
|
||||
.devcontainer/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
818
interactive/app.js
Normal file
818
interactive/app.js
Normal file
@@ -0,0 +1,818 @@
|
||||
// ── Tab Navigation ──
|
||||
document.querySelectorAll('.tab').forEach(btn => {
|
||||
btn.addEventListener('click', () => {
|
||||
document.querySelectorAll('.tab').forEach(b => b.classList.remove('active'));
|
||||
document.querySelectorAll('.panel').forEach(p => p.classList.remove('active'));
|
||||
btn.classList.add('active');
|
||||
document.getElementById(btn.dataset.tab).classList.add('active');
|
||||
});
|
||||
});
|
||||
document.querySelectorAll('.next-btn').forEach(btn => {
|
||||
btn.addEventListener('click', () => {
|
||||
document.querySelector(`.tab[data-tab="${btn.dataset.next}"]`).click();
|
||||
window.scrollTo({ top: 0, behavior: 'smooth' });
|
||||
});
|
||||
});
|
||||
|
||||
// ── Helpers ──
|
||||
function softmax(scores) {
|
||||
const max = Math.max(...scores);
|
||||
const exps = scores.map(s => Math.exp(s - max));
|
||||
const sum = exps.reduce((a, b) => a + b, 0);
|
||||
return exps.map(e => e / sum);
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════
|
||||
// TAB 1: Softmax Temperature Demo
|
||||
// ══════════════════════════════════════════
|
||||
const rawScores = [1.2, 0.5, 3.8, 0.9, 1.5]; // index 2 is the "target"
|
||||
const scoreLabels = ['slot 0', 'slot 1', 'slot 2', 'slot 3', 'slot 4'];
|
||||
|
||||
function renderSoftmaxBars(temp) {
|
||||
const scaled = rawScores.map(s => s * temp);
|
||||
const weights = softmax(scaled);
|
||||
const maxW = Math.max(...weights);
|
||||
const container = document.getElementById('softmaxBars');
|
||||
container.innerHTML = '';
|
||||
weights.forEach((w, i) => {
|
||||
const col = document.createElement('div');
|
||||
col.className = 'bar-col';
|
||||
const wrapper = document.createElement('div');
|
||||
wrapper.className = 'bar-wrapper';
|
||||
const bar = document.createElement('div');
|
||||
bar.className = 'bar' + (w === maxW ? ' winner' : '');
|
||||
bar.style.height = (w * 130) + 'px';
|
||||
wrapper.appendChild(bar);
|
||||
const val = document.createElement('div');
|
||||
val.className = 'bar-value';
|
||||
val.textContent = (w * 100).toFixed(1) + '%';
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'bar-label';
|
||||
lbl.textContent = scoreLabels[i] + (i === 2 ? ' ★' : '');
|
||||
col.append(val, wrapper, lbl);
|
||||
container.appendChild(col);
|
||||
});
|
||||
const insight = document.getElementById('tempInsight');
|
||||
if (temp < 5) insight.textContent = 'At low temperature, attention is spread out — fuzzy, not useful for exact computation.';
|
||||
else if (temp < 20) insight.textContent = 'Getting sharper! The target slot is winning, but there\'s still leakage to other slots.';
|
||||
else insight.textContent = 'Nearly 100% on the target. The softmax now acts like an exact array read — this is how weights produce deterministic lookups.';
|
||||
}
|
||||
|
||||
document.getElementById('tempSlider').addEventListener('input', e => {
|
||||
const v = +e.target.value;
|
||||
document.getElementById('tempVal').textContent = v;
|
||||
renderSoftmaxBars(v);
|
||||
});
|
||||
renderSoftmaxBars(1);
|
||||
|
||||
// ══════════════════════════════════════════
|
||||
// TAB 2: Memory Lookup via Attention
|
||||
// ══════════════════════════════════════════
|
||||
const memValues = [42, 17, 99, 8, 55, 73];
|
||||
let queryTarget = 2;
|
||||
|
||||
function makeColVec(values, cls, label) {
|
||||
// Returns a DOM element showing a column vector with bracket notation
|
||||
const wrap = document.createElement('div');
|
||||
wrap.className = 'col-vec ' + cls;
|
||||
wrap.innerHTML = '<div class="bracket"></div><div class="bracket-r"></div>';
|
||||
values.forEach(v => {
|
||||
const cell = document.createElement('div');
|
||||
cell.className = 'cell';
|
||||
cell.textContent = v;
|
||||
wrap.appendChild(cell);
|
||||
});
|
||||
if (label) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'vec-label';
|
||||
lbl.innerHTML = label;
|
||||
wrap.appendChild(lbl);
|
||||
}
|
||||
return wrap;
|
||||
}
|
||||
|
||||
function renderMemory() {
|
||||
// Memory slots
|
||||
const container = document.getElementById('memorySlots');
|
||||
container.innerHTML = '';
|
||||
memValues.forEach((v, i) => {
|
||||
const slot = document.createElement('div');
|
||||
slot.className = 'mem-slot' + (i === queryTarget ? ' active' : '');
|
||||
slot.innerHTML = `<div class="idx">addr ${i}</div><div class="val">${v}</div>`;
|
||||
slot.addEventListener('click', () => { queryTarget = i; renderMemory(); });
|
||||
container.appendChild(slot);
|
||||
});
|
||||
|
||||
// Query vector
|
||||
const qEl = document.getElementById('queryVec');
|
||||
qEl.innerHTML = '';
|
||||
const qVec = makeColVec([queryTarget, 1], 'query', `q`);
|
||||
const qNote = document.createElement('span');
|
||||
qNote.style.cssText = 'font-size:0.82rem;color:var(--dim);margin-left:12px';
|
||||
qNote.innerHTML = `= (i, 1) where i = ${queryTarget} ← "I want to read address ${queryTarget}"`;
|
||||
qEl.appendChild(qVec);
|
||||
qEl.appendChild(qNote);
|
||||
|
||||
// Key vectors + dot products
|
||||
const vecEl = document.getElementById('vecColumns');
|
||||
vecEl.innerHTML = '';
|
||||
|
||||
const scores = memValues.map((_, j) => 2 * queryTarget * j - j * j);
|
||||
const maxScore = Math.max(...scores);
|
||||
const minScore = Math.min(...scores);
|
||||
const scoreRange = maxScore - minScore || 1;
|
||||
const weights = softmax(scores.map(s => s * 10));
|
||||
|
||||
memValues.forEach((val, j) => {
|
||||
const isWin = j === queryTarget;
|
||||
const k = [2 * j, -(j * j)];
|
||||
|
||||
const group = document.createElement('div');
|
||||
group.className = 'vec-group' + (isWin ? ' winner' : '');
|
||||
|
||||
// Column vector
|
||||
const vec = makeColVec(k, isWin ? 'winner' : '', `k<sub>${j}</sub>`);
|
||||
group.appendChild(vec);
|
||||
|
||||
// Dot product computation
|
||||
const comp = document.createElement('div');
|
||||
comp.className = 'dot-computation';
|
||||
comp.innerHTML = `${queryTarget}×${k[0]} + 1×${k[1] >= 0 ? k[1] : '(' + k[1] + ')'}`;
|
||||
group.appendChild(comp);
|
||||
|
||||
// Score + weight
|
||||
const dpLine = document.createElement('div');
|
||||
dpLine.className = 'dot-product-line';
|
||||
dpLine.innerHTML = `<span class="dp-val">${scores[j]}</span><span class="dp-weight">${(weights[j] * 100).toFixed(1)}%</span>`;
|
||||
group.appendChild(dpLine);
|
||||
|
||||
// Mini bar
|
||||
const bar = document.createElement('div');
|
||||
bar.className = 'dp-bar-mini';
|
||||
bar.style.width = Math.max(2, ((scores[j] - minScore) / scoreRange) * 60) + 'px';
|
||||
group.appendChild(bar);
|
||||
|
||||
// Value stored
|
||||
const valLabel = document.createElement('div');
|
||||
valLabel.style.cssText = `font-size:0.7rem;margin-top:4px;color:${isWin ? 'var(--gold)' : 'var(--dim)'};font-family:monospace`;
|
||||
valLabel.textContent = `val=${val}`;
|
||||
group.appendChild(valLabel);
|
||||
|
||||
vecEl.appendChild(group);
|
||||
|
||||
// Add "·" or "=" separator between groups (except last)
|
||||
if (j < memValues.length - 1) {
|
||||
const sep = document.createElement('div');
|
||||
sep.style.cssText = 'align-self:center;color:var(--border);font-size:1.2rem;padding:0 2px';
|
||||
sep.textContent = '';
|
||||
vecEl.appendChild(sep);
|
||||
}
|
||||
});
|
||||
|
||||
// Read result
|
||||
document.getElementById('readResult').innerHTML =
|
||||
`<strong style="color:var(--gold)">Read result:</strong> mem[${queryTarget}] = <strong>${memValues[queryTarget]}</strong> — key k<sub>${queryTarget}</sub> gets score <strong>${maxScore}</strong> (softmax weight ${(weights[queryTarget] * 100).toFixed(2)}%), all others are penalized by −(i−j)²`;
|
||||
}
|
||||
|
||||
renderMemory();
|
||||
|
||||
// ══════════════════════════════════════════
|
||||
// TAB 2b: Side-by-Side Comparison
|
||||
// ══════════════════════════════════════════
|
||||
const sbsWords = ['The', 'cake', 'delicious', 'was', 'very'];
|
||||
// Simulated high-dim embeddings (4D slice for display) — designed to show semantic similarity spread
|
||||
const sbsTradKeys = [
|
||||
[0.2, -0.1, 0.8, 0.3], // The
|
||||
[0.9, 0.7, 0.1, -0.2], // cake
|
||||
[0.8, 0.9, -0.1, 0.3], // delicious
|
||||
[0.1, -0.3, 0.7, 0.5], // was
|
||||
[0.3, 0.1, 0.2, 0.9], // very
|
||||
];
|
||||
// Queries are similar to the target but with overlap to neighbors (semantic similarity)
|
||||
const sbsTradQueries = [
|
||||
[0.3, -0.2, 0.9, 0.2], // attending to "The"
|
||||
[0.8, 0.6, 0.2, -0.1], // attending to "cake"
|
||||
[0.7, 0.8, 0.0, 0.4], // attending to "delicious"
|
||||
[0.2, -0.2, 0.8, 0.4], // attending to "was"
|
||||
[0.4, 0.2, 0.1, 0.8], // attending to "very"
|
||||
];
|
||||
|
||||
let sbsTarget = 2;
|
||||
|
||||
function makeSbsColVec(values, cls, label) {
|
||||
const wrap = document.createElement('div');
|
||||
wrap.className = 'col-vec sbs-vec ' + cls;
|
||||
wrap.innerHTML = '<div class="bracket"></div><div class="bracket-r"></div>';
|
||||
values.forEach(v => {
|
||||
const cell = document.createElement('div');
|
||||
cell.className = 'cell';
|
||||
cell.textContent = typeof v === 'number' ? (Number.isInteger(v) ? v : v.toFixed(1)) : v;
|
||||
wrap.appendChild(cell);
|
||||
});
|
||||
if (label) {
|
||||
const lbl = document.createElement('div');
|
||||
lbl.className = 'vec-label';
|
||||
lbl.innerHTML = label;
|
||||
wrap.appendChild(lbl);
|
||||
}
|
||||
return wrap;
|
||||
}
|
||||
|
||||
function renderSBS() {
|
||||
const target = sbsTarget;
|
||||
document.getElementById('sbsTargetLabel').textContent = sbsWords[target];
|
||||
|
||||
// ── Traditional side ──
|
||||
const tradQEl = document.getElementById('sbsTradQ');
|
||||
tradQEl.innerHTML = '';
|
||||
const tradQLabel = document.createElement('span');
|
||||
tradQLabel.style.cssText = 'font-size:0.75rem;color:var(--dim);margin-right:4px';
|
||||
tradQLabel.textContent = 'query:';
|
||||
tradQEl.appendChild(tradQLabel);
|
||||
tradQEl.appendChild(makeSbsColVec(sbsTradQueries[target], 'query', 'q'));
|
||||
|
||||
const tradKeysEl = document.getElementById('sbsTradKeys');
|
||||
tradKeysEl.innerHTML = '';
|
||||
|
||||
const tradScores = sbsTradKeys.map(k =>
|
||||
k.reduce((sum, ki, d) => sum + ki * sbsTradQueries[target][d], 0)
|
||||
);
|
||||
const tradWeights = softmax(tradScores.map(s => s * 4)); // moderate temperature
|
||||
const tradMaxW = Math.max(...tradWeights);
|
||||
|
||||
sbsWords.forEach((word, j) => {
|
||||
const grp = document.createElement('div');
|
||||
const isTop = tradWeights[j] === tradMaxW;
|
||||
grp.className = 'sbs-key-group' + (isTop ? ' trad-winner' : '');
|
||||
|
||||
const wordEl = document.createElement('div');
|
||||
wordEl.className = 'sbs-word';
|
||||
wordEl.textContent = word;
|
||||
grp.appendChild(wordEl);
|
||||
|
||||
grp.appendChild(makeSbsColVec(sbsTradKeys[j], '', `k<sub>${j}</sub>`));
|
||||
|
||||
const score = document.createElement('div');
|
||||
score.className = 'sbs-score';
|
||||
score.textContent = tradScores[j].toFixed(2);
|
||||
grp.appendChild(score);
|
||||
|
||||
const weight = document.createElement('div');
|
||||
weight.className = 'sbs-weight';
|
||||
weight.textContent = (tradWeights[j] * 100).toFixed(1) + '%';
|
||||
grp.appendChild(weight);
|
||||
|
||||
const bar = document.createElement('div');
|
||||
bar.className = 'sbs-weight-bar';
|
||||
bar.style.width = (tradWeights[j] / tradMaxW * 50) + 'px';
|
||||
grp.appendChild(bar);
|
||||
|
||||
tradKeysEl.appendChild(grp);
|
||||
});
|
||||
|
||||
const tradResult = document.getElementById('sbsTradResult');
|
||||
const topTrad = tradWeights.map((w, i) => ({w, i})).sort((a, b) => b.w - a.w);
|
||||
tradResult.innerHTML = `<span style="color:var(--accent2)">Output:</span> blend of "${sbsWords[topTrad[0].i]}" (${(topTrad[0].w*100).toFixed(0)}%) + "${sbsWords[topTrad[1].i]}" (${(topTrad[1].w*100).toFixed(0)}%) + others<br><span style="color:var(--dim)">→ A fuzzy mix of semantically related tokens</span>`;
|
||||
|
||||
// ── Lookup side ──
|
||||
const lookupQEl = document.getElementById('sbsLookupQ');
|
||||
lookupQEl.innerHTML = '';
|
||||
const lookupQLabel = document.createElement('span');
|
||||
lookupQLabel.style.cssText = 'font-size:0.75rem;color:var(--dim);margin-right:4px';
|
||||
lookupQLabel.textContent = 'query:';
|
||||
lookupQEl.appendChild(lookupQLabel);
|
||||
lookupQEl.appendChild(makeSbsColVec([target, 1], 'query', 'q'));
|
||||
|
||||
const lookupKeysEl = document.getElementById('sbsLookupKeys');
|
||||
lookupKeysEl.innerHTML = '';
|
||||
|
||||
const lookupScores = sbsWords.map((_, j) => 2 * target * j - j * j);
|
||||
const lookupWeights = softmax(lookupScores.map(s => s * 10));
|
||||
const lookupMaxW = Math.max(...lookupWeights);
|
||||
|
||||
sbsWords.forEach((word, j) => {
|
||||
const grp = document.createElement('div');
|
||||
const isWin = lookupWeights[j] === lookupMaxW;
|
||||
grp.className = 'sbs-key-group' + (isWin ? ' winner' : '');
|
||||
|
||||
const wordEl = document.createElement('div');
|
||||
wordEl.className = 'sbs-word';
|
||||
wordEl.textContent = `addr ${j}`;
|
||||
grp.appendChild(wordEl);
|
||||
|
||||
grp.appendChild(makeSbsColVec([2 * j, -(j * j)], isWin ? 'winner' : '', `k<sub>${j}</sub>`));
|
||||
|
||||
const score = document.createElement('div');
|
||||
score.className = 'sbs-score';
|
||||
score.textContent = lookupScores[j];
|
||||
grp.appendChild(score);
|
||||
|
||||
const weight = document.createElement('div');
|
||||
weight.className = 'sbs-weight';
|
||||
weight.textContent = (lookupWeights[j] * 100).toFixed(1) + '%';
|
||||
grp.appendChild(weight);
|
||||
|
||||
const bar = document.createElement('div');
|
||||
bar.className = 'sbs-weight-bar';
|
||||
bar.style.width = (lookupWeights[j] / (lookupMaxW || 1) * 50) + 'px';
|
||||
grp.appendChild(bar);
|
||||
|
||||
lookupKeysEl.appendChild(grp);
|
||||
});
|
||||
|
||||
const lookupResult = document.getElementById('sbsLookupResult');
|
||||
lookupResult.innerHTML = `<span style="color:var(--gold)">Output:</span> value at addr ${target} with <strong>${(lookupWeights[target] * 100).toFixed(2)}%</strong> weight<br><span style="color:var(--dim)">→ An exact read of one specific address</span>`;
|
||||
}
|
||||
|
||||
document.getElementById('sbsTargetSlider').addEventListener('input', e => {
|
||||
sbsTarget = +e.target.value;
|
||||
renderSBS();
|
||||
});
|
||||
renderSBS();
|
||||
|
||||
// ══════════════════════════════════════════
|
||||
// TAB 3: Parabola Visualization
|
||||
// ══════════════════════════════════════════
|
||||
function drawParabola(queryIdx) {
|
||||
const canvas = document.getElementById('parabolaCanvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
const W = canvas.width, H = canvas.height;
|
||||
ctx.clearRect(0, 0, W, H);
|
||||
|
||||
const n = 8;
|
||||
const pad = 40;
|
||||
const xScale = (W - 2 * pad) / (2 * (n - 1));
|
||||
const maxJ2 = (n - 1) * (n - 1);
|
||||
const yScale = (H - 2 * pad) / maxJ2;
|
||||
|
||||
// Axes
|
||||
ctx.strokeStyle = '#2a2a44';
|
||||
ctx.lineWidth = 1;
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(pad, H - pad);
|
||||
ctx.lineTo(W - pad, H - pad);
|
||||
ctx.moveTo(pad, H - pad);
|
||||
ctx.lineTo(pad, pad);
|
||||
ctx.stroke();
|
||||
|
||||
ctx.fillStyle = '#666680';
|
||||
ctx.font = '10px monospace';
|
||||
ctx.fillText('2j →', W - pad - 20, H - pad + 15);
|
||||
ctx.fillText('−j²', pad - 5, pad - 5);
|
||||
|
||||
// Parabola curve
|
||||
ctx.strokeStyle = '#333355';
|
||||
ctx.lineWidth = 1.5;
|
||||
ctx.beginPath();
|
||||
for (let j = 0; j < n; j++) {
|
||||
const x = pad + (2 * j) * xScale;
|
||||
const y = H - pad - (j * j) * yScale;
|
||||
j === 0 ? ctx.moveTo(x, y) : ctx.lineTo(x, y);
|
||||
}
|
||||
ctx.stroke();
|
||||
|
||||
// Points
|
||||
for (let j = 0; j < n; j++) {
|
||||
const x = pad + (2 * j) * xScale;
|
||||
const y = H - pad - (j * j) * yScale;
|
||||
ctx.beginPath();
|
||||
ctx.arc(x, y, j === queryIdx ? 8 : 5, 0, Math.PI * 2);
|
||||
ctx.fillStyle = j === queryIdx ? '#ffd54f' : '#4fc3f7';
|
||||
ctx.fill();
|
||||
ctx.fillStyle = '#999';
|
||||
ctx.font = '10px monospace';
|
||||
ctx.fillText(`j=${j}`, x - 8, y + 18);
|
||||
}
|
||||
|
||||
// Query direction arrow
|
||||
const qx = pad + (2 * queryIdx) * xScale;
|
||||
const qy = H - pad - (queryIdx * queryIdx) * yScale;
|
||||
ctx.strokeStyle = '#ff7043';
|
||||
ctx.lineWidth = 2;
|
||||
ctx.setLineDash([4, 3]);
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(pad + W / 4, H - pad);
|
||||
ctx.lineTo(qx, qy);
|
||||
ctx.stroke();
|
||||
ctx.setLineDash([]);
|
||||
|
||||
ctx.fillStyle = '#ff7043';
|
||||
ctx.font = 'bold 11px sans-serif';
|
||||
ctx.fillText(`q=(${queryIdx},1)`, pad + W / 4 - 15, H - pad + 15);
|
||||
}
|
||||
|
||||
function drawScores(queryIdx) {
|
||||
const canvas = document.getElementById('scoresCanvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
const W = canvas.width, H = canvas.height;
|
||||
ctx.clearRect(0, 0, W, H);
|
||||
|
||||
const n = 8;
|
||||
const pad = 40;
|
||||
const barW = (W - 2 * pad) / n - 4;
|
||||
|
||||
const scores = [];
|
||||
for (let j = 0; j < n; j++) {
|
||||
scores.push(2 * queryIdx * j - j * j);
|
||||
}
|
||||
const maxS = Math.max(...scores);
|
||||
const minS = Math.min(...scores);
|
||||
const range = maxS - minS || 1;
|
||||
|
||||
// Axes
|
||||
ctx.strokeStyle = '#2a2a44';
|
||||
ctx.lineWidth = 1;
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(pad, H - pad);
|
||||
ctx.lineTo(W - pad, H - pad);
|
||||
ctx.stroke();
|
||||
|
||||
ctx.fillStyle = '#666680';
|
||||
ctx.font = '10px monospace';
|
||||
ctx.fillText('j →', W - pad - 15, H - pad + 15);
|
||||
ctx.fillText('score', pad - 5, pad + 10);
|
||||
|
||||
// Bars
|
||||
for (let j = 0; j < n; j++) {
|
||||
const x = pad + j * ((W - 2 * pad) / n) + 2;
|
||||
const h = ((scores[j] - minS) / range) * (H - 2 * pad - 20);
|
||||
const y = H - pad - h;
|
||||
ctx.fillStyle = j === queryIdx ? '#ffd54f' : '#4fc3f7';
|
||||
ctx.fillRect(x, y, barW, h);
|
||||
ctx.fillStyle = '#999';
|
||||
ctx.font = '9px monospace';
|
||||
ctx.fillText(j.toString(), x + barW / 2 - 3, H - pad + 12);
|
||||
ctx.fillStyle = j === queryIdx ? '#ffd54f' : '#aaa';
|
||||
ctx.font = '9px monospace';
|
||||
ctx.fillText(scores[j].toString(), x + barW / 2 - 6, y - 4);
|
||||
}
|
||||
|
||||
document.getElementById('parabolaInsight').textContent =
|
||||
`Index ${queryIdx} gets the highest score (${maxS}). The penalty −(i−j)² ensures only the exact match wins.`;
|
||||
}
|
||||
|
||||
document.getElementById('queryIdxSlider').addEventListener('input', e => {
|
||||
const v = +e.target.value;
|
||||
document.getElementById('queryIdxVal').textContent = v;
|
||||
drawParabola(v);
|
||||
drawScores(v);
|
||||
});
|
||||
|
||||
// Defer canvas drawing until tab is visible
|
||||
const tab3Observer = new MutationObserver(() => {
|
||||
if (document.getElementById('tab3').classList.contains('active')) {
|
||||
drawParabola(3);
|
||||
drawScores(3);
|
||||
tab3Observer.disconnect();
|
||||
}
|
||||
});
|
||||
tab3Observer.observe(document.getElementById('tab3'), { attributes: true, attributeFilter: ['class'] });
|
||||
|
||||
// ══════════════════════════════════════════
|
||||
// TAB 4: Write-Read Trace Demo
|
||||
// ══════════════════════════════════════════
|
||||
const wrSteps = [
|
||||
{
|
||||
type: 'write',
|
||||
instr: 'i32.const 3',
|
||||
desc: 'Push 3 onto stack',
|
||||
token: { label: 'token 0', instr: 'const 3', k: 'k=[2, −1]', v: 'v=3', addr: 1, val: 3 },
|
||||
action: '<div class="wr-step-title">WRITE: i32.const 3</div>The model emits a new trace token. W<sub>K</sub> maps it to key <strong style="color:var(--accent)">[2, −1]</strong> (stack depth 1 on the parabola). W<sub>V</sub> extracts value <strong style="color:var(--green)">3</strong>. The token now sits in the sequence — that\'s the write. No memory chip needed.',
|
||||
},
|
||||
{
|
||||
type: 'write',
|
||||
instr: 'i32.const 5',
|
||||
desc: 'Push 5 onto stack',
|
||||
token: { label: 'token 1', instr: 'const 5', k: 'k=[4, −4]', v: 'v=5', addr: 2, val: 5 },
|
||||
action: '<div class="wr-step-title">WRITE: i32.const 5</div>Another trace token emitted. Key <strong style="color:var(--accent)">[4, −4]</strong> (stack depth 2). Value <strong style="color:var(--green)">5</strong>. Now two tokens sit in the sequence = two stack entries.',
|
||||
},
|
||||
{
|
||||
type: 'read',
|
||||
instr: 'i32.add (read operands)',
|
||||
desc: 'Read top two stack values',
|
||||
readTargets: [1, 0], // indices into tokens array
|
||||
action: '<div class="wr-step-title">READ: i32.add needs operands</div>The add instruction needs the top two stack values. The stack head produces query <strong style="color:var(--warn)">q=[2, 1]</strong> → attention scans all past tokens\' keys → finds token 1 (score = 2×2×2 − 4 = 4, highest) → retrieves value <strong style="color:var(--gold)">5</strong>. Then query <strong style="color:var(--warn)">q=[1, 1]</strong> → finds token 0 → retrieves <strong style="color:var(--gold)">3</strong>. No memory was accessed — just attention over past tokens.',
|
||||
},
|
||||
{
|
||||
type: 'write',
|
||||
instr: 'i32.add (result)',
|
||||
desc: 'Push result 8',
|
||||
token: { label: 'token 2', instr: 'add → 8', k: 'k=[2, −1]', v: 'v=8', addr: 1, val: 8 },
|
||||
shadowIdx: 0, // token 0 gets overshadowed (same addr)
|
||||
action: '<div class="wr-step-title">WRITE: push result 8</div>The FFN computed 3 + 5 = 8. A new token is emitted with key <strong style="color:var(--accent)">[2, −1]</strong> (stack depth 1 — the stack shrank by 1). Value <strong style="color:var(--green)">8</strong>.<br><br>Notice: token 0 also had key [2, −1] (depth 1, value 3). But token 2 is <em>later</em> in the sequence, so the parabola trick gives it a higher score. <strong>The old value 3 is overshadowed, not erased.</strong>',
|
||||
},
|
||||
{
|
||||
type: 'read',
|
||||
instr: 'output',
|
||||
desc: 'Read top of stack',
|
||||
readTargets: [2],
|
||||
action: '<div class="wr-step-title">READ: output top of stack</div>Query <strong style="color:var(--warn)">q=[1, 1]</strong> for stack depth 1. Both token 0 and token 2 have key [2, −1], but token 2 is later → higher score → attention returns <strong style="color:var(--gold)">8</strong> (not the old 3). Output: <strong>8</strong>. ✓',
|
||||
},
|
||||
];
|
||||
|
||||
let wrStep = 0;
|
||||
let wrTokens = [];
|
||||
let wrShadowed = new Set();
|
||||
|
||||
function wrReset() {
|
||||
wrStep = 0;
|
||||
wrTokens = [];
|
||||
wrShadowed = new Set();
|
||||
renderWR();
|
||||
}
|
||||
|
||||
function wrStepExec() {
|
||||
if (wrStep >= wrSteps.length) return;
|
||||
const step = wrSteps[wrStep];
|
||||
if (step.type === 'write') {
|
||||
wrTokens.push({ ...step.token, isNew: true });
|
||||
if (step.shadowIdx !== undefined) wrShadowed.add(step.shadowIdx);
|
||||
}
|
||||
wrStep++;
|
||||
renderWR();
|
||||
// Clear "new" flag after animation
|
||||
setTimeout(() => {
|
||||
wrTokens.forEach(t => t.isNew = false);
|
||||
// Don't re-render, just let CSS animation finish
|
||||
}, 700);
|
||||
}
|
||||
|
||||
function renderWR() {
|
||||
const traceEl = document.getElementById('wrTrace');
|
||||
traceEl.innerHTML = '';
|
||||
|
||||
const step = wrStep > 0 ? wrSteps[wrStep - 1] : null;
|
||||
const isRead = step && step.type === 'read';
|
||||
const readSet = isRead ? new Set(step.readTargets) : new Set();
|
||||
|
||||
if (wrTokens.length === 0) {
|
||||
traceEl.innerHTML = '<div style="color:var(--dim);font-size:0.82rem;padding:1rem;text-align:center">No tokens yet. The sequence is empty — no memory exists.<br>Click Step to emit the first trace token.</div>';
|
||||
}
|
||||
|
||||
wrTokens.forEach((tok, i) => {
|
||||
const div = document.createElement('div');
|
||||
let cls = 'wr-token';
|
||||
if (tok.isNew) cls += ' new-token';
|
||||
if (isRead && readSet.has(i)) cls += ' found';
|
||||
if (wrShadowed.has(i)) cls += ' shadowed';
|
||||
div.className = cls;
|
||||
|
||||
let inner = `<div class="wr-tok-label">${tok.label}</div>`;
|
||||
inner += `<div class="wr-tok-instr">${tok.instr}</div>`;
|
||||
inner += `<div class="wr-tok-kv"><span class="wr-k">${tok.k}</span><br><span class="wr-v">${tok.v}</span></div>`;
|
||||
if (isRead && readSet.has(i)) {
|
||||
inner += `<div class="wr-read-arrow">↑ READ</div>`;
|
||||
}
|
||||
if (wrShadowed.has(i) && !readSet.has(i)) {
|
||||
inner += `<div style="font-size:0.6rem;color:var(--warn);margin-top:2px">overshadowed</div>`;
|
||||
}
|
||||
div.innerHTML = inner;
|
||||
traceEl.appendChild(div);
|
||||
});
|
||||
|
||||
const actionEl = document.getElementById('wrAction');
|
||||
if (step) {
|
||||
actionEl.innerHTML = step.action;
|
||||
} else {
|
||||
actionEl.innerHTML = '<span style="color:var(--dim)">Click Step to begin execution. Watch how each token becomes a memory cell.</span>';
|
||||
}
|
||||
}
|
||||
|
||||
document.getElementById('wrStep').addEventListener('click', wrStepExec);
|
||||
document.getElementById('wrReset').addEventListener('click', wrReset);
|
||||
renderWR();
|
||||
|
||||
// ══════════════════════════════════════════
|
||||
// TAB 4: Stack Machine Step-Through
|
||||
// ══════════════════════════════════════════
|
||||
const smProgram = [
|
||||
{ op: 'i32.const', arg: 3, desc: 'Push 3' },
|
||||
{ op: 'i32.const', arg: 5, desc: 'Push 5' },
|
||||
{ op: 'i32.add', arg: null, desc: 'Pop two, push sum' },
|
||||
{ op: 'i32.const', arg: 10, desc: 'Push 10' },
|
||||
{ op: 'i32.sub', arg: null, desc: 'Pop two, subtract' },
|
||||
{ op: 'output', arg: null, desc: 'Output top of stack' },
|
||||
{ op: 'halt', arg: null, desc: 'Stop' },
|
||||
];
|
||||
|
||||
let smState = { ip: 0, stack: [], output: null, done: false, headLog: [] };
|
||||
|
||||
function smReset() {
|
||||
smState = { ip: 0, stack: [], output: null, done: false, headLog: [] };
|
||||
renderSM();
|
||||
}
|
||||
|
||||
function smStepExec() {
|
||||
if (smState.done) return;
|
||||
const instr = smProgram[smState.ip];
|
||||
const heads = [];
|
||||
|
||||
// IP head: cumulative sum
|
||||
heads.push({ name: 'IP Head', action: `sum of deltas → IP = ${smState.ip}`, detail: `query: uniform avg × t = ${smState.ip}` });
|
||||
|
||||
if (instr.op === 'i32.const') {
|
||||
smState.stack.push(instr.arg);
|
||||
heads.push({ name: 'Stack Head', action: `WRITE ${instr.arg} at depth ${smState.stack.length}`, detail: `key=(${2 * smState.stack.length}, -${smState.stack.length ** 2}) val=${instr.arg}` });
|
||||
heads.push({ name: 'FFN (ALU)', action: 'passthrough (no arithmetic)', detail: 'gate=1, val=input' });
|
||||
} else if (instr.op === 'i32.add') {
|
||||
const b = smState.stack.pop(), a = smState.stack.pop();
|
||||
const r = a + b;
|
||||
heads.push({ name: 'Stack Head ×2', action: `READ depth ${smState.stack.length + 2} → ${b}, depth ${smState.stack.length + 1} → ${a}`, detail: `q=(${smState.stack.length + 2},1) → ${b}; q=(${smState.stack.length + 1},1) → ${a}` });
|
||||
heads.push({ name: 'FFN (ALU)', action: `${a} + ${b} = ${r}`, detail: `ReLU gate selects ADD path` });
|
||||
smState.stack.push(r);
|
||||
} else if (instr.op === 'i32.sub') {
|
||||
const b = smState.stack.pop(), a = smState.stack.pop();
|
||||
const r = a - b;
|
||||
heads.push({ name: 'Stack Head ×2', action: `READ depth ${smState.stack.length + 2} → ${b}, depth ${smState.stack.length + 1} → ${a}`, detail: `q=(${smState.stack.length + 2},1) → ${b}; q=(${smState.stack.length + 1},1) → ${a}` });
|
||||
heads.push({ name: 'FFN (ALU)', action: `${a} - ${b} = ${r}`, detail: `ReLU gate selects SUB path` });
|
||||
smState.stack.push(r);
|
||||
} else if (instr.op === 'output') {
|
||||
const top = smState.stack[smState.stack.length - 1];
|
||||
smState.output = top;
|
||||
heads.push({ name: 'Stack Head', action: `READ top (depth ${smState.stack.length}) → ${top}`, detail: `q=(${smState.stack.length},1) → ${top}` });
|
||||
} else if (instr.op === 'halt') {
|
||||
smState.done = true;
|
||||
heads.push({ name: 'Control Head', action: 'HALT detected', detail: 'opcode matches halt pattern' });
|
||||
}
|
||||
|
||||
smState.headLog = heads;
|
||||
smState.ip++;
|
||||
renderSM();
|
||||
}
|
||||
|
||||
function renderSM() {
|
||||
// Program listing
|
||||
const progEl = document.getElementById('smProgram');
|
||||
progEl.innerHTML = '<h4>Program</h4>';
|
||||
smProgram.forEach((instr, i) => {
|
||||
const line = document.createElement('div');
|
||||
line.className = 'instr-line' + (i === smState.ip - 1 && !smState.done ? ' current' : '') + (i < smState.ip - 1 ? ' done' : '') + (i === smState.ip - 1 && smState.done ? ' current' : '');
|
||||
line.textContent = `${i}: ${instr.op}${instr.arg !== null ? ' ' + instr.arg : ''}`;
|
||||
progEl.appendChild(line);
|
||||
});
|
||||
|
||||
// State
|
||||
const stateEl = document.getElementById('smState');
|
||||
stateEl.innerHTML = '<h4>VM State</h4>';
|
||||
const rows = [
|
||||
['IP', smState.ip >= smProgram.length ? 'HALT' : smState.ip],
|
||||
['Stack depth', smState.stack.length],
|
||||
['Output', smState.output !== null ? smState.output : '—'],
|
||||
];
|
||||
rows.forEach(([l, v]) => {
|
||||
const row = document.createElement('div');
|
||||
row.className = 'state-row';
|
||||
row.innerHTML = `<span class="label">${l}</span><span class="value">${v}</span>`;
|
||||
stateEl.appendChild(row);
|
||||
});
|
||||
|
||||
const stackLabel = document.createElement('div');
|
||||
stackLabel.style.cssText = 'font-size:0.75rem;color:var(--dim);margin-top:8px;margin-bottom:4px';
|
||||
stackLabel.textContent = 'Stack (top ↑):';
|
||||
stateEl.appendChild(stackLabel);
|
||||
|
||||
[...smState.stack].reverse().forEach((v, i) => {
|
||||
const item = document.createElement('div');
|
||||
item.className = 'stack-item' + (i === 0 ? ' top' : '');
|
||||
item.textContent = v;
|
||||
stateEl.appendChild(item);
|
||||
});
|
||||
|
||||
// Heads
|
||||
const headsEl = document.getElementById('smHeads');
|
||||
headsEl.innerHTML = '<h4>Attention Heads Active</h4>';
|
||||
if (smState.headLog.length === 0) {
|
||||
headsEl.innerHTML += '<div style="color:var(--dim);font-size:0.8rem">Click Step to begin</div>';
|
||||
}
|
||||
smState.headLog.forEach(h => {
|
||||
const div = document.createElement('div');
|
||||
div.className = 'head-info';
|
||||
div.innerHTML = `<div class="head-name">${h.name}</div><div class="head-action">${h.action}</div><div class="head-detail">${h.detail}</div>`;
|
||||
headsEl.appendChild(div);
|
||||
});
|
||||
}
|
||||
|
||||
document.getElementById('smStep').addEventListener('click', smStepExec);
|
||||
document.getElementById('smReset').addEventListener('click', smReset);
|
||||
renderSM();
|
||||
|
||||
// IP demo (mini)
|
||||
(function () {
|
||||
const el = document.querySelector('#ipDemo .ip-trace');
|
||||
const deltas = [1, 1, 1, 1, -2, 1, 1];
|
||||
let sum = 0;
|
||||
deltas.forEach(d => {
|
||||
sum += d;
|
||||
const cell = document.createElement('div');
|
||||
cell.className = 'ip-cell';
|
||||
cell.innerHTML = `<div class="delta">${d > 0 ? '+' : ''}${d}</div><div class="sum">IP=${sum}</div>`;
|
||||
el.appendChild(cell);
|
||||
});
|
||||
})();
|
||||
|
||||
// Stack demo (mini)
|
||||
(function () {
|
||||
const el = document.querySelector('#stackDemo .stack-vis');
|
||||
[3, 5, 8].forEach(v => {
|
||||
const item = document.createElement('div');
|
||||
item.className = 'sv-item';
|
||||
item.textContent = v;
|
||||
el.appendChild(item);
|
||||
});
|
||||
})();
|
||||
|
||||
// ══════════════════════════════════════════
|
||||
// TAB 5: Full Execution Trace
|
||||
// ══════════════════════════════════════════
|
||||
const feProgram = [
|
||||
{ op: 'i32.const', bytes: '03 00 00 00', desc: 'Push 3 onto stack' },
|
||||
{ op: 'i32.const', bytes: '05 00 00 00', desc: 'Push 5 onto stack' },
|
||||
{ op: 'i32.add', bytes: '00 00 00 00', desc: 'Pop 3 and 5, push 8' },
|
||||
{ op: 'output', bytes: '00 00 00 00', desc: 'Output top of stack' },
|
||||
];
|
||||
|
||||
const feTraceSteps = [
|
||||
{ tok: '03 00 00 00', meta: 'commit(+1,sts=1,bt=0)', detail: 'IP Head reads instruction 0 → i32.const. Stack Head writes 3 at depth 1. Stack: [3]' },
|
||||
{ tok: '05 00 00 00', meta: 'commit(+1,sts=2,bt=0)', detail: 'IP Head reads instruction 1 → i32.const. Stack Head writes 5 at depth 2. Stack: [3, 5]' },
|
||||
{ tok: '08 00 00 00', meta: 'commit(-1,sts=1,bt=0)', detail: 'IP Head reads instruction 2 → i32.add. Stack Head reads depth 2 → 5, depth 1 → 3. FFN computes 3+5=8. Writes 8 at depth 1. Stack: [8]' },
|
||||
{ tok: 'out(08)', meta: '', detail: 'IP Head reads instruction 3 → output. Stack Head reads top → 8. Output token emitted.' },
|
||||
{ tok: 'halt', meta: '', detail: 'Program complete. All computation happened inside the transformer\'s forward pass.' },
|
||||
];
|
||||
|
||||
let feStep = 0;
|
||||
|
||||
function feReset() {
|
||||
feStep = 0;
|
||||
renderFE();
|
||||
}
|
||||
|
||||
function feStepExec() {
|
||||
if (feStep < feTraceSteps.length) feStep++;
|
||||
renderFE();
|
||||
}
|
||||
|
||||
function feRunAll() {
|
||||
feStep = feTraceSteps.length;
|
||||
renderFE();
|
||||
}
|
||||
|
||||
function renderFE() {
|
||||
// Program
|
||||
const progEl = document.getElementById('feProgram');
|
||||
progEl.innerHTML = '<h4>WASM Program</h4>';
|
||||
feProgram.forEach((instr, i) => {
|
||||
const line = document.createElement('div');
|
||||
line.className = 'instr-line' + (i < feStep ? ' done' : '') + (i === feStep - 1 ? ' current' : '');
|
||||
line.textContent = `${instr.op} ${instr.bytes}`;
|
||||
progEl.appendChild(line);
|
||||
});
|
||||
|
||||
// Trace
|
||||
const traceEl = document.getElementById('feTrace');
|
||||
traceEl.innerHTML = '<h4>Execution Trace (tokens)</h4>';
|
||||
for (let i = 0; i < feStep; i++) {
|
||||
const t = feTraceSteps[i];
|
||||
const div = document.createElement('div');
|
||||
div.className = 'trace-token' + (i === feStep - 1 ? ' new' : '');
|
||||
div.innerHTML = `<span class="tok">${t.tok}</span><span class="meta">${t.meta}</span>`;
|
||||
traceEl.appendChild(div);
|
||||
}
|
||||
if (feStep === 0) {
|
||||
traceEl.innerHTML += '<div style="color:var(--dim);font-size:0.8rem;padding:4px">Click Step to generate trace tokens...</div>';
|
||||
}
|
||||
|
||||
// Detail
|
||||
const detailEl = document.getElementById('feDetail');
|
||||
detailEl.innerHTML = '<h4>What Happened (weight level)</h4>';
|
||||
if (feStep > 0) {
|
||||
const d = feTraceSteps[feStep - 1];
|
||||
const div = document.createElement('div');
|
||||
div.style.cssText = 'font-size:0.82rem;line-height:1.6;color:var(--text)';
|
||||
div.textContent = d.detail;
|
||||
detailEl.appendChild(div);
|
||||
|
||||
// Visual: which heads fired
|
||||
const heads = [];
|
||||
if (feStep <= 3) heads.push({ name: 'IP Head', color: 'var(--accent)' });
|
||||
if (feStep <= 4) heads.push({ name: 'Stack Head', color: 'var(--green)' });
|
||||
if (feStep === 3) heads.push({ name: 'FFN (ALU)', color: 'var(--gold)' });
|
||||
if (feStep === 5) heads.push({ name: 'Control', color: 'var(--warn)' });
|
||||
|
||||
const hDiv = document.createElement('div');
|
||||
hDiv.style.cssText = 'margin-top:12px;display:flex;gap:6px;flex-wrap:wrap';
|
||||
heads.forEach(h => {
|
||||
const chip = document.createElement('span');
|
||||
chip.style.cssText = `font-size:0.75rem;padding:3px 10px;border-radius:12px;border:1px solid ${h.color};color:${h.color}`;
|
||||
chip.textContent = h.name;
|
||||
hDiv.appendChild(chip);
|
||||
});
|
||||
detailEl.appendChild(hDiv);
|
||||
} else {
|
||||
detailEl.innerHTML += '<div style="color:var(--dim);font-size:0.8rem">Each step shows which attention heads fire and what they compute.</div>';
|
||||
}
|
||||
}
|
||||
|
||||
document.getElementById('feStep').addEventListener('click', feStepExec);
|
||||
document.getElementById('feReset').addEventListener('click', feReset);
|
||||
document.getElementById('feRunAll').addEventListener('click', feRunAll);
|
||||
renderFE();
|
||||
515
interactive/index.html
Normal file
515
interactive/index.html
Normal file
@@ -0,0 +1,515 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>How Transformer Weights Become a Computer</title>
|
||||
<link rel="stylesheet" href="style.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="app">
|
||||
<header>
|
||||
<h1>How Transformer Weights Become a Computer</h1>
|
||||
<p class="subtitle">An interactive exploration of how matrix multiplications can execute deterministic programs</p>
|
||||
</header>
|
||||
|
||||
<nav id="tabs">
|
||||
<button class="tab active" data-tab="tab1">1. The Puzzle</button>
|
||||
<button class="tab" data-tab="tab2">2. Attention = Lookup</button>
|
||||
<button class="tab" data-tab="tab3">3. The 2D Trick</button>
|
||||
<button class="tab" data-tab="tab4">4. Stack Machine</button>
|
||||
<button class="tab" data-tab="tab5">5. Full Execution</button>
|
||||
</nav>
|
||||
|
||||
<!-- TAB 1: The Puzzle -->
|
||||
<section id="tab1" class="panel active">
|
||||
<h2>The Puzzle: How Can Weights Be Deterministic?</h2>
|
||||
<div class="explainer">
|
||||
<p>A transformer is just matrix multiplications and softmax. How can that <em>execute a program</em> exactly?</p>
|
||||
<p>The secret: attention is a <strong>lookup table</strong>. If you set up the keys and queries just right, the softmax becomes so peaked that it acts like an exact <code>array[index]</code> read.</p>
|
||||
</div>
|
||||
<div class="demo-box">
|
||||
<h3>Try it: Softmax Temperature</h3>
|
||||
<p>Below are 5 scores. One is the "correct" lookup target. Watch what happens as you increase the temperature multiplier — the softmax sharpens until it's essentially a hard lookup.</p>
|
||||
<div class="slider-row">
|
||||
<label>Temperature multiplier: <strong id="tempVal">1</strong></label>
|
||||
<input type="range" id="tempSlider" min="1" max="50" value="1" step="1">
|
||||
</div>
|
||||
<div id="softmaxBars"></div>
|
||||
<p class="insight" id="tempInsight">At low temperature, attention is spread out — fuzzy, not useful for exact computation.</p>
|
||||
</div>
|
||||
<div class="explainer">
|
||||
<p><strong>Key insight:</strong> When the score gap is large enough, softmax gives >99.99% weight to one entry. The "weighted average" becomes an exact read. This is how weights produce deterministic behavior — not by magic, but by engineering the scores to be extremely peaked.</p>
|
||||
</div>
|
||||
<button class="next-btn" data-next="tab2">Next: Attention = Lookup →</button>
|
||||
</section>
|
||||
|
||||
<!-- TAB 2: Attention as Lookup -->
|
||||
<section id="tab2" class="panel">
|
||||
<h2>Attention as a Lookup Table</h2>
|
||||
|
||||
<div class="explainer">
|
||||
<h3 style="color:var(--gold);margin-bottom:0.6rem">Start here: What problem are we solving?</h3>
|
||||
<p>A computer needs to <em>read from memory</em>. You say "give me the value at address 3" and you get back whatever is stored there. Simple.</p>
|
||||
<p>But a transformer has no memory. It only has a growing list of past tokens and one operation: <strong>attention</strong> — which computes a weighted average over all past tokens. How do you turn a weighted average into an exact memory read?</p>
|
||||
</div>
|
||||
|
||||
<div class="explainer">
|
||||
<h3 style="color:var(--accent);margin-bottom:0.6rem">First: How tokens become vectors</h3>
|
||||
<p>Before attention can happen, every token goes through the same pipeline. Here it is, step by step:</p>
|
||||
|
||||
<div class="pipeline-diagram">
|
||||
<div class="pipe-step">
|
||||
<div class="pipe-label">Plain text</div>
|
||||
<div class="pipe-box pipe-text">"cake"</div>
|
||||
</div>
|
||||
<div class="pipe-arrow">→<br><span class="pipe-arrow-label">tokenizer</span></div>
|
||||
<div class="pipe-step">
|
||||
<div class="pipe-label">Token ID</div>
|
||||
<div class="pipe-box pipe-id">4821</div>
|
||||
</div>
|
||||
<div class="pipe-arrow">→<br><span class="pipe-arrow-label">embedding table</span></div>
|
||||
<div class="pipe-step">
|
||||
<div class="pipe-label">Token embedding (x)</div>
|
||||
<div class="pipe-box pipe-emb">[0.3, 0.8, −0.1, 0.5, ...]</div>
|
||||
<div class="pipe-dim">one vector per token, d dimensions</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p style="margin-top:0.8rem">Then this embedding gets multiplied by <strong>three separate weight matrices</strong> to produce Q, K, and V:</p>
|
||||
|
||||
<div class="qkv-matmul-diagram">
|
||||
<div class="qkv-row">
|
||||
<div class="qkv-input">x</div>
|
||||
<div class="qkv-times">×</div>
|
||||
<div class="qkv-matrix" style="border-color:var(--warn)">W<sub>Q</sub></div>
|
||||
<div class="qkv-eq">=</div>
|
||||
<div class="qkv-output" style="color:var(--warn)">q</div>
|
||||
<div class="qkv-meaning">← "what address do I need?"</div>
|
||||
</div>
|
||||
<div class="qkv-row">
|
||||
<div class="qkv-input">x</div>
|
||||
<div class="qkv-times">×</div>
|
||||
<div class="qkv-matrix" style="border-color:var(--accent)">W<sub>K</sub></div>
|
||||
<div class="qkv-eq">=</div>
|
||||
<div class="qkv-output" style="color:var(--accent)">k</div>
|
||||
<div class="qkv-meaning">← "what address am I at?"</div>
|
||||
</div>
|
||||
<div class="qkv-row">
|
||||
<div class="qkv-input">x</div>
|
||||
<div class="qkv-times">×</div>
|
||||
<div class="qkv-matrix" style="border-color:var(--green)">W<sub>V</sub></div>
|
||||
<div class="qkv-eq">=</div>
|
||||
<div class="qkv-output" style="color:var(--green)">v</div>
|
||||
<div class="qkv-meaning">← "what data do I hold?"</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="explainer" style="margin-top:0.8rem;border-color:var(--gold);border-width:2px">
|
||||
<h3 style="color:var(--gold);margin-bottom:0.6rem">How Q, K, V work together (this is the key!)</h3>
|
||||
<p>A common misconception: the value <code>v</code> is <strong>not</strong> computed from <code>q · k</code>. The dot product <code>q · k</code> only produces a <em>score</em> — a single number. Here's the actual flow:</p>
|
||||
|
||||
<div class="attn-flow-diagram">
|
||||
<div class="attn-flow-section">
|
||||
<div class="attn-flow-label">Each past token (independently)</div>
|
||||
<div class="attn-flow-row">
|
||||
<div class="attn-flow-box" style="border-color:var(--accent)">k<sub>j</sub> = x<sub>j</sub> × W<sub>K</sub></div>
|
||||
<div class="attn-flow-desc">its address label</div>
|
||||
</div>
|
||||
<div class="attn-flow-row">
|
||||
<div class="attn-flow-box" style="border-color:var(--green)">v<sub>j</sub> = x<sub>j</sub> × W<sub>V</sub></div>
|
||||
<div class="attn-flow-desc">its stored data</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="attn-flow-section">
|
||||
<div class="attn-flow-label">Current token</div>
|
||||
<div class="attn-flow-row">
|
||||
<div class="attn-flow-box" style="border-color:var(--warn)">q = x<sub>now</sub> × W<sub>Q</sub></div>
|
||||
<div class="attn-flow-desc">the address I want</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="attn-flow-section">
|
||||
<div class="attn-flow-label">Step 1: Score every past token</div>
|
||||
<div class="attn-flow-row">
|
||||
<div class="attn-flow-box computed">score<sub>j</sub> = q · k<sub>j</sub></div>
|
||||
<div class="attn-flow-desc">"how well does this key match my query?"<br>→ just a number, not a vector</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="attn-flow-section">
|
||||
<div class="attn-flow-label">Step 2: Pick the winner</div>
|
||||
<div class="attn-flow-row">
|
||||
<div class="attn-flow-box computed">weight<sub>j</sub> = softmax(scores)</div>
|
||||
<div class="attn-flow-desc">highest score → ~100% weight<br>all others → ~0%</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="attn-flow-section">
|
||||
<div class="attn-flow-label">Step 3: Retrieve the value</div>
|
||||
<div class="attn-flow-row">
|
||||
<div class="attn-flow-box" style="border-color:var(--gold);border-width:2px">output = Σ weight<sub>j</sub> × v<sub>j</sub></div>
|
||||
<div class="attn-flow-desc">≈ just <strong>v<sub>winner</sub></strong> (because its weight is ~100%)</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p style="margin-top:0.6rem">Think of it like a library: <strong style="color:var(--accent)">k</strong> is the label on the spine of each book, <strong style="color:var(--warn)">q</strong> is the title you're searching for, and <strong style="color:var(--green)">v</strong> is the content inside the book. You use q and k to <em>find</em> the right book, then you <em>read</em> v from it. The content was always there — the lookup just selects which book to open.</p>
|
||||
</div>
|
||||
|
||||
<div style="background:var(--surface2);border-radius:8px;padding:0.7rem 1rem;margin-top:0.8rem;font-size:0.82rem;line-height:1.6">
|
||||
<strong style="color:var(--gold)">So where is the "memory" physically?</strong><br>
|
||||
There's no separate memory chip. The <strong>past tokens in the sequence are the memory</strong>. Each past token carries a key (its address) and a value (its data), both computed from its residual stream via W<sub>K</sub> and W<sub>V</sub>. When the model "writes 42 to address 3," it emits a trace token whose key encodes address 3 and whose value encodes 42. When a later step "reads address 3," attention finds that token and retrieves its value.
|
||||
</div>
|
||||
|
||||
<div style="background:var(--surface2);border-radius:8px;padding:0.7rem 1rem;margin-top:0.6rem;font-size:0.82rem;line-height:1.6">
|
||||
<strong style="color:var(--gold)">Important nuance about <code>x</code>:</strong> The input to each layer isn't the raw token embedding — it's the <strong>residual stream</strong>: the original embedding <em>plus</em> all outputs from previous layers added on top. So by the time a token reaches layer 3, its <code>x</code> already contains results computed by layers 1 and 2 (like the current instruction pointer, or a value just loaded from the stack).<br><br>
|
||||
This means W<sub>V</sub> doesn't just extract the original token — it extracts <em>whatever information previous layers have written into the residual stream</em>. For example, if layer 2 computed "the value at stack position 3 is <code>42</code>", then layer 3's W<sub>V</sub> can extract that <code>42</code> and make it available as this token's value.
|
||||
</div>
|
||||
|
||||
<div style="background:var(--surface2);border-radius:8px;padding:0.7rem 1rem;margin-top:0.6rem;font-size:0.82rem;line-height:1.6">
|
||||
<strong style="color:var(--gold)">The mechanism is identical in both traditional LLMs and this paper.</strong> The only difference is what the weight matrices contain:<br>
|
||||
<span style="color:var(--accent2)">● Traditional:</span> W<sub>Q</sub>, W<sub>K</sub>, W<sub>V</sub> are <em>learned from data</em> via gradient descent → produce fuzzy semantic vectors<br>
|
||||
<span style="color:var(--gold)">● This paper:</span> W<sub>Q</sub>, W<sub>K</sub>, W<sub>V</sub> are <em>set by construction</em> (compiled) → produce exact address vectors like <code>[i, 1]</code> and <code>[2j, −j²]</code>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="explainer">
|
||||
<h3 style="color:var(--accent);margin-bottom:0.6rem">The three players: Key, Query, Value</h3>
|
||||
<p>Now that you know how they're produced, here's what each one <em>means</em>:</p>
|
||||
<table class="kqv-table">
|
||||
<tr>
|
||||
<td class="kqv-name" style="color:var(--accent)">Key (k)</td>
|
||||
<td>The token's <em>address label</em>. It answers: <strong>"Where am I?"</strong><br>
|
||||
Think of it like the number on a mailbox. Token at memory address 3 gets a key that encodes "I'm address 3."</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="kqv-name" style="color:var(--warn)">Query (q)</td>
|
||||
<td>The current token's <em>request</em>. It answers: <strong>"What address do I need?"</strong><br>
|
||||
If the current instruction needs to read address 3, the query encodes "I want address 3."</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="kqv-name" style="color:var(--green)">Value (v)</td>
|
||||
<td>The token's <em>stored data</em>. It answers: <strong>"What's in me?"</strong><br>
|
||||
This is the actual content — the number 42, or a stack entry, or whatever was written there.</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<div class="explainer">
|
||||
<h3 style="color:var(--accent);margin-bottom:0.6rem">How the lookup works</h3>
|
||||
<p>Attention computes the <strong>dot product</strong> between the query and every key. The dot product is just: multiply corresponding entries and add up. A high dot product means "these two vectors point in a similar direction" — i.e. <em>this key matches what the query is asking for</em>.</p>
|
||||
<p>Then <strong>softmax</strong> turns those scores into weights that sum to 1. If one score is much higher than the rest, that key gets nearly 100% of the weight — and the output is just that key's value. That's an exact read.</p>
|
||||
<div style="background:var(--surface2);border-radius:8px;padding:0.8rem 1rem;margin-top:0.6rem;font-size:0.85rem;line-height:1.7">
|
||||
<span style="color:var(--dim)">1.</span> Compute scores: <code>score_j = q · k_j</code> <span style="color:var(--dim)">← dot product of query with each key</span><br>
|
||||
<span style="color:var(--dim)">2.</span> Normalize: <code>weight_j = softmax(scores)</code> <span style="color:var(--dim)">← highest score gets ~100%</span><br>
|
||||
<span style="color:var(--dim)">3.</span> Read: <code>output = Σ weight_j × value_j</code> <span style="color:var(--dim)">← weighted average ≈ exact value</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="explainer">
|
||||
<h3 style="color:var(--accent);margin-bottom:0.6rem">But what <em>are</em> the key vectors, concretely?</h3>
|
||||
<p>In this paper, each key is a <strong>2-number column vector</strong>. The weight matrix <code>W_K</code> is engineered so that a token at address <code>j</code> gets mapped to:</p>
|
||||
<div style="text-align:center;margin:0.6rem 0;font-size:1.1rem">
|
||||
<code>k<sub>j</sub> = <span style="color:var(--accent)">[2j, −j²]</span></code>
|
||||
</div>
|
||||
<p>Why these specific numbers? Because they sit on a <strong>parabola</strong>, and that shape has a magical property: when you dot-product a query <code>q = [i, 1]</code> with every key, the math simplifies to <code>−(i − j)² + i²</code>. That's an upside-down parabola centered at <code>j = i</code> — meaning the key at the exact target address always wins. (Tab 3 visualizes this in detail.)</p>
|
||||
<p>The key vector isn't something the model "learns" in the usual sense — the weight matrix <code>W_K</code> is <strong>set by construction</strong> to produce these parabola coordinates. It's more like compiling an address decoder into the weights.</p>
|
||||
</div>
|
||||
|
||||
<div class="demo-box">
|
||||
<h3>Interactive: See It In Action</h3>
|
||||
<p>Below is a simulated memory with 6 slots. Click a slot to change which address you're reading. Watch how the query vector changes, and how the dot products with each key vector determine which value gets read.</p>
|
||||
<div id="memorySlots"></div>
|
||||
<h4 style="color:var(--dim);margin:1rem 0 0.5rem;font-size:0.85rem">Query Vector</h4>
|
||||
<div id="queryVec"></div>
|
||||
<h4 style="color:var(--dim);margin:1rem 0 0.5rem;font-size:0.85rem">Key Vectors & Dot Products</h4>
|
||||
<div id="vecColumns"></div>
|
||||
<div id="readResult"></div>
|
||||
</div>
|
||||
<div class="demo-box">
|
||||
<h3>Side by Side: Traditional LLM Attention vs Memory Lookup</h3>
|
||||
<p>Both use the same mechanism (dot product → softmax → weighted average). The difference is <em>what the keys and queries represent</em> and <em>how sharp the result is</em>.</p>
|
||||
<div class="sbs-controls">
|
||||
<div class="slider-row">
|
||||
<label>Input word to attend to: <strong id="sbsTargetLabel">delicious</strong></label>
|
||||
<input type="range" id="sbsTargetSlider" min="0" max="4" value="2" step="1">
|
||||
</div>
|
||||
</div>
|
||||
<div class="sbs-container">
|
||||
<!-- LEFT: Traditional -->
|
||||
<div class="sbs-panel">
|
||||
<div class="sbs-header trad">Traditional LLM Attention</div>
|
||||
<div class="sbs-subtitle">Keys = learned semantic embeddings (high-dimensional, shown as 4D slice)</div>
|
||||
<div id="sbsTradQ" class="sbs-vec-row"></div>
|
||||
<div id="sbsTradKeys" class="sbs-keys"></div>
|
||||
<div id="sbsTradResult" class="sbs-result"></div>
|
||||
<div class="sbs-note">
|
||||
Weights are <strong>spread across multiple tokens</strong> — attention is a soft blend. Good for language understanding ("what words are related?") but useless for exact computation.
|
||||
</div>
|
||||
</div>
|
||||
<!-- RIGHT: Memory Lookup -->
|
||||
<div class="sbs-panel">
|
||||
<div class="sbs-header lookup">Memory Lookup (this paper)</div>
|
||||
<div class="sbs-subtitle">Keys = engineered 2D addresses on a parabola</div>
|
||||
<div id="sbsLookupQ" class="sbs-vec-row"></div>
|
||||
<div id="sbsLookupKeys" class="sbs-keys"></div>
|
||||
<div id="sbsLookupResult" class="sbs-result"></div>
|
||||
<div class="sbs-note">
|
||||
Weight is <strong>100% on one token</strong> — attention is a hard lookup. This is what you need to read <code>stack[3]</code> or <code>mem[addr]</code> exactly.
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<p class="insight">Same math, different weights. Traditional LLMs learn W_K and W_Q from data → fuzzy semantic similarity. This paper engineers W_K and W_Q by construction → exact address lookup. The architecture is identical.</p>
|
||||
</div>
|
||||
|
||||
<div class="explainer">
|
||||
<h3 style="color:var(--accent);margin-bottom:0.6rem">But where does the query vector come from?</h3>
|
||||
<p>In both cases, the query is produced the same way: <strong>multiply the current token's embedding by the W<sub>Q</sub> weight matrix</strong>. The difference is what W<sub>Q</sub> contains.</p>
|
||||
|
||||
<div class="query-origin-grid">
|
||||
<!-- Traditional -->
|
||||
<div class="qo-panel">
|
||||
<div class="qo-header" style="color:var(--accent2)">Traditional LLM</div>
|
||||
<div class="qo-pipeline">
|
||||
<div class="qo-step">
|
||||
<div class="qo-label">Current token embedding</div>
|
||||
<div class="qo-box">x = <code>[0.3, 0.8, −0.1, ...]</code></div>
|
||||
<div class="qo-dim">~4096 dimensions</div>
|
||||
</div>
|
||||
<div class="qo-arrow">× W<sub>Q</sub></div>
|
||||
<div class="qo-step">
|
||||
<div class="qo-label">W<sub>Q</sub> weight matrix</div>
|
||||
<div class="qo-box learned">Learned from data</div>
|
||||
<div class="qo-dim">Encodes "what's semantically relevant to me?"</div>
|
||||
</div>
|
||||
<div class="qo-arrow">=</div>
|
||||
<div class="qo-step">
|
||||
<div class="qo-label">Query vector</div>
|
||||
<div class="qo-box">q = <code>[0.7, 0.8, 0.0, ...]</code></div>
|
||||
<div class="qo-dim">Points toward semantically similar keys</div>
|
||||
</div>
|
||||
</div>
|
||||
<p class="qo-note">W<sub>Q</sub> is trained via gradient descent on billions of tokens. It learns to produce queries that find <em>contextually relevant</em> tokens — "delicious" attends to "cake" because they co-occur. The query is a fuzzy semantic direction.</p>
|
||||
</div>
|
||||
|
||||
<!-- Memory lookup -->
|
||||
<div class="qo-panel">
|
||||
<div class="qo-header" style="color:var(--gold)">Memory Lookup (this paper)</div>
|
||||
<div class="qo-pipeline">
|
||||
<div class="qo-step">
|
||||
<div class="qo-label">Current state (from prior layers)</div>
|
||||
<div class="qo-box">x includes the <em>target address</em> i</div>
|
||||
<div class="qo-dim">Prior attention heads computed which address to read</div>
|
||||
</div>
|
||||
<div class="qo-arrow">× W<sub>Q</sub></div>
|
||||
<div class="qo-step">
|
||||
<div class="qo-label">W<sub>Q</sub> weight matrix</div>
|
||||
<div class="qo-box engineered">Engineered by construction</div>
|
||||
<div class="qo-dim">Extracts address i and formats it as (i, 1)</div>
|
||||
</div>
|
||||
<div class="qo-arrow">=</div>
|
||||
<div class="qo-step">
|
||||
<div class="qo-label">Query vector</div>
|
||||
<div class="qo-box">q = <code>[i, 1]</code></div>
|
||||
<div class="qo-dim">Points exactly at key k<sub>i</sub> on the parabola</div>
|
||||
</div>
|
||||
</div>
|
||||
<p class="qo-note">W<sub>Q</sub> is <strong>not learned</strong> — it's set so that it extracts the target address from the current token's state and formats it as <code>[i, 1]</code>. The address <code>i</code> itself comes from earlier layers: e.g. the instruction pointer head computes the current IP, and the stack head uses the current stack depth. Each head's W<sub>Q</sub> is wired to extract the specific piece of state it needs.</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style="background:var(--surface2);border-radius:8px;padding:0.8rem 1rem;margin-top:0.8rem;font-size:0.85rem;line-height:1.7">
|
||||
<strong style="color:var(--accent)">The chain of computation across layers:</strong><br>
|
||||
<span style="color:var(--dim)">Layer 1:</span> A head computes the instruction pointer (cumulative sum of deltas) → "we're at instruction 5"<br>
|
||||
<span style="color:var(--dim)">Layer 2:</span> A head uses the IP to look up instruction 5 from the program → "it's <code>i32.add</code>"<br>
|
||||
<span style="color:var(--dim)">Layer 3:</span> A head uses the stack depth to read the top two stack values → operands<br>
|
||||
<span style="color:var(--dim)">Layer 4:</span> The feed-forward network does the arithmetic → result<br>
|
||||
<span style="color:var(--dim)">Layer 5:</span> A head writes the result back to the stack at the new depth<br>
|
||||
<br>
|
||||
Each layer's W<sub>Q</sub> is wired to extract exactly the right piece of state from the previous layer's output. The query vector is never a vague "what's relevant?" — it's always a precise "give me the value at this specific address."
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="explainer">
|
||||
<p>Each attention head in the transformer does exactly this — but the "indices" and "values" are determined by the weight matrices <code>W_Q</code>, <code>W_K</code>, <code>W_V</code>. The weights are set so that:</p>
|
||||
<ul>
|
||||
<li><strong>W_K</strong> maps each token to a 2D "address" on a parabola</li>
|
||||
<li><strong>W_Q</strong> maps the current step to a "query direction" — extracting the target address from the current state</li>
|
||||
<li><strong>W_V</strong> extracts the stored value</li>
|
||||
</ul>
|
||||
<p>Multiple heads = multiple independent arrays. That's enough to build registers, a stack, and memory.</p>
|
||||
</div>
|
||||
<button class="next-btn" data-next="tab3">Next: The 2D Trick →</button>
|
||||
</section>
|
||||
|
||||
<!-- TAB 3: The 2D Parabola Trick -->
|
||||
<section id="tab3" class="panel">
|
||||
<h2>The 2D Parabola Trick: Exact Index Lookup</h2>
|
||||
<div class="explainer">
|
||||
<p>Here's the mathematical trick that makes exact lookup work with just 2D keys.</p>
|
||||
<p>Store index <code>j</code> as the 2D key: <strong>k<sub>j</sub> = (2j, −j²)</strong></p>
|
||||
<p>To look up index <code>i</code>, query with direction: <strong>q = (i, 1)</strong></p>
|
||||
<p>The dot product becomes: <code>q · k<sub>j</sub> = 2ij − j² = −(i − j)² + i²</code></p>
|
||||
<p>Since <code>i²</code> is constant for a given query, the argmax is at <code>j = i</code> — exact match!</p>
|
||||
</div>
|
||||
<div class="demo-box">
|
||||
<h3>Interactive: See the Parabola</h3>
|
||||
<p>The keys sit on a parabola. Drag the query slider to change which index you're looking up. The dot product scores form an inverted parabola peaked at the target.</p>
|
||||
<div class="slider-row">
|
||||
<label>Query index i = <strong id="queryIdxVal">3</strong></label>
|
||||
<input type="range" id="queryIdxSlider" min="0" max="7" value="3" step="1">
|
||||
</div>
|
||||
<div class="two-col">
|
||||
<div>
|
||||
<h4>Keys on Parabola (2D space)</h4>
|
||||
<canvas id="parabolaCanvas" width="380" height="300"></canvas>
|
||||
</div>
|
||||
<div>
|
||||
<h4>Dot Product Scores (q · k<sub>j</sub>)</h4>
|
||||
<canvas id="scoresCanvas" width="380" height="300"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
<p class="insight" id="parabolaInsight">Index 3 gets the highest score. The penalty −(i−j)² ensures only the exact match wins.</p>
|
||||
</div>
|
||||
<div class="explainer">
|
||||
<p><strong>Why this matters:</strong> The weight matrix <code>W_K</code> is set so it maps token positions to points on this parabola. <code>W_Q</code> produces the query direction. The result: attention performs an exact array read — no scanning, no approximation. And because the keys are 2D, you can use a convex hull to find the max in O(log n) time instead of checking every key.</p>
|
||||
</div>
|
||||
<button class="next-btn" data-next="tab4">Next: Stack Machine →</button>
|
||||
</section>
|
||||
|
||||
<!-- TAB 4: Stack Machine -->
|
||||
<section id="tab4" class="panel">
|
||||
<h2>Building a Stack Machine from Attention Heads</h2>
|
||||
<div class="explainer">
|
||||
<p>A WebAssembly VM needs: an <strong>instruction pointer</strong>, an <strong>operand stack</strong>, and <strong>memory</strong>. Here's how attention heads provide each:</p>
|
||||
</div>
|
||||
<div class="mechanism-grid">
|
||||
<div class="mechanism">
|
||||
<h3>📍 Instruction Pointer</h3>
|
||||
<p><strong>Mechanism:</strong> Cumulative sum via attention</p>
|
||||
<p>Each trace token emits a delta (+1 for next instruction, or a jump offset). One head averages all deltas uniformly, then multiplies by the token count to recover the running sum = current IP.</p>
|
||||
<div class="mini-demo" id="ipDemo">
|
||||
<div class="ip-trace"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mechanism">
|
||||
<h3>📚 Operand Stack</h3>
|
||||
<p><strong>Mechanism:</strong> Index lookup via 2D keys</p>
|
||||
<p>Stack depth is tracked as a cumulative sum of push/pop deltas. To read the top of stack, the head queries for the current depth using the parabola trick — exact lookup of the most recent value at that depth.</p>
|
||||
<div class="mini-demo" id="stackDemo">
|
||||
<div class="stack-vis"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mechanism">
|
||||
<h3>💾 Memory</h3>
|
||||
<p><strong>Mechanism:</strong> Index lookup via 2D keys</p>
|
||||
<p>Memory addresses work the same way as stack indices. To read <code>mem[addr]</code>, the head queries with direction <code>(addr, 1)</code>. The most recent write to that address has the highest score because it has the largest position component.</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="explainer" style="border-color:var(--gold);border-width:2px">
|
||||
<h3 style="color:var(--gold);margin-bottom:0.6rem">But how does "writing" work? There's no memory!</h3>
|
||||
<p>This is the crucial insight: <strong>writing = emitting a new token</strong>. There is no separate memory. The growing sequence of trace tokens <em>is</em> the memory.</p>
|
||||
|
||||
<div class="write-read-demo">
|
||||
<p style="font-size:0.82rem;color:var(--dim);margin-bottom:0.6rem">Click "Step" to watch the token sequence grow. Each token carries a key (address) and value (data). Reading is just attention looking back at past tokens.</p>
|
||||
<div id="wrTrace" class="wr-trace"></div>
|
||||
<div id="wrAction" class="wr-action"></div>
|
||||
<div class="btn-row" style="margin-top:0.6rem">
|
||||
<button id="wrStep" class="action-btn">Step →</button>
|
||||
<button id="wrReset" class="action-btn secondary">Reset</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style="background:var(--surface2);border-radius:8px;padding:0.7rem 1rem;margin-top:0.8rem;font-size:0.82rem;line-height:1.6">
|
||||
<strong style="color:var(--gold)">Summary:</strong><br>
|
||||
<strong>Write</strong> = emit a token. Its W<sub>K</sub> gives it an address, its W<sub>V</sub> gives it data. It just sits there in the sequence.<br>
|
||||
<strong>Read</strong> = a later token's W<sub>Q</sub> produces a query. Attention scans all past tokens' keys, finds the match, returns that token's value.<br>
|
||||
<strong>Overwrite</strong> = emit another token with the same address. Because it's later in the sequence, the parabola trick gives it a higher score than the old one (the position component breaks ties).<br><br>
|
||||
There is no "write back" step. The sequence only grows. Old values are never erased — they're just overshadowed by newer tokens at the same address.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="demo-box">
|
||||
<h3>Interactive: Step Through a Stack Machine</h3>
|
||||
<p>Click "Step" to execute WASM instructions. Watch how each head reads/writes.</p>
|
||||
<div id="stackMachineVis">
|
||||
<div id="smProgram"></div>
|
||||
<div id="smState"></div>
|
||||
<div id="smHeads"></div>
|
||||
</div>
|
||||
<div class="btn-row">
|
||||
<button id="smStep" class="action-btn">Step →</button>
|
||||
<button id="smReset" class="action-btn secondary">Reset</button>
|
||||
</div>
|
||||
</div>
|
||||
<button class="next-btn" data-next="tab5">Next: Full Execution →</button>
|
||||
</section>
|
||||
|
||||
<!-- TAB 5: Full Execution -->
|
||||
<section id="tab5" class="panel">
|
||||
<h2>Putting It All Together: 3 + 5 Inside a Transformer</h2>
|
||||
<div class="explainer">
|
||||
<p>Now let's see the complete picture. The transformer receives a WASM program as tokens. Then it generates an execution trace — each token is produced by the attention heads doing lookups into the growing trace.</p>
|
||||
</div>
|
||||
<div class="demo-box full-width">
|
||||
<h3>Full Execution Trace: 3 + 5</h3>
|
||||
<div id="fullExec">
|
||||
<div id="feProgram"></div>
|
||||
<div id="feTrace"></div>
|
||||
<div id="feDetail"></div>
|
||||
</div>
|
||||
<div class="btn-row">
|
||||
<button id="feStep" class="action-btn">Step →</button>
|
||||
<button id="feReset" class="action-btn secondary">Reset</button>
|
||||
<button id="feRunAll" class="action-btn secondary">Run All</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="explainer">
|
||||
<h3>Why This Is Different From Tool Use</h3>
|
||||
<div class="comparison">
|
||||
<div class="comp-col">
|
||||
<h4>🔧 Tool Use</h4>
|
||||
<p>LLM writes <code>print(3+5)</code></p>
|
||||
<p>Pauses, sends to external Python</p>
|
||||
<p>Gets back "8" — a black box</p>
|
||||
<p>Execution happened <em>outside</em></p>
|
||||
</div>
|
||||
<div class="comp-col highlight">
|
||||
<h4>⚡ In-Model Execution</h4>
|
||||
<p>LLM emits WASM tokens</p>
|
||||
<p>Each next token = one VM step</p>
|
||||
<p>Attention heads ARE the CPU</p>
|
||||
<p>Every step visible in the trace</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="explainer">
|
||||
<h3>The Weight-Level Summary</h3>
|
||||
<div class="summary-grid">
|
||||
<div class="summary-item">
|
||||
<strong>W<sub>K</sub></strong> maps positions → parabola points<br>
|
||||
<span class="dim">Encodes "what address am I?"</span>
|
||||
</div>
|
||||
<div class="summary-item">
|
||||
<strong>W<sub>Q</sub></strong> maps current state → query direction<br>
|
||||
<span class="dim">Encodes "what address do I need?"</span>
|
||||
</div>
|
||||
<div class="summary-item">
|
||||
<strong>W<sub>V</sub></strong> extracts stored values<br>
|
||||
<span class="dim">Encodes "what data is here?"</span>
|
||||
</div>
|
||||
<div class="summary-item">
|
||||
<strong>W<sub>ff</sub></strong> (feed-forward) does ALU ops<br>
|
||||
<span class="dim">Addition, comparison, branching logic</span>
|
||||
</div>
|
||||
</div>
|
||||
<p>The weights aren't learned from data in the usual sense — they're <strong>engineered</strong> (or compiled) so that the transformer's forward pass literally executes a WASM interpreter. The architecture is a standard PyTorch transformer. Only the weight values are special.</p>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
|
||||
<script src="app.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
1237
interactive/style.css
Normal file
1237
interactive/style.css
Normal file
File diff suppressed because it is too large
Load Diff
523
manim_project/scene.py
Normal file
523
manim_project/scene.py
Normal file
@@ -0,0 +1,523 @@
|
||||
from manim import *
|
||||
import numpy as np
|
||||
|
||||
# Color palette
|
||||
BG = "#0f0f1a"
|
||||
ACCENT = "#4fc3f7"
|
||||
ACCENT2 = "#ab47bc"
|
||||
ACCENT3 = "#66bb6a"
|
||||
WARN = "#ff7043"
|
||||
TEXT_COL = "#e0e0e0"
|
||||
DIM = "#666677"
|
||||
GOLD = "#ffd54f"
|
||||
|
||||
|
||||
class Scene1_Intro(Scene):
|
||||
"""The big question: Can LLMs be computers?"""
|
||||
|
||||
def construct(self):
|
||||
self.camera.background_color = BG
|
||||
|
||||
title = Text("Can LLMs Be Computers?", font_size=56, color=ACCENT, weight=BOLD)
|
||||
sub = Text(
|
||||
"Executing programs inside transformers",
|
||||
font_size=28, color=DIM,
|
||||
).next_to(title, DOWN, buff=0.4)
|
||||
self.play(Write(title, run_time=1.5))
|
||||
self.play(FadeIn(sub, shift=UP * 0.2))
|
||||
self.wait(1.5)
|
||||
|
||||
# Fade out title
|
||||
self.play(FadeOut(title), FadeOut(sub))
|
||||
|
||||
# --- What LLMs are good at vs bad at ---
|
||||
good_title = Text("LLMs are great at…", font_size=32, color=ACCENT3)
|
||||
good_title.to_edge(UP, buff=0.6)
|
||||
good_items = VGroup(
|
||||
Text("✓ Solving hard math (IMO gold-medal level)", font_size=24, color=TEXT_COL),
|
||||
Text("✓ Writing code & reasoning about algorithms", font_size=24, color=TEXT_COL),
|
||||
Text("✓ Understanding natural language", font_size=24, color=TEXT_COL),
|
||||
).arrange(DOWN, aligned_edge=LEFT, buff=0.25).next_to(good_title, DOWN, buff=0.4)
|
||||
|
||||
self.play(Write(good_title))
|
||||
for item in good_items:
|
||||
self.play(FadeIn(item, shift=RIGHT * 0.3), run_time=0.5)
|
||||
self.wait(1)
|
||||
|
||||
bad_title = Text("…but terrible at simple computation", font_size=32, color=WARN)
|
||||
bad_title.next_to(good_items, DOWN, buff=0.6)
|
||||
bad_items = VGroup(
|
||||
Text("✗ Multiplying two numbers reliably", font_size=24, color=WARN),
|
||||
Text("✗ Solving even easy Sudoku puzzles", font_size=24, color=WARN),
|
||||
Text("✗ Any task needing many exact steps", font_size=24, color=WARN),
|
||||
).arrange(DOWN, aligned_edge=LEFT, buff=0.25).next_to(bad_title, DOWN, buff=0.4)
|
||||
|
||||
self.play(Write(bad_title))
|
||||
for item in bad_items:
|
||||
self.play(FadeIn(item, shift=RIGHT * 0.3), run_time=0.5)
|
||||
self.wait(2)
|
||||
self.play(*[FadeOut(m) for m in self.mobjects])
|
||||
|
||||
|
||||
class Scene2_ToolUse(Scene):
|
||||
"""How LLMs currently handle computation: tool use (the airplane analogy)."""
|
||||
|
||||
def construct(self):
|
||||
self.camera.background_color = BG
|
||||
|
||||
# --- Airplane analogy ---
|
||||
analogy_title = Text("The Airplane Analogy", font_size=36, color=ACCENT, weight=BOLD)
|
||||
analogy_title.to_edge(UP, buff=0.5)
|
||||
self.play(Write(analogy_title))
|
||||
|
||||
human_label = Text("Human", font_size=28, color=TEXT_COL)
|
||||
human_box = RoundedRectangle(
|
||||
corner_radius=0.15, width=2.2, height=1.2, color=ACCENT
|
||||
)
|
||||
human_grp = VGroup(human_box, human_label).move_to(LEFT * 3)
|
||||
|
||||
plane_label = Text("Airplane", font_size=28, color=TEXT_COL)
|
||||
plane_box = RoundedRectangle(
|
||||
corner_radius=0.15, width=2.2, height=1.2, color=ACCENT3
|
||||
)
|
||||
plane_grp = VGroup(plane_box, plane_label).move_to(RIGHT * 3)
|
||||
|
||||
arrow = Arrow(human_box.get_right(), plane_box.get_left(), color=DIM, buff=0.2)
|
||||
arrow_label = Text("delegates flying", font_size=20, color=DIM).next_to(arrow, UP, buff=0.15)
|
||||
|
||||
cant = Text("Humans can't fly.", font_size=24, color=WARN).next_to(
|
||||
VGroup(human_grp, plane_grp), DOWN, buff=0.8
|
||||
)
|
||||
but = Text("Building airplanes doesn't change that.", font_size=24, color=WARN).next_to(
|
||||
cant, DOWN, buff=0.25
|
||||
)
|
||||
|
||||
self.play(FadeIn(human_grp), FadeIn(plane_grp))
|
||||
self.play(GrowArrow(arrow), FadeIn(arrow_label))
|
||||
self.wait(0.5)
|
||||
self.play(Write(cant))
|
||||
self.play(Write(but))
|
||||
self.wait(2)
|
||||
self.play(*[FadeOut(m) for m in self.mobjects if m != analogy_title])
|
||||
self.play(FadeOut(analogy_title))
|
||||
|
||||
# --- Tool use diagram ---
|
||||
tool_title = Text("Today: LLMs Use External Tools", font_size=34, color=ACCENT, weight=BOLD)
|
||||
tool_title.to_edge(UP, buff=0.5)
|
||||
self.play(Write(tool_title))
|
||||
|
||||
llm_box = RoundedRectangle(corner_radius=0.2, width=2.5, height=1.4, color=ACCENT)
|
||||
llm_label = Text("LLM", font_size=30, color=ACCENT, weight=BOLD)
|
||||
llm_grp = VGroup(llm_box, llm_label).move_to(LEFT * 3.5)
|
||||
|
||||
interp_box = RoundedRectangle(corner_radius=0.2, width=3, height=1.4, color=ACCENT3)
|
||||
interp_label = Text("Interpreter", font_size=26, color=ACCENT3)
|
||||
interp_grp = VGroup(interp_box, interp_label).move_to(RIGHT * 3)
|
||||
|
||||
a1 = Arrow(llm_box.get_right(), interp_box.get_left(), color=GOLD, buff=0.2).shift(UP * 0.2)
|
||||
a1_label = Text("sends code", font_size=18, color=GOLD).next_to(a1, UP, buff=0.1)
|
||||
a2 = Arrow(interp_box.get_left(), llm_box.get_right(), color=ACCENT3, buff=0.2).shift(DOWN * 0.2)
|
||||
a2_label = Text("returns result", font_size=18, color=ACCENT3).next_to(a2, DOWN, buff=0.1)
|
||||
|
||||
note = Text(
|
||||
"The LLM describes computation\nbut never executes it.",
|
||||
font_size=22, color=WARN, line_spacing=1.3,
|
||||
).next_to(VGroup(llm_grp, interp_grp), DOWN, buff=0.9)
|
||||
|
||||
self.play(FadeIn(llm_grp), FadeIn(interp_grp))
|
||||
self.play(GrowArrow(a1), FadeIn(a1_label))
|
||||
self.play(GrowArrow(a2), FadeIn(a2_label))
|
||||
self.wait(0.5)
|
||||
self.play(Write(note))
|
||||
self.wait(2.5)
|
||||
self.play(*[FadeOut(m) for m in self.mobjects])
|
||||
|
||||
|
||||
class Scene3_KeyIdea(Scene):
|
||||
"""The breakthrough: building a computer INSIDE the transformer."""
|
||||
|
||||
def construct(self):
|
||||
self.camera.background_color = BG
|
||||
|
||||
title = Text("The Breakthrough", font_size=40, color=GOLD, weight=BOLD)
|
||||
title.to_edge(UP, buff=0.5)
|
||||
self.play(Write(title))
|
||||
|
||||
idea = Text(
|
||||
"Build a computer INSIDE the transformer.",
|
||||
font_size=30, color=TEXT_COL,
|
||||
).next_to(title, DOWN, buff=0.6)
|
||||
self.play(Write(idea))
|
||||
self.wait(1)
|
||||
|
||||
# Big transformer box
|
||||
tf_box = RoundedRectangle(
|
||||
corner_radius=0.25, width=9, height=4.5, color=ACCENT, stroke_width=3
|
||||
).shift(DOWN * 0.5)
|
||||
tf_label = Text("Transformer", font_size=22, color=ACCENT).next_to(tf_box, UP, buff=0.15)
|
||||
|
||||
# CPU inside
|
||||
cpu_box = RoundedRectangle(
|
||||
corner_radius=0.15, width=3, height=2, color=GOLD, stroke_width=2
|
||||
).move_to(tf_box.get_center() + LEFT * 2.2)
|
||||
cpu_label = Text("Virtual CPU\n(WebAssembly)", font_size=20, color=GOLD, line_spacing=1.2).move_to(cpu_box)
|
||||
|
||||
# Memory inside
|
||||
mem_box = RoundedRectangle(
|
||||
corner_radius=0.15, width=2.5, height=2, color=ACCENT3, stroke_width=2
|
||||
).move_to(tf_box.get_center() + RIGHT * 2)
|
||||
mem_label = Text("Memory\n& Stack", font_size=20, color=ACCENT3, line_spacing=1.2).move_to(mem_box)
|
||||
|
||||
conn = Arrow(cpu_box.get_right(), mem_box.get_left(), color=DIM, buff=0.15)
|
||||
|
||||
self.play(Create(tf_box), Write(tf_label))
|
||||
self.play(Create(cpu_box), Write(cpu_label))
|
||||
self.play(Create(mem_box), Write(mem_label))
|
||||
self.play(GrowArrow(conn))
|
||||
|
||||
result = Text(
|
||||
"Arbitrary C programs → tokens → executed by the model itself",
|
||||
font_size=22, color=ACCENT,
|
||||
).next_to(tf_box, DOWN, buff=0.5)
|
||||
self.play(Write(result))
|
||||
self.wait(2.5)
|
||||
self.play(*[FadeOut(m) for m in self.mobjects])
|
||||
|
||||
|
||||
class Scene4_ToolVsInModel(Scene):
|
||||
"""Side-by-side: tool use vs in-model execution for 3+5."""
|
||||
|
||||
def construct(self):
|
||||
self.camera.background_color = BG
|
||||
|
||||
title = Text("Tool Use vs In-Model Execution", font_size=34, color=ACCENT, weight=BOLD)
|
||||
title.to_edge(UP, buff=0.4)
|
||||
self.play(Write(title))
|
||||
|
||||
divider = DashedLine(UP * 2.5, DOWN * 2.5, color=DIM, dash_length=0.15).shift(DOWN * 0.3)
|
||||
self.play(Create(divider))
|
||||
|
||||
# LEFT: tool use
|
||||
left_title = Text("Tool Use", font_size=26, color=WARN, weight=BOLD).move_to(LEFT * 3.5 + UP * 2)
|
||||
self.play(Write(left_title))
|
||||
|
||||
left_steps = [
|
||||
("LLM:", 'print(3+5)', ACCENT, GOLD),
|
||||
("→ send to interpreter", "", DIM, DIM),
|
||||
("← result: 8", "", ACCENT3, ACCENT3),
|
||||
]
|
||||
left_grp = VGroup()
|
||||
y = 1.0
|
||||
for label_text, code_text, lc, cc in left_steps:
|
||||
row = VGroup()
|
||||
lab = Text(label_text, font_size=20, color=lc).move_to(LEFT * 5 + UP * y)
|
||||
lab.align_to(LEFT * 5.5, LEFT)
|
||||
row.add(lab)
|
||||
if code_text:
|
||||
cd = Text(code_text, font_size=18, color=cc, font="Monospace").next_to(lab, RIGHT, buff=0.2)
|
||||
row.add(cd)
|
||||
left_grp.add(row)
|
||||
y -= 0.7
|
||||
|
||||
opaque = Text("Execution is opaque\n(black box)", font_size=18, color=WARN, line_spacing=1.2)
|
||||
opaque.move_to(LEFT * 3.5 + DOWN * 1.5)
|
||||
|
||||
# RIGHT: in-model
|
||||
right_title = Text("In-Model", font_size=26, color=ACCENT3, weight=BOLD).move_to(RIGHT * 3.5 + UP * 2)
|
||||
self.play(Write(right_title))
|
||||
|
||||
right_lines = [
|
||||
"i32.const 03",
|
||||
"i32.const 05",
|
||||
"i32.add → 08",
|
||||
"output(08)",
|
||||
"halt",
|
||||
]
|
||||
right_grp = VGroup()
|
||||
y = 1.0
|
||||
for line in right_lines:
|
||||
t = Text(line, font_size=18, color=ACCENT3, font="Monospace").move_to(RIGHT * 3.5 + UP * y)
|
||||
right_grp.add(t)
|
||||
y -= 0.55
|
||||
|
||||
transparent = Text("Every step visible\nin the token stream", font_size=18, color=ACCENT3, line_spacing=1.2)
|
||||
transparent.move_to(RIGHT * 3.5 + DOWN * 2)
|
||||
|
||||
for row in left_grp:
|
||||
self.play(FadeIn(row, shift=RIGHT * 0.2), run_time=0.5)
|
||||
self.play(Write(opaque))
|
||||
self.wait(0.5)
|
||||
|
||||
for t in right_grp:
|
||||
self.play(FadeIn(t, shift=LEFT * 0.2), run_time=0.4)
|
||||
self.play(Write(transparent))
|
||||
self.wait(2.5)
|
||||
self.play(*[FadeOut(m) for m in self.mobjects])
|
||||
|
||||
|
||||
class Scene5_AppendOnlyTrace(Scene):
|
||||
"""Computation as an append-only trace — the notebook analogy."""
|
||||
|
||||
def construct(self):
|
||||
self.camera.background_color = BG
|
||||
|
||||
title = Text("How It Works: The Append-Only Notebook", font_size=32, color=ACCENT, weight=BOLD)
|
||||
title.to_edge(UP, buff=0.4)
|
||||
self.play(Write(title))
|
||||
|
||||
# Notebook visual
|
||||
nb_rect = Rectangle(width=7, height=4.5, color=DIM, stroke_width=1.5).shift(DOWN * 0.4)
|
||||
nb_label = Text("Notebook (token stream)", font_size=18, color=DIM).next_to(nb_rect, UP, buff=0.1)
|
||||
self.play(Create(nb_rect), Write(nb_label))
|
||||
|
||||
# Lines
|
||||
prompt_words = ["the", "cat", "runs", "and", "dog", "jumps", "over"]
|
||||
colors = [DIM, DIM, ACCENT3, DIM, DIM, ACCENT3, DIM]
|
||||
prompt_label = Text("PROMPT (input words):", font_size=18, color=ACCENT).move_to(
|
||||
nb_rect.get_top() + DOWN * 0.4
|
||||
)
|
||||
self.play(Write(prompt_label))
|
||||
|
||||
word_mobs = VGroup()
|
||||
for i, (w, c) in enumerate(zip(prompt_words, colors)):
|
||||
t = Text(w, font_size=22, color=c, font="Monospace")
|
||||
t.move_to(nb_rect.get_top() + DOWN * 0.8 + RIGHT * (i - 3) * 1.0)
|
||||
word_mobs.add(t)
|
||||
self.play(*[FadeIn(w) for w in word_mobs], run_time=0.8)
|
||||
self.wait(0.5)
|
||||
|
||||
# Trace section
|
||||
trace_label = Text("EXECUTION TRACE (generated tokens):", font_size=18, color=GOLD).move_to(
|
||||
nb_rect.get_center() + UP * 0.3
|
||||
)
|
||||
self.play(Write(trace_label))
|
||||
|
||||
# Show parity counting step by step
|
||||
trace_vals = ["odd=0", "odd=0", "odd=1", "odd=1", "odd=1", "odd=0", "odd=0"]
|
||||
trace_mobs = VGroup()
|
||||
for i, val in enumerate(trace_vals):
|
||||
t = Text(val, font_size=20, color=GOLD, font="Monospace")
|
||||
t.move_to(nb_rect.get_center() + DOWN * 0.3 + RIGHT * (i - 3) * 1.0)
|
||||
trace_mobs.add(t)
|
||||
|
||||
rule = Text(
|
||||
"Each new token looks back at (1) the input word and (2) the previous trace token",
|
||||
font_size=17, color=TEXT_COL,
|
||||
).move_to(nb_rect.get_bottom() + DOWN * 0.05)
|
||||
|
||||
for i, tm in enumerate(trace_mobs):
|
||||
anims = [FadeIn(tm)]
|
||||
# Draw arrows from word and previous trace
|
||||
arr1 = Arrow(
|
||||
word_mobs[i].get_bottom(), tm.get_top(),
|
||||
color=ACCENT, stroke_width=1.5, buff=0.08, max_tip_length_to_length_ratio=0.15,
|
||||
)
|
||||
anims.append(GrowArrow(arr1))
|
||||
if i > 0:
|
||||
arr2 = Arrow(
|
||||
trace_mobs[i - 1].get_right(), tm.get_left(),
|
||||
color=GOLD, stroke_width=1.5, buff=0.08, max_tip_length_to_length_ratio=0.15,
|
||||
)
|
||||
anims.append(GrowArrow(arr2))
|
||||
self.play(*anims, run_time=0.6)
|
||||
|
||||
self.play(Write(rule))
|
||||
|
||||
key_insight = Text(
|
||||
"Key: only 2 lookbacks per step — cost doesn't grow with length!",
|
||||
font_size=20, color=ACCENT3, weight=BOLD,
|
||||
).next_to(nb_rect, DOWN, buff=0.6)
|
||||
self.play(Write(key_insight))
|
||||
self.wait(2.5)
|
||||
self.play(*[FadeOut(m) for m in self.mobjects])
|
||||
|
||||
|
||||
class Scene6_QuadraticProblem(Scene):
|
||||
"""The quadratic cost problem of standard attention."""
|
||||
|
||||
def construct(self):
|
||||
self.camera.background_color = BG
|
||||
|
||||
title = Text("The Problem: Attention Cost Grows", font_size=34, color=ACCENT, weight=BOLD)
|
||||
title.to_edge(UP, buff=0.5)
|
||||
self.play(Write(title))
|
||||
|
||||
# Axes
|
||||
ax = Axes(
|
||||
x_range=[0, 10, 2], y_range=[0, 100, 20],
|
||||
x_length=5, y_length=3.5,
|
||||
axis_config={"color": DIM, "include_numbers": False},
|
||||
).shift(DOWN * 0.5)
|
||||
x_lab = Text("tokens generated (t)", font_size=18, color=DIM).next_to(ax.x_axis, DOWN, buff=0.3)
|
||||
y_lab = Text("work per step", font_size=18, color=DIM).next_to(ax.y_axis, LEFT, buff=0.3).rotate(PI / 2)
|
||||
|
||||
self.play(Create(ax), Write(x_lab), Write(y_lab))
|
||||
|
||||
# Standard: linear cost per step → quadratic total
|
||||
std_graph = ax.plot(lambda x: x * 10, x_range=[0.1, 10], color=WARN, stroke_width=3)
|
||||
std_label = Text("Standard attention: O(t) per step", font_size=18, color=WARN)
|
||||
std_label.next_to(ax, RIGHT, buff=0.3).shift(UP * 1)
|
||||
|
||||
self.play(Create(std_graph), Write(std_label), run_time=1.5)
|
||||
self.wait(1)
|
||||
|
||||
# Log cost
|
||||
log_graph = ax.plot(lambda x: np.log2(x + 1) * 8, x_range=[0.1, 10], color=ACCENT3, stroke_width=3)
|
||||
log_label = Text("Hull attention: O(log t) per step", font_size=18, color=ACCENT3)
|
||||
log_label.next_to(std_label, DOWN, buff=0.4)
|
||||
|
||||
self.play(Create(log_graph), Write(log_label), run_time=1.5)
|
||||
|
||||
gap_text = Text(
|
||||
"At 1 million steps:\nstandard = 1,000,000 ops/step\nhull = ~20 ops/step",
|
||||
font_size=18, color=GOLD, line_spacing=1.3,
|
||||
).next_to(ax, DOWN, buff=0.7)
|
||||
self.play(Write(gap_text))
|
||||
self.wait(3)
|
||||
self.play(*[FadeOut(m) for m in self.mobjects])
|
||||
|
||||
|
||||
class Scene7_2DHeads(Scene):
|
||||
"""The key unlock: 2D attention heads and convex hull queries."""
|
||||
|
||||
def construct(self):
|
||||
self.camera.background_color = BG
|
||||
|
||||
title = Text("The Key Unlock: 2D Attention Heads", font_size=34, color=ACCENT, weight=BOLD)
|
||||
title.to_edge(UP, buff=0.4)
|
||||
self.play(Write(title))
|
||||
|
||||
# Explain head dimension
|
||||
expl = Text(
|
||||
"Restrict each attention head to dimension 2\n→ keys & queries become 2D points",
|
||||
font_size=22, color=TEXT_COL, line_spacing=1.3,
|
||||
).next_to(title, DOWN, buff=0.5)
|
||||
self.play(Write(expl))
|
||||
self.wait(1)
|
||||
|
||||
# 2D scatter of keys
|
||||
np.random.seed(42)
|
||||
n_pts = 20
|
||||
pts = np.random.randn(n_pts, 2) * 1.2
|
||||
plane = NumberPlane(
|
||||
x_range=[-3, 3, 1], y_range=[-3, 3, 1],
|
||||
x_length=5, y_length=4,
|
||||
background_line_style={"stroke_color": DIM, "stroke_width": 0.5},
|
||||
axis_config={"stroke_color": DIM},
|
||||
).shift(DOWN * 0.7)
|
||||
|
||||
dots = VGroup()
|
||||
for pt in pts:
|
||||
d = Dot(plane.c2p(pt[0], pt[1]), radius=0.06, color=ACCENT)
|
||||
dots.add(d)
|
||||
|
||||
key_label = Text("Keys (past tokens)", font_size=16, color=ACCENT).next_to(plane, LEFT, buff=0.3)
|
||||
self.play(Create(plane), Write(key_label))
|
||||
self.play(*[FadeIn(d, scale=0.5) for d in dots], run_time=1)
|
||||
self.wait(0.5)
|
||||
|
||||
# Convex hull
|
||||
from scipy.spatial import ConvexHull # noqa: E402
|
||||
|
||||
hull = ConvexHull(pts)
|
||||
hull_pts = [plane.c2p(pts[i][0], pts[i][1]) for i in hull.vertices]
|
||||
hull_pts.append(hull_pts[0])
|
||||
hull_poly = Polygon(*hull_pts, color=GOLD, stroke_width=2, fill_opacity=0.08, fill_color=GOLD)
|
||||
|
||||
hull_label = Text("Convex Hull", font_size=18, color=GOLD).next_to(plane, RIGHT, buff=0.3).shift(UP)
|
||||
self.play(Create(hull_poly), Write(hull_label))
|
||||
self.wait(0.5)
|
||||
|
||||
# Query point
|
||||
q_pt = plane.c2p(1.5, 0.5)
|
||||
q_dot = Dot(q_pt, radius=0.1, color=WARN)
|
||||
q_label = Text("Query", font_size=16, color=WARN).next_to(q_dot, UR, buff=0.1)
|
||||
self.play(FadeIn(q_dot, scale=2), Write(q_label))
|
||||
|
||||
insight = Text(
|
||||
"Finding the best-matching key = a geometric query on the hull\n"
|
||||
"→ binary search in O(log t) instead of scanning all t keys",
|
||||
font_size=18, color=ACCENT3, line_spacing=1.3,
|
||||
).next_to(plane, DOWN, buff=0.5)
|
||||
self.play(Write(insight))
|
||||
self.wait(3)
|
||||
self.play(*[FadeOut(m) for m in self.mobjects])
|
||||
|
||||
|
||||
class Scene8_Results(Scene):
|
||||
"""Performance results and what this means."""
|
||||
|
||||
def construct(self):
|
||||
self.camera.background_color = BG
|
||||
|
||||
title = Text("Results: What This Enables", font_size=36, color=ACCENT, weight=BOLD)
|
||||
title.to_edge(UP, buff=0.5)
|
||||
self.play(Write(title))
|
||||
|
||||
results = VGroup(
|
||||
self._result_card("30,000+ tok/s", "Execution speed on CPU", ACCENT3),
|
||||
self._result_card("Millions of steps", "Correct execution length", GOLD),
|
||||
self._result_card("100% accuracy", "On Sudoku benchmarks\n(incl. world's hardest)", ACCENT),
|
||||
).arrange(RIGHT, buff=0.6).next_to(title, DOWN, buff=0.8)
|
||||
|
||||
for card in results:
|
||||
self.play(FadeIn(card, shift=UP * 0.3), run_time=0.7)
|
||||
self.wait(0.5)
|
||||
|
||||
# Sudoku callout
|
||||
sudoku_note = Text(
|
||||
"Solves Arto Inkala's Sudoku (\"world's hardest\")\nin under 3 minutes — fully inside the transformer",
|
||||
font_size=20, color=TEXT_COL, line_spacing=1.3,
|
||||
).next_to(results, DOWN, buff=0.8)
|
||||
self.play(Write(sudoku_note))
|
||||
self.wait(2.5)
|
||||
self.play(*[FadeOut(m) for m in self.mobjects])
|
||||
|
||||
def _result_card(self, big_text, small_text, color):
|
||||
box = RoundedRectangle(corner_radius=0.15, width=3.5, height=2.5, color=color, stroke_width=2)
|
||||
big = Text(big_text, font_size=28, color=color, weight=BOLD)
|
||||
small = Text(small_text, font_size=16, color=TEXT_COL, line_spacing=1.2)
|
||||
grp = VGroup(big, small).arrange(DOWN, buff=0.3)
|
||||
return VGroup(box, grp)
|
||||
|
||||
|
||||
class Scene9_Recap(Scene):
|
||||
"""Final recap and takeaway."""
|
||||
|
||||
def construct(self):
|
||||
self.camera.background_color = BG
|
||||
|
||||
title = Text("Recap", font_size=40, color=ACCENT, weight=BOLD)
|
||||
title.to_edge(UP, buff=0.5)
|
||||
self.play(Write(title))
|
||||
|
||||
steps = [
|
||||
("1", "LLMs struggle with long, exact computation", TEXT_COL),
|
||||
("2", "Today we bolt on external tools (like airplanes for humans)", DIM),
|
||||
("3", "This team built a real computer inside a transformer", ACCENT3),
|
||||
("4", "C code → WebAssembly tokens → executed by the model itself", GOLD),
|
||||
("5", "2D attention heads enable O(log t) lookups (exponentially faster)", ACCENT),
|
||||
("6", "Result: millions of correct steps, 30k+ tokens/sec, 100% Sudoku", ACCENT3),
|
||||
]
|
||||
|
||||
grp = VGroup()
|
||||
for num, text, color in steps:
|
||||
num_mob = Text(num, font_size=24, color=ACCENT2, weight=BOLD)
|
||||
text_mob = Text(text, font_size=21, color=color)
|
||||
row = VGroup(num_mob, text_mob).arrange(RIGHT, buff=0.3)
|
||||
grp.add(row)
|
||||
grp.arrange(DOWN, aligned_edge=LEFT, buff=0.35).next_to(title, DOWN, buff=0.6)
|
||||
|
||||
for row in grp:
|
||||
self.play(FadeIn(row, shift=RIGHT * 0.3), run_time=0.6)
|
||||
self.wait(0.3)
|
||||
|
||||
self.wait(1)
|
||||
|
||||
takeaway = Text(
|
||||
"The model stops being a coordinator of computation\nand becomes a computer itself.",
|
||||
font_size=24, color=GOLD, weight=BOLD, line_spacing=1.3,
|
||||
).next_to(grp, DOWN, buff=0.7)
|
||||
self.play(Write(takeaway, run_time=2))
|
||||
self.wait(3)
|
||||
self.play(*[FadeOut(m) for m in self.mobjects])
|
||||
Reference in New Issue
Block a user