Repositories / more_nnsight.git
more_nnsight.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
@@ -11,7 +11,7 @@ the API and shows the equivalent direct NNSight code. ## Imports ```python -from more_nnsight import SavedActivation, save_activations +from more_nnsight import SavedActivation, save_activations, updates ``` ## Core Rule @@ -267,6 +267,42 @@ with model.trace() as tracer: `SavedActivation.apply(model)` performs that assignment for every saved key. +## Layer-Major Updates + +When you need to read the current activation and write back an updated value in +the same invoke, use `updates(model, keys)`. + +```python +for key, current_value, update in updates(model, saved.keys()): + update(current_value + 2.0 * saved.get(key)) +``` + +This is the interleaving-safe pattern for multi-layer current-pass updates. It +walks the keys in canonical layer-major order and gives you: + +- `key`: the saved-activation path string +- `current_value`: the live activation slice from the current forward pass +- `update(new_value)`: a callback that writes a new value back to that same + slice + +`updates(...)` requires concrete saved keys in canonical order. In practice, +`saved.keys()` is the intended input. + +Direct NNSight equivalent for one path: + +```python +current = model.transformer.h[2].output[:, 10, :] +model.transformer.h[2].output[:, 10, :] = current + 2.0 * saved_tensor +``` + +Use `updates(...)` when the new value depends on the current forward-pass +activation. Use `saved.apply(model)` when you just want to replay stored +activations unchanged. + +`save_activations(...)` and `updates(...)` should usually happen in different +invocations. Saving layer 2 and then revisiting layer 2 after later layers have +already been touched violates NNSight's interleaving rules. + ## `saved.save()` `SavedActivation.save()` is different from `save_activations(...)`. @@ -374,6 +410,17 @@ with model.trace() as tracer: logits = model.lm_head.output.save() ``` +If you want to modify the current activation instead of replacing it outright, +use `updates(...)` in the later run: + +```python +with model.trace() as tracer: + with tracer.invoke([corrupted_prompt]): + for key, current_value, update in updates(model, focused_patch.keys()): + update(current_value + 2.0 * focused_patch.get(key)) + logits = model.lm_head.output.save() +``` + ## Other Models Use the model's real path names. For example, a Qwen-style decoder stack can
@@ -1,6 +1,6 @@ -from .saved_activation import SavedActivation, save_activations +from .saved_activation import SavedActivation, save_activations, updates -__all__ = ["SavedActivation", "save_activations"] +__all__ = ["SavedActivation", "save_activations", "updates"] def main() -> None:
@@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable +from typing import Any, Callable, Iterator import torch from lark import Lark, Token, Transformer @@ -227,10 +227,8 @@ class SavedActivation: def apply(self, model: Any) -> None: """Writes the stored activations into the current traced run of the model.""" - for target, saved_tensor in self._targets.items(): - activation = self._resolve_activation(model, target.activation_path) - self._validate_activation_shape(activation) - activation[:, target.position, :] = saved_tensor + for key, _current_value, update in updates(model, self.keys()): + update(self.get(key)) def __add__(self, other: "SavedActivation") -> "SavedActivation": """Combines two saved activation sets elementwise when they cover the same paths.""" @@ -508,3 +506,32 @@ def save_activations(model: Any, activation_paths: list[str]) -> SavedActivation SavedActivation._validate_activation_shape(activation) saved_values[target] = activation[:, target.position, :].save() return SavedActivation._from_targets(saved_values) + + +def updates( + model: Any, + activation_paths: list[str], +) -> Iterator[tuple[str, Any, Callable[[Any], None]]]: + """Returns layer-major live update handles for concrete saved-activation paths.""" + parsed_targets = [SavedActivation.parse_path(path) for path in activation_paths] + + for path, target in zip(activation_paths, parsed_targets, strict=True): + if SavedActivation._contains_slice(target.activation_path): + raise ValueError(f"updates() requires concrete keys, got slice path: {path}") + + if len(set(parsed_targets)) != len(parsed_targets): + raise ValueError("updates() requires unique keys.") + + canonical_targets = sorted(parsed_targets, key=SavedActivation._target_order_key) + if tuple(parsed_targets) != tuple(canonical_targets): + raise ValueError("updates() requires keys in canonical layer-major order.") + + for key, target in zip(activation_paths, parsed_targets, strict=True): + activation = SavedActivation._resolve_activation(model, target.activation_path) + SavedActivation._validate_activation_shape(activation) + current_value = activation[:, target.position, :] + + def update(new_value: Any, activation: Any = activation, position: int = target.position) -> None: + activation[:, position, :] = new_value + + yield (key, current_value, update)
@@ -3,7 +3,7 @@ from __future__ import annotations import pytest import torch -from more_nnsight import SavedActivation +from more_nnsight import SavedActivation, updates from more_nnsight.saved_activation import ActivationTarget, AttrSegment, IndexSegment @@ -169,3 +169,39 @@ def test_union_merges_disjoint_paths_and_rejects_overlap() -> None: with pytest.raises(ValueError, match="disjoint keys"): _ = left.union(overlap) + + +def test_updates_reject_duplicate_keys() -> None: + with pytest.raises(ValueError, match="unique keys"): + list( + updates( + object(), + [ + "model.transformer.h[2].output[10]", + "model.transformer.h[2].output[10]", + ], + ) + ) + + +def test_updates_reject_noncanonical_key_order() -> None: + with pytest.raises(ValueError, match="canonical layer-major order"): + list( + updates( + object(), + [ + "model.transformer.h[3].output[10]", + "model.transformer.h[2].output[10]", + ], + ) + ) + + +def test_updates_reject_slice_paths() -> None: + with pytest.raises(ValueError, match="concrete keys"): + list( + updates( + object(), + ["model.transformer.h[:].output[10]"], + ) + )
@@ -9,7 +9,7 @@ from nnsight import LanguageModel from nnsight.modeling.base import NNsight from transformers import AutoConfig, AutoModelForCausalLM -from more_nnsight import SavedActivation, save_activations +from more_nnsight import SavedActivation, save_activations, updates @pytest.fixture(scope="module") @@ -203,6 +203,108 @@ def test_union_with_gpt2_combines_disjoint_patches(gpt2_model: LanguageModel) -> assert torch.allclose(union_logits, direct_logits) +# Saving, reading, and updating a multi-layer last-token patch should work within one invoke. +def test_updates_with_gpt2_supports_current_plus_scaled_saved_activation_in_single_invoke( + gpt2_model: LanguageModel, +) -> None: + prompt = "After John and Mary went to the store, John gave a bottle of milk to" + patch_paths = [ + "model.transformer.h[2].output[-1]", + "model.transformer.h[3].output[-1]", + ] + + with gpt2_model.trace() as tracer: + with tracer.invoke([prompt]): + saved = save_activations(gpt2_model, patch_paths) + + with gpt2_model.trace() as tracer: + with tracer.invoke([prompt]): + layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save() + gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0]) + layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save() + gpt2_model.transformer.h[3].output[:, -1, :] = layer_3 + 2.0 * saved.get(patch_paths[1]) + direct_logits = gpt2_model.lm_head.output.save() + + with gpt2_model.trace() as tracer: + with tracer.invoke([prompt]): + for key, value, update in updates(gpt2_model, saved.keys()): + update(value + 2.0 * saved.get(key)) + patched_logits = gpt2_model.lm_head.output.save() + + assert torch.allclose(patched_logits, direct_logits) + + +def test_updates_with_gpt2_supports_key_specific_update_logic( + gpt2_model: LanguageModel, +) -> None: + prompt = "After John and Mary went to the store, John gave a bottle of milk to" + patch_paths = [ + "model.transformer.h[2].output[-1]", + "model.transformer.h[3].output[-1]", + ] + + with gpt2_model.trace() as tracer: + with tracer.invoke([prompt]): + saved = save_activations(gpt2_model, patch_paths) + + with gpt2_model.trace() as tracer: + with tracer.invoke([prompt]): + layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save() + gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0]) + layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save() + gpt2_model.transformer.h[3].output[:, -1, :] = saved.get(patch_paths[1]) + direct_logits = gpt2_model.lm_head.output.save() + + with gpt2_model.trace() as tracer: + with tracer.invoke([prompt]): + for key, value, update in updates(gpt2_model, saved.keys()): + if key == "model.transformer.h[2].output[-1]": + update(value + 2.0 * saved.get(key)) + else: + update(saved.get(key)) + patched_logits = gpt2_model.lm_head.output.save() + + assert torch.allclose(patched_logits, direct_logits) + + +def test_updates_with_gpt2_supports_key_specific_batch_reductions( + gpt2_model: LanguageModel, +) -> None: + prompts = [ + "The movie was great and I felt happy.", + "The dinner was excellent and I felt joyful.", + "The trip was terrible and I felt sad.", + ] + patch_paths = [ + "model.transformer.h[2].output[-1]", + "model.transformer.h[3].output[-1]", + ] + + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(gpt2_model, patch_paths) + + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save() + gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0]) + gpt2_model.transformer.h[3].output[:, -1, :] = saved.get(patch_paths[1])[[0, 2]].mean( + dim=0, keepdim=True + ) + direct_logits = gpt2_model.lm_head.output.save() + + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + for key, value, update in updates(gpt2_model, saved.keys()): + if key == "model.transformer.h[2].output[-1]": + update(value + 2.0 * saved.get(key)) + else: + update(saved.get(key)[[0, 2]].mean(dim=0, keepdim=True)) + patched_logits = gpt2_model.lm_head.output.save() + + assert torch.allclose(patched_logits, direct_logits) + + # `mean()` should reduce a real saved GPT-2 batch to a single-row steering vector. def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None: prompts = [