Repositories / more_nnsight.git

scripts/equivalence.py

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch
11938 bytes · 98873feffaf7
""" Establishing the equivalence between activation steering approaches in nnsight. This script investigates how nnsight's .generate() applies interventions and confirms that using position -1 in a plain invoke context is equivalent to using the concrete last-prompt-token index with tracer.iter[0]. Key findings established by these experiments: 1. INTERVENTION FIRES ONCE (on prompt pass only) nnsight's plain `with tracer.invoke(prompt):` fires the intervention graph exactly once, during the first forward pass over the full prompt. Subsequent generation steps use the KV cache and process only 1 new token per step — the intervention does NOT fire again. 2. POSITION -1 CORRECTLY TARGETS THE LAST PROMPT TOKEN Since the intervention fires once when the full prompt (seq_len=4) is being processed, `activation[-1]` resolves to position 3 (the last prompt token). This is exactly what we want for steering. 3. THE EFFECT PERSISTS VIA THE KV CACHE Corrupting a hidden state at position 3 during the prompt pass alters the key/value pairs stored in the KV cache for that position. All subsequent generation steps attend to these corrupted values, so the effect propagates through the entire generation without re-applying the intervention. 4. iter[0] WITH CONCRETE POSITION GIVES IDENTICAL RESULTS Applying the intervention at step 0 only (via tracer.iter[0]) with position n-1 is bit-for-bit identical to using -1 in the plain invoke context. 5. iter[:] FAILS FOR POSITION > 0 AT STEPS 1+ At generation steps 1 and beyond (KV cache steps), the activation tensor has shape [1, hidden_size] — only the new token. Position 3 is out of bounds. This is proof that nnsight uses the KV cache during generation. Usage: uv run --no-sync python scripts/equivalence.py Expected output (GPT-2, prompt "The weather today is"): === Experiment 1: Which prompt position matters? === Prompt: 4 tokens ['The', 'Ġweather', 'Ġtoday', 'Ġis'] Normal (no zero) : very good, and we Zero [0] : very good, and we Zero [1] : very good, and we Zero [2] : very good, and we Zero [3] = [-1] : , of course, very Only zeroing the last prompt token (position 3 = -1) changes the output when using a late layer (e.g. layer 11). Positions 0-2 have no effect because the last prompt token is what directly determines the next-token logits in an autoregressive model. (At early layers like layer 0, all positions matter because the corruption propagates forward through all subsequent layers.) === Experiment 2: KV cache persistence === Steered 6 tokens: ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ', 'Ċ'] Context for trace: ['The', 'Ġweather', 'Ġtoday', 'Ġis', 'Ċ', 'Ċ', 'Ċ', '"'] Actual 5th token from generate: '\n' (id=198) Predicted by full trace at pos 3: '\n' (id=198) KV cache persistent == full recompute: True === Experiment 3: iter[0] equivalence === iter[0] step shape: torch.Size([4, 768]) Approach 1 (-1, plain invoke): ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ'] Approach 2 (iter[0], pos 3): ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ'] Identical: True === Experiment 4: iter[:] out-of-bounds at steps 1+ === Shapes seen before error: [(0, (4, 768)), (1, (1, 768))] IndexError at step 1+: index 3 is out of bounds for dimension 0 with size 1 -> Step 0 has seq_len=4, position 3 valid. -> Steps 1+ have seq_len=1 (KV cache), position 3 out of bounds. === Experiment 5: Unsteered differs from steered === Unsteered: ['Ġvery', 'Ġgood', ',', 'Ġand', 'Ġwe'] Steered: ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ'] Different: True """ import torch torch.cuda.memory.set_per_process_memory_fraction(0.8) from nnsight import LanguageModel model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True) tokenizer = model.tokenizer if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token prompt = "The weather today is" prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"][0] n = len(prompt_ids) # --------------------------------------------------------------------------- # Experiment 1: Which token position matters? # # We zero layer 11's output at each individual token position to determine # which positions influence the generated output. At a late layer like 11, # only the last prompt token (position n-1 = -1) changes the output; earlier # positions do not. This is because causal attention means the last position # is the only one whose representation directly feeds into the next-token # logits. (At early layers like layer 0, corrupting any position propagates # through all subsequent layers and changes the output.) # --------------------------------------------------------------------------- print("=" * 60) print("Experiment 1: Which prompt position matters? (layer 11)") print("=" * 60) print(f"Prompt: {n} tokens {tokenizer.convert_ids_to_tokens(prompt_ids.tolist())}") for label, idx in [ ("Normal (no zero)", None), ("Zero [0]", 0), ("Zero [1]", 1), ("Zero [2]", 2), ("Zero [3] = [-1]", 3), ]: with model.generate(max_new_tokens=5, do_sample=False) as tracer: with tracer.invoke(prompt): if idx is not None: model.transformer.h[11].output[0][idx] = 0.0 out_ids = model.generator.output.save() response = tokenizer.decode(out_ids[0][n:], skip_special_tokens=True) print(f"{label:22s}: {repr(response)}") # --------------------------------------------------------------------------- # Experiment 2: KV cache persistence. # # We confirm that the intervention on the prompt pass persists via the KV # cache through all subsequent generation steps. Specifically: # - Run generate from P with zero at position n-1, collect 6 steered tokens. # - Do a fresh full forward pass over P + first 4 steered tokens, with the # same zero at position n-1. # - The full-pass logits at the last position should predict the same 5th # token that generate produced. # # This works because transformers with causal attention are equivalent whether # computed incrementally (KV cache) or all at once, as long as the intervention # is applied at the same absolute position. # --------------------------------------------------------------------------- print() print("=" * 60) print("Experiment 2: KV cache persistence") print("=" * 60) with model.generate(max_new_tokens=6, do_sample=False) as tracer: with tracer.invoke(prompt): h = model.transformer.h[0].output[0] model.transformer.h[0].output[0][-1] = torch.zeros_like(h[-1]) out_ids = model.generator.output.save() steered = out_ids[0][n:].tolist() t5_actual = steered[4] print(f"Steered 6 tokens: {tokenizer.convert_ids_to_tokens(steered)}") # Full trace over P + first 4 steered tokens, zero at the same absolute position. ctx = torch.cat([prompt_ids, torch.tensor(steered[:4])]).unsqueeze(0) print(f"Context for trace: {tokenizer.convert_ids_to_tokens(ctx[0].tolist())}") with torch.no_grad(): with model.trace() as tracer: with tracer.invoke(ctx): h = model.transformer.h[0].output[0] model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1]) logits_proxy = model.lm_head.output.save() logits = torch.as_tensor(logits_proxy).float().cpu() # shape [1, seq, vocab] last_logits = logits[0, -1] # logits predicting token after ctx t5_predicted = torch.argmax(last_logits).item() print(f"Actual 5th token from generate: {repr(tokenizer.decode([t5_actual]))} (id={t5_actual})") print(f"Predicted by full trace at pos {n-1}: {repr(tokenizer.decode([t5_predicted]))} (id={t5_predicted})") print(f"KV cache persistent == full recompute: {t5_actual == t5_predicted}") # --------------------------------------------------------------------------- # Experiment 3: iter[0] with concrete position == -1 in plain invoke. # # tracer.iter[0] applies the intervention graph only at generation step 0, # which is the prompt pass. At this step the activation has shape # [seq_len, hidden_size] so position n-1 is valid. This is bit-for-bit # identical to using -1 in the plain invoke context. # --------------------------------------------------------------------------- print() print("=" * 60) print("Experiment 3: iter[0] equivalence") print("=" * 60) # Approach 1: -1 in plain invoke (fires once on prompt pass) with model.generate(max_new_tokens=5, do_sample=False) as tracer: with tracer.invoke(prompt): h = model.transformer.h[0].output[0] model.transformer.h[0].output[0][-1] = torch.zeros_like(h[-1]) out1 = model.generator.output.save() tokens1 = out1[0][n:].tolist() # Approach 2: iter[0] with concrete position n-1 (fires once at step 0) with model.generate(max_new_tokens=5, do_sample=False) as tracer: with tracer.invoke(prompt): for step in tracer.iter[0]: h = model.transformer.h[0].output[0] print(f" iter[0] step shape: {h.shape}") model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1]) out2 = model.generator.output.save() tokens2 = out2[0][n:].tolist() print(f"Approach 1 (-1, plain invoke): {tokenizer.convert_ids_to_tokens(tokens1)}") print(f"Approach 2 (iter[0], pos {n-1}): {tokenizer.convert_ids_to_tokens(tokens2)}") print(f"Identical: {tokens1 == tokens2}") # --------------------------------------------------------------------------- # Experiment 4: iter[:] fails at steps 1+ for position > 0. # # tracer.iter[:] (= tracer.all()) fires the intervention at every generation # step. At step 0 the full prompt is processed (shape [seq_len, hidden]). # At steps 1+ nnsight uses the KV cache, processing only the new token # (shape [1, hidden]). Trying to access position n-1 at those steps raises # an IndexError, proving the KV cache is active. # --------------------------------------------------------------------------- print() print("=" * 60) print("Experiment 4: iter[:] shape at each step") print("=" * 60) shapes = [] try: with model.generate(max_new_tokens=5, do_sample=False) as tracer: with tracer.invoke(prompt): for step in tracer.iter[:]: h = model.transformer.h[0].output[0] shapes.append((step, tuple(h.shape))) # This will raise IndexError at step 1+ because size is 1 model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1]) model.generator.output.save() except IndexError as e: print(f"Shapes seen before error: {shapes}") print(f"IndexError at step 1+: {e}") print(f" -> Step 0 has seq_len={n}, position {n-1} valid.") print(f" -> Steps 1+ have seq_len=1 (KV cache), position {n-1} out of bounds.") # --------------------------------------------------------------------------- # Experiment 5: Sanity check — unsteered generation differs from steered. # # Confirms the intervention actually has an effect: generation without any # intervention produces different tokens than generation with position -1 # zeroed. Without this check, the equivalences above could be vacuous. # --------------------------------------------------------------------------- print() print("=" * 60) print("Experiment 5: Unsteered differs from steered") print("=" * 60) with model.generate(max_new_tokens=5, do_sample=False) as tracer: with tracer.invoke(prompt): out_unsteered = model.generator.output.save() tokens_unsteered = out_unsteered[0][n:].tolist() tokens_steered = out1[0][n:].tolist() # from experiment 3 print(f"Unsteered: {tokenizer.convert_ids_to_tokens(tokens_unsteered)}") print(f"Steered: {tokenizer.convert_ids_to_tokens(tokens_steered)}") print(f"Different: {tokens_unsteered != tokens_steered}")