Repositories / more_nnsight.git

more_nnsight.git

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch

Preserve NNSight access order in save_activations

Author
Arjun Guha <a.guha@northeastern.edu>
Date
2026-03-30 19:14:16 -0400
Commit
bd3d1b47978cfba1caa6f7a04a2646c4ce64dc73
src/more_nnsight/saved_activation.py
index f952903..946103f 100644
--- a/src/more_nnsight/saved_activation.py
+++ b/src/more_nnsight/saved_activation.py
@@ -374,6 +374,19 @@ class SavedActivation:
         return "".join(parts)
 
     @classmethod
+    def _target_order_key(cls, target: ActivationTarget) -> tuple[Any, ...]:
+        """Orders concrete targets by model traversal path, then by token position."""
+        path_key: list[tuple[int, Any]] = []
+        for segment in target.activation_path:
+            if isinstance(segment, AttrSegment):
+                path_key.append((0, segment.name))
+            elif isinstance(segment, IndexSegment):
+                path_key.append((1, segment.value))
+            else:
+                raise ValueError("Targets must be concrete before access ordering is computed.")
+        return (*path_key, (2, target.position))
+
+    @classmethod
     def _expand_target(
         cls,
         model: Any,
@@ -457,10 +470,15 @@ class SavedActivation:
 def save_activations(model: Any, activation_paths: list[str]) -> SavedActivation:
     """Captures requested activations inside an already-active NNSight trace/invoke context."""
     parsed_targets = [SavedActivation.parse_path(path) for path in activation_paths]
+    concrete_targets = [
+        concrete_target
+        for parsed_target in parsed_targets
+        for concrete_target in SavedActivation._expand_target(model, parsed_target)
+    ]
+
     saved_values: dict[ActivationTarget, Any] = {}
-    for parsed_target in parsed_targets:
-        for target in SavedActivation._expand_target(model, parsed_target):
-            activation = SavedActivation._resolve_activation(model, target.activation_path)
-            SavedActivation._validate_activation_shape(activation)
-            saved_values[target] = activation[:, target.position, :].save()
+    for target in sorted(concrete_targets, key=SavedActivation._target_order_key):
+        activation = SavedActivation._resolve_activation(model, target.activation_path)
+        SavedActivation._validate_activation_shape(activation)
+        saved_values[target] = activation[:, target.position, :].save()
     return SavedActivation._from_targets(saved_values)
tests/test_saved_activation_gpt2.py
index a338ebc..db40ca8 100644
--- a/tests/test_saved_activation_gpt2.py
+++ b/tests/test_saved_activation_gpt2.py
@@ -62,6 +62,35 @@ def test_save_with_gpt2_layer_slice_matches_direct_nnsight_loop(gpt2_model: Lang
         assert torch.allclose(saved.get(f"model.transformer.h[{layer}].output[10]"), direct_tensor)
 
 
+# Saving several token positions across all layers must still traverse layers in graph order.
+def test_save_with_gpt2_layer_slice_and_multiple_positions_matches_direct_loop(
+    gpt2_model: LanguageModel,
+) -> None:
+    prompts = [
+        "One two three four five six seven eight nine ten eleven twelve.",
+        "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.",
+    ]
+    positions = [-4, -3, -2, -1]
+    paths = [f"model.transformer.h[:].output[{position}]" for position in positions]
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            saved = save_activations(gpt2_model, paths)
+
+    direct: dict[str, torch.Tensor] = {}
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            for layer in range(len(gpt2_model.transformer.h)):
+                for position in positions:
+                    direct[f"model.transformer.h[{layer}].output[{position}]"] = (
+                        gpt2_model.transformer.h[layer].output[:, position, :].save()
+                    )
+
+    assert saved.keys() == list(direct)
+    for path, direct_tensor in direct.items():
+        assert torch.allclose(saved.get(path), direct_tensor)
+
+
 # Unsaved GPT-2 paths should still error even when nearby paths were captured.
 def test_unsaved_paths_error_with_gpt2_saved_activation(gpt2_model: LanguageModel) -> None:
     prompts = [