Repositories / more_nnsight.git

more_nnsight.git

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch

Add ModuleList slice support for saved activations

Author
Arjun Guha <a.guha@northeastern.edu>
Date
2026-03-30 09:54:21 -0400
Commit
d07b541bb6e2d4160c65a97b527d2480ba224665
src/more_nnsight/saved_activation.py
index ae896a3..f952903 100644
--- a/src/more_nnsight/saved_activation.py
+++ b/src/more_nnsight/saved_activation.py
@@ -14,8 +14,14 @@ from ._saved_activation_values import SavedActivationValues
 GRAMMAR = r"""
 start: "model" segment+ "[" SIGNED_INT "]"
 
-segment: "." CNAME      -> attr
-       | "[" SIGNED_INT "]" -> index
+segment: "." CNAME               -> attr
+       | "[" SIGNED_INT "]"      -> index
+       | "[" slice_expr "]"      -> slice
+
+slice_expr: SIGNED_INT ":" SIGNED_INT -> bounded_slice
+          | SIGNED_INT ":"            -> start_slice
+          | ":" SIGNED_INT            -> stop_slice
+          | ":"                       -> full_slice
 
 %import common.CNAME
 %import common.SIGNED_INT
@@ -38,7 +44,15 @@ class IndexSegment:
     value: int
 
 
-PathSegment = AttrSegment | IndexSegment
+@dataclass(frozen=True, slots=True)
+class SliceSegment:
+    """Represents a slice access in a model activation path."""
+
+    start: int | None
+    stop: int | None
+
+
+PathSegment = AttrSegment | IndexSegment | SliceSegment
 
 
 @dataclass(frozen=True, slots=True)
@@ -69,6 +83,27 @@ class _ActivationTargetTransformer(Transformer[Token, ActivationTarget | PathSeg
         """Turns a parsed bracketed integer into a path segment."""
         return IndexSegment(items[0])
 
+    def slice(self, items: list[tuple[int | None, int | None]]) -> SliceSegment:
+        """Turns a parsed bracketed slice into a path segment."""
+        start, stop = items[0]
+        return SliceSegment(start, stop)
+
+    def bounded_slice(self, items: list[int]) -> tuple[int | None, int | None]:
+        """Parses a slice with both explicit bounds."""
+        return (items[0], items[1])
+
+    def start_slice(self, items: list[int]) -> tuple[int | None, int | None]:
+        """Parses a slice with only an explicit start bound."""
+        return (items[0], None)
+
+    def stop_slice(self, items: list[int]) -> tuple[int | None, int | None]:
+        """Parses a slice with only an explicit stop bound."""
+        return (None, items[0])
+
+    def full_slice(self, _: list[int]) -> tuple[int | None, int | None]:
+        """Parses a slice that spans the entire container."""
+        return (None, None)
+
     def CNAME(self, token: Token) -> str:
         """Preserves parsed attribute names as plain strings."""
         return str(token)
@@ -105,10 +140,12 @@ class SavedActivation:
     Paths use the model's real attribute/index structure, with the final
     bracket selecting the token position to save or patch. For example:
 
-    - ``SavedActivation.save_activations(model, ["model.transformer.h[2].output[10]"])``
+    - ``save_activations(model, ["model.transformer.h[2].output[10]"])``
       captures layer 2 output at token 10
-    - ``SavedActivation.save_activations(model, ["model.transformer.h[3].output[-1]"])``
+    - ``save_activations(model, ["model.transformer.h[3].output[-1]"])``
       captures the last token from layer 3
+    - ``save_activations(model, ["model.transformer.h[:].output[10]"])``
+      captures token 10 at every layer in a ModuleList-like container
 
     ``save_activations`` captures activation values, while ``save`` registers a
     ``SavedActivation`` object itself with NNSight so it survives trace exit.
@@ -167,7 +204,7 @@ class SavedActivation:
             }
         )
 
-    def apply(self, model: LanguageModel) -> None:
+    def apply(self, model: Any) -> None:
         """Writes the stored activations into the current traced run of the model."""
         for target, saved_tensor in self._targets.items():
             activation = self._resolve_activation(model, target.activation_path)
@@ -236,14 +273,11 @@ class SavedActivation:
         return _PATH_PARSER.parse(path)
 
     @staticmethod
-    def _resolve_activation(model: LanguageModel, activation_path: tuple[PathSegment, ...]) -> Any:
+    def _resolve_activation(model: Any, activation_path: tuple[PathSegment, ...]) -> Any:
         """Finds the live NNSight object identified by a parsed activation path."""
         current: Any = model
         for segment in activation_path:
-            if isinstance(segment, AttrSegment):
-                current = getattr(current, segment.name)
-            else:
-                current = current[segment.value]
+            current = SavedActivation._apply_concrete_segment(current, segment)
         return current
 
     @staticmethod
@@ -330,18 +364,103 @@ class SavedActivation:
         for segment in target.activation_path:
             if isinstance(segment, AttrSegment):
                 parts.append(f".{segment.name}")
-            else:
+            elif isinstance(segment, IndexSegment):
                 parts.append(f"[{segment.value}]")
+            else:
+                start = "" if segment.start is None else str(segment.start)
+                stop = "" if segment.stop is None else str(segment.stop)
+                parts.append(f"[{start}:{stop}]")
         parts.append(f"[{target.position}]")
         return "".join(parts)
 
+    @classmethod
+    def _expand_target(
+        cls,
+        model: Any,
+        target: ActivationTarget,
+    ) -> list[ActivationTarget]:
+        """Expands slice segments into concrete index paths before saving activations."""
+        return cls._expand_segments(
+            cls._static_structure(model),
+            tuple(),
+            target.activation_path,
+            target.position,
+        )
+
+    @classmethod
+    def _expand_segments(
+        cls,
+        current: Any,
+        prefix: tuple[PathSegment, ...],
+        remaining: tuple[PathSegment, ...],
+        position: int,
+    ) -> list[ActivationTarget]:
+        """Recursively expands one parsed activation path into concrete saved targets."""
+        if not remaining:
+            return [ActivationTarget(prefix, position)]
+
+        if not cls._contains_slice(remaining):
+            return [ActivationTarget(prefix + remaining, position)]
+
+        segment, *rest = remaining
+        concrete_targets: list[ActivationTarget] = []
+        for next_current, concrete_segment in cls._expand_segment(current, segment):
+            concrete_targets.extend(
+                cls._expand_segments(next_current, prefix + (concrete_segment,), tuple(rest), position)
+            )
+        return concrete_targets
+
+    @staticmethod
+    def _contains_slice(segments: tuple[PathSegment, ...]) -> bool:
+        """Checks whether a remaining path still needs ModuleList slice expansion."""
+        return any(isinstance(segment, SliceSegment) for segment in segments)
 
-def save_activations(model: LanguageModel, activation_paths: list[str]) -> SavedActivation:
+    @staticmethod
+    def _normalize_path_slice(segment: SliceSegment, length: int) -> range:
+        """Converts a parsed path slice into concrete indices for a module container."""
+        return range(*slice(segment.start, segment.stop).indices(length))
+
+    @staticmethod
+    def _static_structure(value: Any) -> Any:
+        """Returns the underlying module tree used only for slice expansion."""
+        return getattr(value, "_module", value)
+
+    @staticmethod
+    def _apply_concrete_segment(current: Any, segment: PathSegment) -> Any:
+        """Applies one non-slice path segment to either a live envoy or static module."""
+        if isinstance(segment, AttrSegment):
+            return getattr(current, segment.name)
+        if isinstance(segment, IndexSegment):
+            return current[segment.value]
+        raise ValueError("Slice segments must be expanded before resolving a concrete activation.")
+
+    @classmethod
+    def _expand_segment(
+        cls,
+        current: Any,
+        segment: PathSegment,
+    ) -> list[tuple[Any, AttrSegment | IndexSegment]]:
+        """Expands one path segment into concrete child accesses."""
+        if isinstance(segment, SliceSegment):
+            static_current = cls._static_structure(current)
+            if not hasattr(static_current, "__len__") or not hasattr(static_current, "__getitem__"):
+                raise TypeError("Slice syntax is only supported on indexable module containers.")
+            return [
+                (static_current[index], IndexSegment(index))
+                for index in cls._normalize_path_slice(segment, len(static_current))
+            ]
+
+        next_current = cls._apply_concrete_segment(current, segment)
+        return [(cls._static_structure(next_current), segment)]
+
+
+def save_activations(model: Any, activation_paths: list[str]) -> SavedActivation:
     """Captures requested activations inside an already-active NNSight trace/invoke context."""
-    targets = [SavedActivation.parse_path(path) for path in activation_paths]
+    parsed_targets = [SavedActivation.parse_path(path) for path in activation_paths]
     saved_values: dict[ActivationTarget, Any] = {}
-    for target in targets:
-        activation = SavedActivation._resolve_activation(model, target.activation_path)
-        SavedActivation._validate_activation_shape(activation)
-        saved_values[target] = activation[:, target.position, :].save()
+    for parsed_target in parsed_targets:
+        for target in SavedActivation._expand_target(model, parsed_target):
+            activation = SavedActivation._resolve_activation(model, target.activation_path)
+            SavedActivation._validate_activation_shape(activation)
+            saved_values[target] = activation[:, target.position, :].save()
     return SavedActivation._from_targets(saved_values)
tests/test_saved_activation_gpt2.py
index d90f0ee..a338ebc 100644
--- a/tests/test_saved_activation_gpt2.py
+++ b/tests/test_saved_activation_gpt2.py
@@ -6,6 +6,8 @@ from pathlib import Path
 import pytest
 import torch
 from nnsight import LanguageModel
+from nnsight.modeling.base import NNsight
+from transformers import AutoConfig, AutoModelForCausalLM
 
 from more_nnsight import SavedActivation, save_activations
 
@@ -38,6 +40,28 @@ def test_save_with_gpt2_returns_saved_tensors_at_requested_paths(gpt2_model: Lan
     assert torch.allclose(saved.get("model.transformer.h[3].output[-1]"), direct_layer_3)
 
 
+# Slice syntax over GPT-2's transformer block list should match a direct per-layer NNSight loop.
+def test_save_with_gpt2_layer_slice_matches_direct_nnsight_loop(gpt2_model: LanguageModel) -> None:
+    prompts = [
+        "One two three four five six seven eight nine ten eleven twelve.",
+        "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.",
+    ]
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            saved = save_activations(gpt2_model, ["model.transformer.h[:].output[10]"])
+    direct: list[torch.Tensor] = []
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            direct.extend(
+                gpt2_model.transformer.h[layer].output[:, 10, :].save()
+                for layer in range(len(gpt2_model.transformer.h))
+            )
+
+    assert len(saved.keys()) == len(direct)
+    for layer, direct_tensor in enumerate(direct):
+        assert torch.allclose(saved.get(f"model.transformer.h[{layer}].output[10]"), direct_tensor)
+
+
 # Unsaved GPT-2 paths should still error even when nearby paths were captured.
 def test_unsaved_paths_error_with_gpt2_saved_activation(gpt2_model: LanguageModel) -> None:
     prompts = [
@@ -138,3 +162,31 @@ def test_mean_inside_trace_saves_reduced_result_on_device(gpt2_model: LanguageMo
     assert saved.get(path).shape == (2, 768)
     assert mean_saved.get(path).shape == (1, 768)
     assert torch.allclose(mean_saved.get(path), saved.get(path).mean(dim=0, keepdim=True))
+
+
+# A randomly initialized non-GPT2 architecture should work as long as layers live in a ModuleList.
+def test_save_with_random_init_qwen_layer_slice_matches_direct_loop() -> None:
+    cfg = AutoConfig.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
+    cfg.hidden_size = 64
+    cfg.intermediate_size = 128
+    cfg.num_hidden_layers = 3
+    cfg.num_attention_heads = 4
+    cfg.num_key_value_heads = 4
+    cfg.vocab_size = 256
+
+    model = NNsight(AutoModelForCausalLM.from_config(cfg))
+    input_ids = torch.randint(0, cfg.vocab_size, (2, 6))
+    attention_mask = torch.ones_like(input_ids)
+
+    with model.trace(input_ids=input_ids, attention_mask=attention_mask):
+        saved = save_activations(model, ["model.model.layers[:].output[2]"])
+    direct: list[torch.Tensor] = []
+    with model.trace(input_ids=input_ids, attention_mask=attention_mask):
+        direct.extend(
+            model.model.layers[layer].output[:, 2, :].save()
+            for layer in range(cfg.num_hidden_layers)
+        )
+
+    assert len(saved.keys()) == cfg.num_hidden_layers
+    for layer, direct_tensor in enumerate(direct):
+        assert torch.allclose(saved.get(f"model.model.layers[{layer}].output[2]"), direct_tensor)