Repositories / more_nnsight.git

more_nnsight.git

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch

Canonicalize SavedActivation target order

Author
Arjun Guha <a.guha@northeastern.edu>
Date
2026-03-30 19:27:23 -0400
Commit
4e9062996482d974222bd9d8cf6d22bdf6b64c64
src/more_nnsight/saved_activation.py
index 946103f..08cae19 100644
--- a/src/more_nnsight/saved_activation.py
+++ b/src/more_nnsight/saved_activation.py
@@ -161,7 +161,8 @@ class SavedActivation:
     ) -> None:
         """Creates a saved-activation tree from nested values or explicit targets."""
         values = {} if values is None else values
-        self._targets = targets if targets is not None else self._extract_targets(values)
+        extracted_targets = targets if targets is not None else self._extract_targets(values)
+        self._targets = self._canonicalize_targets(extracted_targets)
         self.values = SavedActivationValues(values)
         self.slice = _SavedActivationBatchSlicer(self)
 
@@ -334,6 +335,11 @@ class SavedActivation:
         if Globals.stack > 0:
             nnsight_save(value)
 
+    @classmethod
+    def _canonicalize_targets(cls, targets: dict[ActivationTarget, Any]) -> dict[ActivationTarget, Any]:
+        """Normalizes target order so every SavedActivation iterates in graph-safe model order."""
+        return dict(sorted(targets.items(), key=lambda item: cls._target_order_key(item[0])))
+
     @staticmethod
     def _slice_value(value: Any, ranges: list[range]) -> Any:
         """Slices saved batch rows, concatenating multiple ranges when requested."""
tests/test_saved_activation_gpt2.py
index db40ca8..cc20a49 100644
--- a/tests/test_saved_activation_gpt2.py
+++ b/tests/test_saved_activation_gpt2.py
@@ -131,6 +131,47 @@ def test_apply_with_gpt2_patches_current_run(gpt2_model: LanguageModel) -> None:
     assert not torch.equal(before_patch, after_patch)
 
 
+# Applying multiple patches must replay them in model order even if the SavedActivation keys are reordered.
+def test_apply_with_gpt2_reordered_patch_set_preserves_nnsight_access_order(
+    gpt2_model: LanguageModel,
+) -> None:
+    clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
+    corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
+    patch_paths = [
+        "model.transformer.h[2].output[9]",
+        "model.transformer.h[3].output[8]",
+    ]
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([clean_prompt]):
+            patch = save_activations(gpt2_model, patch_paths)
+
+    reordered_patch = patch.subset(
+        [
+            "model.transformer.h[3].output[8]",
+            "model.transformer.h[2].output[9]",
+        ]
+    )
+
+    assert reordered_patch.keys() == [
+        "model.transformer.h[2].output[9]",
+        "model.transformer.h[3].output[8]",
+    ]
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([corrupted_prompt]):
+            reordered_patch.apply(gpt2_model)
+            patched_logits = gpt2_model.lm_head.output.save()
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([corrupted_prompt]):
+            gpt2_model.transformer.h[2].output[:, 9, :] = patch.get("model.transformer.h[2].output[9]")
+            gpt2_model.transformer.h[3].output[:, 8, :] = patch.get("model.transformer.h[3].output[8]")
+            direct_logits = gpt2_model.lm_head.output.save()
+
+    assert torch.allclose(patched_logits, direct_logits)
+
+
 # `mean()` should reduce a real saved GPT-2 batch to a single-row steering vector.
 def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None:
     prompts = [