Repositories / more_nnsight.git

more_nnsight.git

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch

Add layer-major updates helper

Author
Arjun Guha <a.guha@northeastern.edu>
Date
2026-03-31 14:00:29 -0400
Commit
91dca5f4170ec9de24b641c5d11bfe23eb007200
SKILL.md
index e958469..45650d2 100644
--- a/SKILL.md
+++ b/SKILL.md
@@ -11,7 +11,7 @@ the API and shows the equivalent direct NNSight code.
 ## Imports
 
 ```python
-from more_nnsight import SavedActivation, save_activations
+from more_nnsight import SavedActivation, save_activations, updates
 ```
 
 ## Core Rule
@@ -267,6 +267,42 @@ with model.trace() as tracer:
 
 `SavedActivation.apply(model)` performs that assignment for every saved key.
 
+## Layer-Major Updates
+
+When you need to read the current activation and write back an updated value in
+the same invoke, use `updates(model, keys)`.
+
+```python
+for key, current_value, update in updates(model, saved.keys()):
+    update(current_value + 2.0 * saved.get(key))
+```
+
+This is the interleaving-safe pattern for multi-layer current-pass updates. It
+walks the keys in canonical layer-major order and gives you:
+
+- `key`: the saved-activation path string
+- `current_value`: the live activation slice from the current forward pass
+- `update(new_value)`: a callback that writes a new value back to that same
+  slice
+
+`updates(...)` requires concrete saved keys in canonical order. In practice,
+`saved.keys()` is the intended input.
+
+Direct NNSight equivalent for one path:
+
+```python
+current = model.transformer.h[2].output[:, 10, :]
+model.transformer.h[2].output[:, 10, :] = current + 2.0 * saved_tensor
+```
+
+Use `updates(...)` when the new value depends on the current forward-pass
+activation. Use `saved.apply(model)` when you just want to replay stored
+activations unchanged.
+
+`save_activations(...)` and `updates(...)` should usually happen in different
+invocations. Saving layer 2 and then revisiting layer 2 after later layers have
+already been touched violates NNSight's interleaving rules.
+
 ## `saved.save()`
 
 `SavedActivation.save()` is different from `save_activations(...)`.
