Repositories / more_nnsight.git

more_nnsight.git

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch

Add SavedActivation union operation

Author
Arjun Guha <a.guha@northeastern.edu>
Date
2026-03-30 19:52:06 -0400
Commit
a0b08b59dd6319fc29904b4e3dee53389c9cdea3
src/more_nnsight/saved_activation.py
index 08cae19..5577a2e 100644
--- a/src/more_nnsight/saved_activation.py
+++ b/src/more_nnsight/saved_activation.py
@@ -205,6 +205,15 @@ class SavedActivation:
             }
         )
 
+    def union(self, other: "SavedActivation") -> "SavedActivation":
+        """Combines two disjoint saved activation sets into one patch collection."""
+        overlapping_targets = set(self._targets) & set(other._targets)
+        if overlapping_targets:
+            overlapping_paths = ", ".join(self._format_target(target) for target in overlapping_targets)
+            raise ValueError(f"SavedActivation.union requires disjoint keys, got overlap: {overlapping_paths}")
+
+        return self._from_targets({**self._targets, **other._targets})
+
     def apply(self, model: Any) -> None:
         """Writes the stored activations into the current traced run of the model."""
         for target, saved_tensor in self._targets.items():
tests/test_saved_activation.py
index fc0cca0..524b2d4 100644
--- a/tests/test_saved_activation.py
+++ b/tests/test_saved_activation.py
@@ -122,3 +122,26 @@ def test_add_and_subtract_require_matching_keys() -> None:
 
     with pytest.raises(ValueError, match="same keys"):
         _ = left + other
+
+
+# Union should merge disjoint saved paths and reject overlapping ones.
+def test_union_merges_disjoint_paths_and_rejects_overlap() -> None:
+    left_tensor = torch.ones(1, 2)
+    right_tensor = 2 * torch.ones(1, 2)
+    overlap_tensor = 3 * torch.ones(1, 2)
+
+    left = SavedActivation({"transformer": {"h": {2: {"output": {10: left_tensor}}}}})
+    right = SavedActivation({"transformer": {"h": {3: {"output": {10: right_tensor}}}}})
+    overlap = SavedActivation({"transformer": {"h": {2: {"output": {10: overlap_tensor}}}}})
+
+    union = left.union(right)
+
+    assert union.keys() == [
+        "model.transformer.h[2].output[10]",
+        "model.transformer.h[3].output[10]",
+    ]
+    assert union.get("model.transformer.h[2].output[10]") is left_tensor
+    assert union.get("model.transformer.h[3].output[10]") is right_tensor
+
+    with pytest.raises(ValueError, match="disjoint keys"):
+        _ = left.union(overlap)
tests/test_saved_activation_gpt2.py
index cc20a49..730b730 100644
--- a/tests/test_saved_activation_gpt2.py
+++ b/tests/test_saved_activation_gpt2.py
@@ -172,6 +172,37 @@ def test_apply_with_gpt2_reordered_patch_set_preserves_nnsight_access_order(
     assert torch.allclose(patched_logits, direct_logits)
 
 
+# Union should combine disjoint GPT-2 patches so they replay like one direct multi-assignment patch.
+def test_union_with_gpt2_combines_disjoint_patches(gpt2_model: LanguageModel) -> None:
+    clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
+    corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([clean_prompt]):
+            layer_2_patch = save_activations(gpt2_model, ["model.transformer.h[2].output[9]"])
+            layer_3_patch = save_activations(gpt2_model, ["model.transformer.h[3].output[8]"])
+
+    union_patch = layer_2_patch.union(layer_3_patch)
+
+    assert union_patch.keys() == [
+        "model.transformer.h[2].output[9]",
+        "model.transformer.h[3].output[8]",
+    ]
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([corrupted_prompt]):
+            union_patch.apply(gpt2_model)
+            union_logits = gpt2_model.lm_head.output.save()
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([corrupted_prompt]):
+            gpt2_model.transformer.h[2].output[:, 9, :] = layer_2_patch.get("model.transformer.h[2].output[9]")
+            gpt2_model.transformer.h[3].output[:, 8, :] = layer_3_patch.get("model.transformer.h[3].output[8]")
+            direct_logits = gpt2_model.lm_head.output.save()
+
+    assert torch.allclose(union_logits, direct_logits)
+
+
 # `mean()` should reduce a real saved GPT-2 batch to a single-row steering vector.
 def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None:
     prompts = [