Repositories / more_nnsight.git
scripts/equivalence.py
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
"""
Establishing the equivalence between activation steering approaches in nnsight.
This script investigates how nnsight's .generate() applies interventions and confirms
that using position -1 in a plain invoke context is equivalent to using the concrete
last-prompt-token index with tracer.iter[0].
Key findings established by these experiments:
1. INTERVENTION FIRES ONCE (on prompt pass only)
nnsight's plain `with tracer.invoke(prompt):` fires the intervention graph exactly
once, during the first forward pass over the full prompt. Subsequent generation
steps use the KV cache and process only 1 new token per step — the intervention
does NOT fire again.
2. POSITION -1 CORRECTLY TARGETS THE LAST PROMPT TOKEN
Since the intervention fires once when the full prompt (seq_len=4) is being
processed, `activation[-1]` resolves to position 3 (the last prompt token).
This is exactly what we want for steering.
3. THE EFFECT PERSISTS VIA THE KV CACHE
Corrupting a hidden state at position 3 during the prompt pass alters the
key/value pairs stored in the KV cache for that position. All subsequent
generation steps attend to these corrupted values, so the effect propagates
through the entire generation without re-applying the intervention.
4. iter[0] WITH CONCRETE POSITION GIVES IDENTICAL RESULTS
Applying the intervention at step 0 only (via tracer.iter[0]) with position n-1
is bit-for-bit identical to using -1 in the plain invoke context.
5. iter[:] FAILS FOR POSITION > 0 AT STEPS 1+
At generation steps 1 and beyond (KV cache steps), the activation tensor has
shape [1, hidden_size] — only the new token. Position 3 is out of bounds.
This is proof that nnsight uses the KV cache during generation.
Usage:
uv run --no-sync python scripts/equivalence.py
Expected output (GPT-2, prompt "The weather today is"):
=== Experiment 1: Which prompt position matters? ===
Prompt: 4 tokens ['The', 'Ġweather', 'Ġtoday', 'Ġis']
Normal (no zero) : very good, and we
Zero [0] : very good, and we
Zero [1] : very good, and we
Zero [2] : very good, and we
Zero [3] = [-1] : , of course, very
Only zeroing the last prompt token (position 3 = -1) changes the output
when using a late layer (e.g. layer 11). Positions 0-2 have no effect
because the last prompt token is what directly determines the next-token
logits in an autoregressive model. (At early layers like layer 0, all
positions matter because the corruption propagates forward through all
subsequent layers.)
=== Experiment 2: KV cache persistence ===
Steered 6 tokens: ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ', 'Ċ']
Context for trace: ['The', 'Ġweather', 'Ġtoday', 'Ġis', 'Ċ', 'Ċ', 'Ċ', '"']
Actual 5th token from generate: '\n' (id=198)
Predicted by full trace at pos 3: '\n' (id=198)
KV cache persistent == full recompute: True
=== Experiment 3: iter[0] equivalence ===
iter[0] step shape: torch.Size([4, 768])
Approach 1 (-1, plain invoke): ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ']
Approach 2 (iter[0], pos 3): ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ']
Identical: True
=== Experiment 4: iter[:] out-of-bounds at steps 1+ ===
Shapes seen before error: [(0, (4, 768)), (1, (1, 768))]
IndexError at step 1+: index 3 is out of bounds for dimension 0 with size 1
-> Step 0 has seq_len=4, position 3 valid.
-> Steps 1+ have seq_len=1 (KV cache), position 3 out of bounds.
=== Experiment 5: Unsteered differs from steered ===
Unsteered: ['Ġvery', 'Ġgood', ',', 'Ġand', 'Ġwe']
Steered: ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ']
Different: True
"""
import torch
torch.cuda.memory.set_per_process_memory_fraction(0.8)
from nnsight import LanguageModel
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
tokenizer = model.tokenizer
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
prompt = "The weather today is"
prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"][0]
n = len(prompt_ids)
# ---------------------------------------------------------------------------
# Experiment 1: Which token position matters?
#
# We zero layer 11's output at each individual token position to determine
# which positions influence the generated output. At a late layer like 11,
# only the last prompt token (position n-1 = -1) changes the output; earlier
# positions do not. This is because causal attention means the last position
# is the only one whose representation directly feeds into the next-token
# logits. (At early layers like layer 0, corrupting any position propagates
# through all subsequent layers and changes the output.)
# ---------------------------------------------------------------------------
print("=" * 60)
print("Experiment 1: Which prompt position matters? (layer 11)")
print("=" * 60)
print(f"Prompt: {n} tokens {tokenizer.convert_ids_to_tokens(prompt_ids.tolist())}")
for label, idx in [
("Normal (no zero)", None),
("Zero [0]", 0),
("Zero [1]", 1),
("Zero [2]", 2),
("Zero [3] = [-1]", 3),
]:
with model.generate(max_new_tokens=5, do_sample=False) as tracer:
with tracer.invoke(prompt):
if idx is not None:
model.transformer.h[11].output[0][idx] = 0.0
out_ids = model.generator.output.save()
response = tokenizer.decode(out_ids[0][n:], skip_special_tokens=True)
print(f"{label:22s}: {repr(response)}")
# ---------------------------------------------------------------------------
# Experiment 2: KV cache persistence.
#
# We confirm that the intervention on the prompt pass persists via the KV
# cache through all subsequent generation steps. Specifically:
# - Run generate from P with zero at position n-1, collect 6 steered tokens.
# - Do a fresh full forward pass over P + first 4 steered tokens, with the
# same zero at position n-1.
# - The full-pass logits at the last position should predict the same 5th
# token that generate produced.
#
# This works because transformers with causal attention are equivalent whether
# computed incrementally (KV cache) or all at once, as long as the intervention
# is applied at the same absolute position.
# ---------------------------------------------------------------------------
print()
print("=" * 60)
print("Experiment 2: KV cache persistence")
print("=" * 60)
with model.generate(max_new_tokens=6, do_sample=False) as tracer:
with tracer.invoke(prompt):
h = model.transformer.h[0].output[0]
model.transformer.h[0].output[0][-1] = torch.zeros_like(h[-1])
out_ids = model.generator.output.save()
steered = out_ids[0][n:].tolist()
t5_actual = steered[4]
print(f"Steered 6 tokens: {tokenizer.convert_ids_to_tokens(steered)}")
# Full trace over P + first 4 steered tokens, zero at the same absolute position.
ctx = torch.cat([prompt_ids, torch.tensor(steered[:4])]).unsqueeze(0)
print(f"Context for trace: {tokenizer.convert_ids_to_tokens(ctx[0].tolist())}")
with torch.no_grad():
with model.trace() as tracer:
with tracer.invoke(ctx):
h = model.transformer.h[0].output[0]
model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1])
logits_proxy = model.lm_head.output.save()
logits = torch.as_tensor(logits_proxy).float().cpu() # shape [1, seq, vocab]
last_logits = logits[0, -1] # logits predicting token after ctx
t5_predicted = torch.argmax(last_logits).item()
print(f"Actual 5th token from generate: {repr(tokenizer.decode([t5_actual]))} (id={t5_actual})")
print(f"Predicted by full trace at pos {n-1}: {repr(tokenizer.decode([t5_predicted]))} (id={t5_predicted})")
print(f"KV cache persistent == full recompute: {t5_actual == t5_predicted}")
# ---------------------------------------------------------------------------
# Experiment 3: iter[0] with concrete position == -1 in plain invoke.
#
# tracer.iter[0] applies the intervention graph only at generation step 0,
# which is the prompt pass. At this step the activation has shape
# [seq_len, hidden_size] so position n-1 is valid. This is bit-for-bit
# identical to using -1 in the plain invoke context.
# ---------------------------------------------------------------------------
print()
print("=" * 60)
print("Experiment 3: iter[0] equivalence")
print("=" * 60)
# Approach 1: -1 in plain invoke (fires once on prompt pass)
with model.generate(max_new_tokens=5, do_sample=False) as tracer:
with tracer.invoke(prompt):
h = model.transformer.h[0].output[0]
model.transformer.h[0].output[0][-1] = torch.zeros_like(h[-1])
out1 = model.generator.output.save()
tokens1 = out1[0][n:].tolist()
# Approach 2: iter[0] with concrete position n-1 (fires once at step 0)
with model.generate(max_new_tokens=5, do_sample=False) as tracer:
with tracer.invoke(prompt):
for step in tracer.iter[0]:
h = model.transformer.h[0].output[0]
print(f" iter[0] step shape: {h.shape}")
model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1])
out2 = model.generator.output.save()
tokens2 = out2[0][n:].tolist()
print(f"Approach 1 (-1, plain invoke): {tokenizer.convert_ids_to_tokens(tokens1)}")
print(f"Approach 2 (iter[0], pos {n-1}): {tokenizer.convert_ids_to_tokens(tokens2)}")
print(f"Identical: {tokens1 == tokens2}")
# ---------------------------------------------------------------------------
# Experiment 4: iter[:] fails at steps 1+ for position > 0.
#
# tracer.iter[:] (= tracer.all()) fires the intervention at every generation
# step. At step 0 the full prompt is processed (shape [seq_len, hidden]).
# At steps 1+ nnsight uses the KV cache, processing only the new token
# (shape [1, hidden]). Trying to access position n-1 at those steps raises
# an IndexError, proving the KV cache is active.
# ---------------------------------------------------------------------------
print()
print("=" * 60)
print("Experiment 4: iter[:] shape at each step")
print("=" * 60)
shapes = []
try:
with model.generate(max_new_tokens=5, do_sample=False) as tracer:
with tracer.invoke(prompt):
for step in tracer.iter[:]:
h = model.transformer.h[0].output[0]
shapes.append((step, tuple(h.shape)))
# This will raise IndexError at step 1+ because size is 1
model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1])
model.generator.output.save()
except IndexError as e:
print(f"Shapes seen before error: {shapes}")
print(f"IndexError at step 1+: {e}")
print(f" -> Step 0 has seq_len={n}, position {n-1} valid.")
print(f" -> Steps 1+ have seq_len=1 (KV cache), position {n-1} out of bounds.")
# ---------------------------------------------------------------------------
# Experiment 5: Sanity check — unsteered generation differs from steered.
#
# Confirms the intervention actually has an effect: generation without any
# intervention produces different tokens than generation with position -1
# zeroed. Without this check, the equivalences above could be vacuous.
# ---------------------------------------------------------------------------
print()
print("=" * 60)
print("Experiment 5: Unsteered differs from steered")
print("=" * 60)
with model.generate(max_new_tokens=5, do_sample=False) as tracer:
with tracer.invoke(prompt):
out_unsteered = model.generator.output.save()
tokens_unsteered = out_unsteered[0][n:].tolist()
tokens_steered = out1[0][n:].tolist() # from experiment 3
print(f"Unsteered: {tokenizer.convert_ids_to_tokens(tokens_unsteered)}")
print(f"Steered: {tokenizer.convert_ids_to_tokens(tokens_steered)}")
print(f"Different: {tokens_unsteered != tokens_steered}")