@@ -374,6 +410,17 @@ with model.trace() as tracer:
         logits = model.lm_head.output.save()
 ```
 
+If you want to modify the current activation instead of replacing it outright,
+use `updates(...)` in the later run:
+
+```python
+with model.trace() as tracer:
+    with tracer.invoke([corrupted_prompt]):
+        for key, current_value, update in updates(model, focused_patch.keys()):
+            update(current_value + 2.0 * focused_patch.get(key))
+        logits = model.lm_head.output.save()
+```
+
 ## Other Models
 
 Use the model's real path names. For example, a Qwen-style decoder stack can
src/more_nnsight/__init__.py
index 681a71b..5f149ac 100644
--- a/src/more_nnsight/__init__.py
+++ b/src/more_nnsight/__init__.py
@@ -1,6 +1,6 @@
-from .saved_activation import SavedActivation, save_activations
+from .saved_activation import SavedActivation, save_activations, updates
 
-__all__ = ["SavedActivation", "save_activations"]
+__all__ = ["SavedActivation", "save_activations", "updates"]
 
 
 def main() -> None:
src/more_nnsight/saved_activation.py
index e83f096..89f0be3 100644
--- a/src/more_nnsight/saved_activation.py
+++ b/src/more_nnsight/saved_activation.py
@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 from dataclasses import dataclass
-from typing import Any, Callable
+from typing import Any, Callable, Iterator
 
 import torch
 from lark import Lark, Token, Transformer
@@ -227,10 +227,8 @@ class SavedActivation:
 
     def apply(self, model: Any) -> None:
         """Writes the stored activations into the current traced run of the model."""
-        for target, saved_tensor in self._targets.items():
-            activation = self._resolve_activation(model, target.activation_path)
-            self._validate_activation_shape(activation)
-            activation[:, target.position, :] = saved_tensor
+        for key, _current_value, update in updates(model, self.keys()):
+            update(self.get(key))
 
     def __add__(self, other: "SavedActivation") -> "SavedActivation":
         """Combines two saved activation sets elementwise when they cover the same paths."""
@@ -508,3 +506,32 @@ def save_activations(model: Any, activation_paths: list[str]) -> SavedActivation
         SavedActivation._validate_activation_shape(activation)
         saved_values[target] = activation[:, target.position, :].save()
     return SavedActivation._from_targets(saved_values)
+
+
+def updates(
+    model: Any,
+    activation_paths: list[str],
+) -> Iterator[tuple[str, Any, Callable[[Any], None]]]:
+    """Returns layer-major live update handles for concrete saved-activation paths."""
+    parsed_targets = [SavedActivation.parse_path(path) for path in activation_paths]
+
+    for path, target in zip(activation_paths, parsed_targets, strict=True):
+        if SavedActivation._contains_slice(target.activation_path):
+            raise ValueError(f"updates() requires concrete keys, got slice path: {path}")
+
+    if len(set(parsed_targets)) != len(parsed_targets):
+        raise ValueError("updates() requires unique keys.")
+
+    canonical_targets = sorted(parsed_targets, key=SavedActivation._target_order_key)
+    if tuple(parsed_targets) != tuple(canonical_targets):
+        raise ValueError("updates() requires keys in canonical layer-major order.")
+
+    for key, target in zip(activation_paths, parsed_targets, strict=True):
+        activation = SavedActivation._resolve_activation(model, target.activation_path)
+        SavedActivation._validate_activation_shape(activation)
+        current_value = activation[:, target.position, :]
+
+        def update(new_value: Any, activation: Any = activation, position: int = target.position) -> None:
+            activation[:, position, :] = new_value
+
+        yield (key, current_value, update)
tests/test_saved_activation.py
index a9313a9..af9ce73 100644
--- a/tests/test_saved_activation.py
+++ b/tests/test_saved_activation.py
@@ -3,7 +3,7 @@ from __future__ import annotations
 import pytest
 import torch
 
-from more_nnsight import SavedActivation
+from more_nnsight import SavedActivation, updates
 from more_nnsight.saved_activation import ActivationTarget, AttrSegment, IndexSegment
 
 
@@ -169,3 +169,39 @@ def test_union_merges_disjoint_paths_and_rejects_overlap() -> None:
 
     with pytest.raises(ValueError, match="disjoint keys"):
         _ = left.union(overlap)
+
+
+def test_updates_reject_duplicate_keys() -> None:
+    with pytest.raises(ValueError, match="unique keys"):
+        list(
+            updates(
+                object(),
+                [
+                    "model.transformer.h[2].output[10]",
+                    "model.transformer.h[2].output[10]",
+                ],
+            )
+        )
+
+
+def test_updates_reject_noncanonical_key_order() -> None:
+    with pytest.raises(ValueError, match="canonical layer-major order"):
+        list(
+            updates(
+                object(),
+                [
+                    "model.transformer.h[3].output[10]",
+                    "model.transformer.h[2].output[10]",
+                ],
+            )
+        )
+
+
+def test_updates_reject_slice_paths() -> None:
+    with pytest.raises(ValueError, match="concrete keys"):
+        list(
+            updates(
+                object(),
+                ["model.transformer.h[:].output[10]"],
+            )
+        )
tests/test_saved_activation_gpt2.py
index 730b730..775ab62 100644
--- a/tests/test_saved_activation_gpt2.py
+++ b/tests/test_saved_activation_gpt2.py
@@ -9,7 +9,7 @@ from nnsight import LanguageModel
 from nnsight.modeling.base import NNsight
 from transformers import AutoConfig, AutoModelForCausalLM
 
-from more_nnsight import SavedActivation, save_activations
+from more_nnsight import SavedActivation, save_activations, updates
 
 
 @pytest.fixture(scope="module")
@@ -203,6 +203,108 @@ def test_union_with_gpt2_combines_disjoint_patches(gpt2_model: LanguageModel) ->
     assert torch.allclose(union_logits, direct_logits)
 
 
+# Saving, reading, and updating a multi-layer last-token patch should work within one invoke.
+def test_updates_with_gpt2_supports_current_plus_scaled_saved_activation_in_single_invoke(
+    gpt2_model: LanguageModel,
+) -> None:
+    prompt = "After John and Mary went to the store, John gave a bottle of milk to"
+    patch_paths = [
+        "model.transformer.h[2].output[-1]",
+        "model.transformer.h[3].output[-1]",
+    ]
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([prompt]):
+            saved = save_activations(gpt2_model, patch_paths)
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([prompt]):
+            layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save()
+            gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0])
+            layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save()
+            gpt2_model.transformer.h[3].output[:, -1, :] = layer_3 + 2.0 * saved.get(patch_paths[1])
+            direct_logits = gpt2_model.lm_head.output.save()
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([prompt]):
+            for key, value, update in updates(gpt2_model, saved.keys()):
+                update(value + 2.0 * saved.get(key))
+            patched_logits = gpt2_model.lm_head.output.save()
+
+    assert torch.allclose(patched_logits, direct_logits)
+
+
+def test_updates_with_gpt2_supports_key_specific_update_logic(
+    gpt2_model: LanguageModel,
+) -> None:
+    prompt = "After John and Mary went to the store, John gave a bottle of milk to"
+    patch_paths = [
+        "model.transformer.h[2].output[-1]",
+        "model.transformer.h[3].output[-1]",
+    ]
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([prompt]):
+            saved = save_activations(gpt2_model, patch_paths)
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([prompt]):
+            layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save()
+            gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0])
+            layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save()
+            gpt2_model.transformer.h[3].output[:, -1, :] = saved.get(patch_paths[1])
+            direct_logits = gpt2_model.lm_head.output.save()
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([prompt]):
+            for key, value, update in updates(gpt2_model, saved.keys()):
+                if key == "model.transformer.h[2].output[-1]":
+                    update(value + 2.0 * saved.get(key))
+                else:
+                    update(saved.get(key))
+            patched_logits = gpt2_model.lm_head.output.save()
+
+    assert torch.allclose(patched_logits, direct_logits)
+
+
+def test_updates_with_gpt2_supports_key_specific_batch_reductions(
+    gpt2_model: LanguageModel,
+) -> None:
+    prompts = [
+        "The movie was great and I felt happy.",
+        "The dinner was excellent and I felt joyful.",
+        "The trip was terrible and I felt sad.",
+    ]
+    patch_paths = [
+        "model.transformer.h[2].output[-1]",
+        "model.transformer.h[3].output[-1]",
+    ]
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            saved = save_activations(gpt2_model, patch_paths)
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save()
+            gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0])
+            gpt2_model.transformer.h[3].output[:, -1, :] = saved.get(patch_paths[1])[[0, 2]].mean(
+                dim=0, keepdim=True
+            )
+            direct_logits = gpt2_model.lm_head.output.save()
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            for key, value, update in updates(gpt2_model, saved.keys()):
+                if key == "model.transformer.h[2].output[-1]":
+                    update(value + 2.0 * saved.get(key))
+                else:
+                    update(saved.get(key)[[0, 2]].mean(dim=0, keepdim=True))
+            patched_logits = gpt2_model.lm_head.output.save()
+
+    assert torch.allclose(patched_logits, direct_logits)
+
+
 # `mean()` should reduce a real saved GPT-2 batch to a single-row steering vector.
 def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None:
     prompts = [