Repositories / more_nnsight.git

more_nnsight.git

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch

Refactor saved activation capture API

Author
Arjun Guha <a.guha@northeastern.edu>
Date
2026-03-30 09:22:21 -0400
Commit
3d29f84f0d7fcd3c239e54c9d238fc833f3d32db
scripts/demo_activation_patching.py
index 02fca2f..7b8c0c8 100644
--- a/scripts/demo_activation_patching.py
+++ b/scripts/demo_activation_patching.py
@@ -6,7 +6,7 @@ from pathlib import Path
 import torch
 from nnsight import LanguageModel
 
-from more_nnsight import SavedActivation
+from more_nnsight import SavedActivation, save_activations
 
 
 def _token_id(model: LanguageModel, token: str) -> int:
@@ -61,7 +61,9 @@ def main() -> None:
         "model.transformer.h[2].output[9]",
         "model.transformer.h[3].output[9]",
     ]
-    clean_saved = SavedActivation.save(model, [clean_prompt], patch_paths)
+    with model.trace() as tracer:
+        with tracer.invoke([clean_prompt]):
+            clean_saved = save_activations(model, patch_paths)
     focused_patch = clean_saved.subset(["model.transformer.h[2].output[9]"])
 
     clean_diff = logit_diff(model, clean_prompt, " John", " Mary")
scripts/demo_activation_steering.py
index cf9f800..842e39c 100644
--- a/scripts/demo_activation_steering.py
+++ b/scripts/demo_activation_steering.py
@@ -6,7 +6,7 @@ from pathlib import Path
 import torch
 from nnsight import LanguageModel
 
-from more_nnsight import SavedActivation
+from more_nnsight import SavedActivation, save_activations
 
 
 def _token_id(model: LanguageModel, token: str) -> int:
@@ -56,12 +56,18 @@ def main() -> None:
         "The vacation was horrible and it made me feel",
     ]
     neutral_prompt = "The day was long and by the end I felt"
+    prompts = positive_prompts + negative_prompts + [neutral_prompt]
 
-    positive = SavedActivation.save(model, positive_prompts, [steering_path], batch_size=2)
-    negative = SavedActivation.save(model, negative_prompts, [steering_path], batch_size=2)
-    neutral = SavedActivation.save(model, [neutral_prompt], [steering_path])
-
-    steering_vector = positive.mean() - negative.mean()
+    with model.trace() as tracer:
+        with tracer.invoke(prompts):
+            saved = save_activations(model, [steering_path])
+            positive = saved.slice[0 : len(positive_prompts)].mean()
+            negative = saved.slice[
+                len(positive_prompts) : len(positive_prompts) + len(negative_prompts)
+            ].mean()
+            neutral = saved.slice[len(prompts) - 1]
+
+    steering_vector = positive - negative
     steered = neutral + 1.5 * steering_vector
 
     baseline_happy, baseline_sad = candidate_logits(model, neutral_prompt, " happy", " sad")
src/more_nnsight/__init__.py
index e34a8f0..681a71b 100644
--- a/src/more_nnsight/__init__.py
+++ b/src/more_nnsight/__init__.py
@@ -1,6 +1,6 @@
-from .saved_activation import SavedActivation
+from .saved_activation import SavedActivation, save_activations
 
-__all__ = ["SavedActivation"]
+__all__ = ["SavedActivation", "save_activations"]
 
 
 def main() -> None:
src/more_nnsight/_saved_activation_values.py
new file mode 100644
index 0000000..e1aabb5
--- /dev/null
+++ b/src/more_nnsight/_saved_activation_values.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from typing import Any
+
+
+class SavedActivationValues:
+    """Provides dynamic attribute and index access over a sparse activation tree."""
+
+    def __init__(self, values: dict[Any, Any], path: str = "saved_activation") -> None:
+        """Initializes one view into the saved activation hierarchy."""
+        self._values = values
+        self._path = path
+
+    def __getattr__(self, name: str) -> Any:
+        """Exposes saved attribute-style paths while rejecting unsaved ones."""
+        if name.startswith("_"):
+            raise AttributeError(name)
+        try:
+            value = self._values[name]
+        except KeyError as exc:
+            raise AttributeError(f"Unknown saved activation path: {self._path}.{name}") from exc
+        return self._wrap(value, f"{self._path}.{name}")
+
+    def __getitem__(self, index: int) -> Any:
+        """Exposes saved index-style paths while rejecting unsaved ones."""
+        try:
+            value = self._values[index]
+        except KeyError as exc:
+            raise KeyError(f"Unknown saved activation path: {self._path}[{index}]") from exc
+        return self._wrap(value, f"{self._path}[{index}]")
+
+    @staticmethod
+    def _wrap(value: Any, path: str) -> Any:
+        """Returns child nodes for nested paths and tensors for saved leaves."""
+        if isinstance(value, dict):
+            return SavedActivationValues(value, path)
+        return value
src/more_nnsight/saved_activation.py
index 0256ee7..ae896a3 100644
--- a/src/more_nnsight/saved_activation.py
+++ b/src/more_nnsight/saved_activation.py
@@ -1,11 +1,14 @@
 from __future__ import annotations
 
 from dataclasses import dataclass
-from typing import Any
+from typing import Any, Callable
 
 import torch
 from lark import Lark, Token, Transformer
 from nnsight import LanguageModel
+from nnsight.intervention.tracing.globals import Globals, save as nnsight_save
+
+from ._saved_activation_values import SavedActivationValues
 
 
 GRAMMAR = r"""
@@ -78,92 +81,63 @@ class _ActivationTargetTransformer(Transformer[Token, ActivationTarget | PathSeg
 _PATH_PARSER = Lark(GRAMMAR, parser="lalr", transformer=_ActivationTargetTransformer())
 
 
-class _SavedActivationNode:
-    """Provides dynamic attribute and index access over a sparse activation tree."""
-
-    def __init__(self, values: dict[Any, Any], path: str = "saved_activation") -> None:
-        """Initializes one view into the saved activation hierarchy."""
-        self._values = values
-        self._path = path
-
-    def __getattr__(self, name: str) -> Any:
-        """Exposes saved attribute-style paths while rejecting unsaved ones."""
-        if name.startswith("_"):
-            raise AttributeError(name)
-        try:
-            value = self._values[name]
-        except KeyError as exc:
-            raise AttributeError(f"Unknown saved activation path: {self._path}.{name}") from exc
-        return self._wrap(value, f"{self._path}.{name}")
+class _SavedActivationBatchSlicer:
+    """Provides bracket-based batch slicing for saved activation batches."""
 
-    def __getitem__(self, index: int) -> Any:
-        """Exposes saved index-style paths while rejecting unsaved ones."""
-        try:
-            value = self._values[index]
-        except KeyError as exc:
-            raise KeyError(f"Unknown saved activation path: {self._path}[{index}]") from exc
-        return self._wrap(value, f"{self._path}[{index}]")
+    def __init__(self, saved_activation: "SavedActivation") -> None:
+        """Binds batch-slice syntax to one SavedActivation instance."""
+        self._saved_activation = saved_activation
 
-    @staticmethod
-    def _wrap(value: Any, path: str) -> Any:
-        """Returns child nodes for nested paths and tensors for saved leaves."""
-        if isinstance(value, dict):
-            return _SavedActivationNode(value, path)
-        return value
+    def __getitem__(self, batch_slices: int | range | slice | tuple[int | range | slice, ...]) -> "SavedActivation":
+        """Builds a sliced SavedActivation from one or more batch selectors."""
+        selectors = batch_slices if isinstance(batch_slices, tuple) else (batch_slices,)
+        ranges = [self._saved_activation._normalize_batch_slice(selector) for selector in selectors]
+        sliced_targets = {
+            target: self._saved_activation._slice_value(value, ranges)
+            for target, value in self._saved_activation._targets.items()
+        }
+        return self._saved_activation._from_targets(sliced_targets)
 
 
-class SavedActivation(_SavedActivationNode):
+class SavedActivation:
     """Stores sparse activation slices and can replay them into later traces.
 
     Paths use the model's real attribute/index structure, with the final
     bracket selecting the token position to save or patch. For example:
 
-    - ``model.transformer.h[2].output[10]`` saves layer 2 output at token 10
-    - ``model.transformer.h[3].output[-1]`` saves the last token from layer 3
+    - ``SavedActivation.save_activations(model, ["model.transformer.h[2].output[10]"])``
+      captures layer 2 output at token 10
+    - ``SavedActivation.save_activations(model, ["model.transformer.h[3].output[-1]"])``
+      captures the last token from layer 3
+
+    ``save_activations`` captures activation values, while ``save`` registers a
+    ``SavedActivation`` object itself with NNSight so it survives trace exit.
 
-    After saving, the same structure is exposed on the object itself, so
-    ``saved.transformer.h[2].output[10]`` returns the stored tensor.
+    After capture, the same structure is exposed under ``values``, so
+    ``saved.values.transformer.h[2].output[10]`` returns the stored tensor.
     """
 
-    def __init__(self, values: dict[Any, Any], targets: dict[ActivationTarget, torch.Tensor] | None = None) -> None:
+    def __init__(
+        self,
+        values: dict[Any, Any] | None = None,
+        targets: dict[ActivationTarget, Any] | None = None,
+    ) -> None:
         """Creates a saved-activation tree from nested values or explicit targets."""
-        super().__init__(values)
+        values = {} if values is None else values
         self._targets = targets if targets is not None else self._extract_targets(values)
+        self.values = SavedActivationValues(values)
+        self.slice = _SavedActivationBatchSlicer(self)
 
-    @classmethod
-    def save(
-        cls,
-        model: LanguageModel,
-        prompts: list[str],
-        activation_paths: list[str],
-        batch_size: int | None = None,
-    ) -> "SavedActivation":
-        """Runs one or more traces and captures only the requested token-position activations."""
-        targets = [cls.parse_path(path) for path in activation_paths]
-        batch_size = len(prompts) if batch_size is None else batch_size
-        if batch_size <= 0:
-            raise ValueError("batch_size must be a positive integer.")
-
-        saved_batches: dict[ActivationTarget, list[torch.Tensor]] = {target: [] for target in targets}
-        for prompt_batch in cls._chunk_prompts(prompts, batch_size):
-            with model.trace() as tracer:
-                with tracer.invoke(prompt_batch):
-                    for target in targets:
-                        activation = cls._resolve_activation(model, target.activation_path)
-                        cls._validate_activation_shape(activation)
-                        saved_batches[target].append(activation[:, target.position, :].save())
-
-        saved_tensors = {
-            target: torch.cat(batch_tensors, dim=0) for target, batch_tensors in saved_batches.items()
-        }
-
-        return cls._from_targets(saved_tensors)
+    def save(self) -> "SavedActivation":
+        """Registers this SavedActivation object with NNSight inside an active trace."""
+        self._register_object_save(self)
+        return self
 
     def subset(self, activation_paths: list[str]) -> "SavedActivation":
         """Keeps only selected saved paths so patches can be focused or composed."""
         requested_targets = [self.parse_path(path) for path in activation_paths]
 
-        subset_targets: dict[ActivationTarget, torch.Tensor] = {}
+        subset_targets: dict[ActivationTarget, Any] = {}
         for target in requested_targets:
             try:
                 subset_targets[target] = self._targets[target]
@@ -173,10 +147,10 @@ class SavedActivation(_SavedActivationNode):
         return self._from_targets(subset_targets)
 
     def keys(self) -> list[str]:
-        """Returns the saved activation paths in the same syntax used by save."""
+        """Returns the saved activation paths in the same syntax used by save_activations."""
         return [self._format_target(target) for target in self._targets]
 
-    def get(self, activation_path: str) -> torch.Tensor:
+    def get(self, activation_path: str) -> Any:
         """Returns one saved tensor by its user-facing activation path."""
         target = self.parse_path(activation_path)
         try:
@@ -187,7 +161,10 @@ class SavedActivation(_SavedActivationNode):
     def mean(self) -> "SavedActivation":
         """Averages each saved tensor across the batch dimension while keeping one row."""
         return self._from_targets(
-            {target: tensor.mean(dim=0, keepdim=True) for target, tensor in self._targets.items()}
+            {
+                target: self._saved_result(value.mean(dim=0, keepdim=True))
+                for target, value in self._targets.items()
+            }
         )
 
     def apply(self, model: LanguageModel) -> None:
@@ -199,40 +176,59 @@ class SavedActivation(_SavedActivationNode):
 
     def __add__(self, other: "SavedActivation") -> "SavedActivation":
         """Combines two saved activation sets elementwise when they cover the same paths."""
-        self._validate_matching_keys(other)
-        return self._from_targets(
-            {target: self._targets[target] + other._targets[target] for target in self._targets}
-        )
+        return self._binary_op(other, lambda left, right: left + right)
 
     def __sub__(self, other: "SavedActivation") -> "SavedActivation":
         """Subtracts one saved activation set from another when they cover the same paths."""
-        self._validate_matching_keys(other)
-        return self._from_targets(
-            {target: self._targets[target] - other._targets[target] for target in self._targets}
-        )
+        return self._binary_op(other, lambda left, right: left - right)
 
     def __mul__(self, scalar: float) -> "SavedActivation":
         """Scales every saved tensor by the same numeric factor."""
         if not isinstance(scalar, int | float):
             return NotImplemented
-        return self._from_targets({target: tensor * scalar for target, tensor in self._targets.items()})
+        return self._from_targets(
+            {
+                target: self._saved_result(value * scalar)
+                for target, value in self._targets.items()
+            }
+        )
 
     def __rmul__(self, scalar: float) -> "SavedActivation":
         """Allows scalar multiplication with the scalar on the left-hand side."""
         return self * scalar
 
+    def _binary_op(
+        self,
+        other: "SavedActivation",
+        op: Callable[[Any, Any], Any],
+    ) -> "SavedActivation":
+        """Applies an elementwise operation across matching saved activation sets."""
+        self._validate_matching_keys(other)
+        return self._from_targets(
+            {
+                target: self._saved_result(op(self._targets[target], other._targets[target]))
+                for target in self._targets
+            }
+        )
+
     @classmethod
-    def _from_targets(cls, targets: dict[ActivationTarget, torch.Tensor]) -> "SavedActivation":
+    def _from_targets(cls, targets: dict[ActivationTarget, Any]) -> "SavedActivation":
         """Builds the public tree view from the flat target-to-tensor mapping."""
+        result = cls(cls._build_values_from_targets(targets), dict(targets))
+        cls._register_object_save(result)
+        return result
+
+    @classmethod
+    def _build_values_from_targets(cls, targets: dict[ActivationTarget, Any]) -> dict[Any, Any]:
+        """Builds the nested values tree from a flat target-to-value mapping."""
         values: dict[Any, Any] = {}
-        for target, tensor in targets.items():
+        for target, value in targets.items():
             cursor = values
             for segment in target.activation_path:
                 key = segment.name if isinstance(segment, AttrSegment) else segment.value
                 cursor = cursor.setdefault(key, {})
-            cursor[target.position] = tensor
-
-        return cls(values, dict(targets))
+            cursor[target.position] = value
+        return values
 
     @staticmethod
     def parse_path(path: str) -> ActivationTarget:
@@ -264,9 +260,9 @@ class SavedActivation(_SavedActivationNode):
         cls,
         values: dict[Any, Any],
         prefix: tuple[PathSegment, ...] = (),
-    ) -> dict[ActivationTarget, torch.Tensor]:
+    ) -> dict[ActivationTarget, Any]:
         """Reconstructs the flat target mapping from the nested public tree form."""
-        targets: dict[ActivationTarget, torch.Tensor] = {}
+        targets: dict[ActivationTarget, Any] = {}
         for key, value in values.items():
             if isinstance(value, dict):
                 next_segment: PathSegment
@@ -281,23 +277,53 @@ class SavedActivation(_SavedActivationNode):
 
             if not isinstance(key, int):
                 raise TypeError("Saved activation leaves must be keyed by token position integers.")
-            if not isinstance(value, torch.Tensor):
-                raise TypeError("Saved activation leaves must be torch tensors.")
             targets[ActivationTarget(prefix, key)] = value
 
         return targets
 
-    @staticmethod
-    def _chunk_prompts(prompts: list[str], batch_size: int) -> list[list[str]]:
-        """Splits prompts into trace-sized batches so activation saving can scale to larger runs."""
-        return [prompts[index : index + batch_size] for index in range(0, len(prompts), batch_size)]
-
     def _validate_matching_keys(self, other: "SavedActivation") -> None:
         """Ensures elementwise operations only happen across identical saved paths."""
         if tuple(self._targets) != tuple(other._targets):
             raise ValueError("SavedActivation objects must have the same keys for arithmetic.")
 
     @staticmethod
+    def _saved_result(value: Any) -> Any:
+        """Preserves derived activation values after the trace by saving proxies when needed."""
+        save = getattr(value, "save", None)
+        if Globals.stack > 0 and callable(save):
+            return save()
+        return value
+
+    @staticmethod
+    def _register_object_save(value: "SavedActivation") -> None:
+        """Registers a newly created SavedActivation with NNSight when inside a trace."""
+        if Globals.stack > 0:
+            nnsight_save(value)
+
+    @staticmethod
+    def _slice_value(value: Any, ranges: list[range]) -> Any:
+        """Slices saved batch rows, concatenating multiple ranges when requested."""
+        pieces = [value[batch_range] for batch_range in ranges]
+        if len(pieces) == 1:
+            return SavedActivation._saved_result(pieces[0])
+        return SavedActivation._saved_result(torch.cat(pieces, dim=0))
+
+    @staticmethod
+    def _normalize_batch_slice(batch_slice: int | range | slice) -> range:
+        """Converts supported batch selectors into a common range representation."""
+        if isinstance(batch_slice, int):
+            return range(batch_slice, batch_slice + 1)
+        if isinstance(batch_slice, range):
+            return batch_slice
+        if isinstance(batch_slice, slice):
+            if batch_slice.step not in (None, 1):
+                raise ValueError("slice does not support slice steps other than 1.")
+            if batch_slice.start is None or batch_slice.stop is None:
+                raise ValueError("slice requires explicit start and stop values.")
+            return range(batch_slice.start, batch_slice.stop)
+        raise TypeError("slice accepts only ints, ranges, or slice objects.")
+
+    @staticmethod
     def _format_target(target: ActivationTarget) -> str:
         """Turns an internal target back into the user-facing path syntax."""
         parts = ["model"]
@@ -308,3 +334,14 @@ class SavedActivation(_SavedActivationNode):
                 parts.append(f"[{segment.value}]")
         parts.append(f"[{target.position}]")
         return "".join(parts)
+
+
+def save_activations(model: LanguageModel, activation_paths: list[str]) -> SavedActivation:
+    """Captures requested activations inside an already-active NNSight trace/invoke context."""
+    targets = [SavedActivation.parse_path(path) for path in activation_paths]
+    saved_values: dict[ActivationTarget, Any] = {}
+    for target in targets:
+        activation = SavedActivation._resolve_activation(model, target.activation_path)
+        SavedActivation._validate_activation_shape(activation)
+        saved_values[target] = activation[:, target.position, :].save()
+    return SavedActivation._from_targets(saved_values)
tests/test_saved_activation.py
index a1d2f6a..fc0cca0 100644
--- a/tests/test_saved_activation.py
+++ b/tests/test_saved_activation.py
@@ -7,6 +7,7 @@ from more_nnsight import SavedActivation
 from more_nnsight.saved_activation import ActivationTarget, AttrSegment, IndexSegment
 
 
+# The parser should preserve the model's real attribute/index path structure.
 def test_parse_path_uses_real_model_names() -> None:
     target = SavedActivation.parse_path("model.transformer.h[2].output[10]")
 
@@ -21,6 +22,7 @@ def test_parse_path_uses_real_model_names() -> None:
     )
 
 
+# Values should be reachable through the same nested path under `.values`.
 def test_saved_activation_mirrors_model_path() -> None:
     tensor = torch.randn(2, 768)
     saved = SavedActivation(
@@ -35,23 +37,26 @@ def test_saved_activation_mirrors_model_path() -> None:
         }
     )
 
-    assert saved.transformer.h[2].output[10] is tensor
+    assert saved.values.transformer.h[2].output[10] is tensor
 
 
+# Missing attribute hops should fail immediately instead of producing empty nodes.
 def test_missing_attribute_errors() -> None:
     saved = SavedActivation({"transformer": {}})
 
     with pytest.raises(AttributeError, match=r"saved_activation.transformer.h"):
-        _ = saved.transformer.h
+        _ = saved.values.transformer.h
 
 
+# Missing index hops should also fail immediately.
 def test_missing_index_errors() -> None:
     saved = SavedActivation({"transformer": {"h": {2: {}}}})
 
     with pytest.raises(KeyError, match="saved_activation.transformer.h"):
-        _ = saved.transformer.h[3]
+        _ = saved.values.transformer.h[3]
 
 
+# Subsetting should keep only the requested saved path and discard the rest.
 def test_subset_returns_new_saved_activation() -> None:
     tensor_1 = torch.randn(2, 768)
     tensor_2 = torch.randn(2, 768)
@@ -68,11 +73,12 @@ def test_subset_returns_new_saved_activation() -> None:
 
     subset = saved.subset(["model.transformer.h[3].output[-1]"])
 
-    assert subset.transformer.h[3].output[-1] is tensor_2
+    assert subset.values.transformer.h[3].output[-1] is tensor_2
     with pytest.raises(KeyError, match="saved_activation.transformer.h"):
-        _ = subset.transformer.h[2]
+        _ = subset.values.transformer.h[2]
 
 
+# String keys should round-trip through `keys()` and `get(...)`.
 def test_keys_and_get_round_trip_saved_paths() -> None:
     tensor = torch.randn(2, 768)
     saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}})
@@ -81,29 +87,38 @@ def test_keys_and_get_round_trip_saved_paths() -> None:
     assert saved.get("model.transformer.h[2].output[10]") is tensor
 
 
+# Batch slicing should preserve order across mixed slice and single-index selectors.
+def test_slice_keeps_requested_batch_rows() -> None:
+    tensor = torch.arange(20, dtype=torch.float32).reshape(5, 4)
+    saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}})
+
+    sliced = saved.slice[1:3, 4]
+
+    assert sliced.values.transformer.h[2].output[10].shape == (3, 4)
+    assert torch.equal(sliced.values.transformer.h[2].output[10], tensor[[1, 2, 4]])
+
+
+# Mean should return a new SavedActivation reduced across the batch dimension.
 def test_mean_reduces_batch_dimension() -> None:
     tensor = torch.randn(3, 768)
     saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}})
 
     mean_saved = saved.mean()
 
-    assert mean_saved.transformer.h[2].output[10].shape == (1, 768)
-    assert torch.allclose(mean_saved.transformer.h[2].output[10], tensor.mean(dim=0, keepdim=True))
+    assert saved.values.transformer.h[2].output[10].shape == (3, 768)
+    assert mean_saved.values.transformer.h[2].output[10].shape == (1, 768)
+    assert torch.allclose(mean_saved.values.transformer.h[2].output[10], tensor.mean(dim=0, keepdim=True))
 
 
+# Arithmetic should work elementwise on matching keys and reject mismatched ones.
 def test_add_and_subtract_require_matching_keys() -> None:
     left = SavedActivation({"transformer": {"h": {2: {"output": {10: torch.ones(1, 2)}}}}})
     right = SavedActivation({"transformer": {"h": {2: {"output": {10: 2 * torch.ones(1, 2)}}}}})
     other = SavedActivation({"transformer": {"h": {3: {"output": {10: torch.ones(1, 2)}}}}})
 
-    assert torch.equal((left + right).transformer.h[2].output[10], 3 * torch.ones(1, 2))
-    assert torch.equal((right - left).transformer.h[2].output[10], torch.ones(1, 2))
-    assert torch.equal((2.0 * left).transformer.h[2].output[10], 2 * torch.ones(1, 2))
+    assert torch.equal((left + right).values.transformer.h[2].output[10], 3 * torch.ones(1, 2))
+    assert torch.equal((right - left).values.transformer.h[2].output[10], torch.ones(1, 2))
+    assert torch.equal((2.0 * left).values.transformer.h[2].output[10], 2 * torch.ones(1, 2))
 
     with pytest.raises(ValueError, match="same keys"):
         _ = left + other
-
-
-def test_save_rejects_non_positive_batch_size() -> None:
-    with pytest.raises(ValueError, match="batch_size"):
-        SavedActivation.save(None, [], [], batch_size=0)  # type: ignore[arg-type]
tests/test_saved_activation_gpt2.py
index 5c1e55a..d90f0ee 100644
--- a/tests/test_saved_activation_gpt2.py
+++ b/tests/test_saved_activation_gpt2.py
@@ -7,7 +7,7 @@ import pytest
 import torch
 from nnsight import LanguageModel
 
-from more_nnsight import SavedActivation
+from more_nnsight import SavedActivation, save_activations
 
 
 @pytest.fixture(scope="module")
@@ -17,52 +17,54 @@ def gpt2_model() -> LanguageModel:
     return LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
 
 
+# Saving real GPT-2 activations should produce batch-by-hidden tensors at the requested paths.
 def test_save_with_gpt2_returns_saved_tensors_at_requested_paths(gpt2_model: LanguageModel) -> None:
     prompts = [
         "One two three four five six seven eight nine ten eleven twelve.",
         "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.",
     ]
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            saved = save_activations(
+                gpt2_model,
+                ["model.transformer.h[2].output[10]", "model.transformer.h[3].output[-1]"],
+            )
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            direct_layer_2 = gpt2_model.transformer.h[2].output[:, 10, :].save()
+            direct_layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save()
 
-    saved = SavedActivation.save(
-        gpt2_model,
-        prompts,
-        ["model.transformer.h[2].output[10]", "model.transformer.h[3].output[-1]"],
-        batch_size=1,
-    )
-
-    assert saved.transformer.h[2].output[10].shape == (2, 768)
-    assert saved.transformer.h[3].output[-1].shape == (2, 768)
+    assert torch.allclose(saved.get("model.transformer.h[2].output[10]"), direct_layer_2)
+    assert torch.allclose(saved.get("model.transformer.h[3].output[-1]"), direct_layer_3)
 
 
+# Unsaved GPT-2 paths should still error even when nearby paths were captured.
 def test_unsaved_paths_error_with_gpt2_saved_activation(gpt2_model: LanguageModel) -> None:
     prompts = [
         "One two three four five six seven eight nine ten eleven twelve.",
         "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.",
     ]
 
-    saved = SavedActivation.save(
-        gpt2_model,
-        prompts,
-        ["model.transformer.h[2].output[10]"],
-    )
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            saved = save_activations(gpt2_model, ["model.transformer.h[2].output[10]"])
 
     with pytest.raises(KeyError, match="saved_activation.transformer.h\\[2\\].output"):
-        _ = saved.transformer.h[2].output[9]
+        _ = saved.values.transformer.h[2].output[9]
 
     with pytest.raises(KeyError, match="saved_activation.transformer.h"):
-        _ = saved.transformer.h[3]
+        _ = saved.values.transformer.h[3]
 
 
+# Applying a saved activation should overwrite the live activation slice during a later run.
 def test_apply_with_gpt2_patches_current_run(gpt2_model: LanguageModel) -> None:
     clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
     corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
-    patch = SavedActivation.save(
-        gpt2_model,
-        [clean_prompt],
-        ["model.transformer.h[2].output[9]"],
-    )
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke([clean_prompt]):
+            patch = save_activations(gpt2_model, ["model.transformer.h[2].output[9]"])
 
-    clean_tensor = patch.transformer.h[2].output[9]
+    clean_tensor = patch.values.transformer.h[2].output[9]
 
     with gpt2_model.trace() as tracer:
         with tracer.invoke([corrupted_prompt]):
@@ -76,6 +78,7 @@ def test_apply_with_gpt2_patches_current_run(gpt2_model: LanguageModel) -> None:
     assert not torch.equal(before_patch, after_patch)
 
 
+# `mean()` should reduce a real saved GPT-2 batch to a single-row steering vector.
 def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None:
     prompts = [
         "The movie was great and I felt happy.",
@@ -83,9 +86,55 @@ def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None:
     ]
     path = "model.transformer.h[5].output[-1]"
 
-    saved = SavedActivation.save(gpt2_model, prompts, [path], batch_size=1)
-    mean_saved = saved.mean()
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            saved = save_activations(gpt2_model, [path])
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            direct_mean = gpt2_model.transformer.h[5].output[:, -1, :].mean(dim=0, keepdim=True).save()
 
     assert saved.keys() == [path]
     assert saved.get(path).shape == (2, 768)
+    mean_saved = saved.mean()
+    assert mean_saved.get(path).shape == (1, 768)
+    assert torch.allclose(mean_saved.get(path), direct_mean)
+
+
+# Slicing should let one GPT-2 save feed multiple downstream batch views.
+def test_slice_with_gpt2_selects_requested_batch_rows(gpt2_model: LanguageModel) -> None:
+    prompts = [
+        "The movie was great and I felt happy.",
+        "The dinner was excellent and I felt joyful.",
+        "The trip was terrible and I felt sad.",
+    ]
+    path = "model.transformer.h[5].output[-1]"
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            saved = save_activations(gpt2_model, [path])
+            sliced = saved.slice[0, 2:3]
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            direct_slice = gpt2_model.transformer.h[5].output[[0, 2], -1, :].save()
+
+    assert saved.get(path).shape == (3, 768)
+    assert sliced.get(path).shape == (2, 768)
+    assert torch.allclose(sliced.get(path), direct_slice)
+
+
+def test_mean_inside_trace_saves_reduced_result_on_device(gpt2_model: LanguageModel) -> None:
+    """This guards the intended workflow where reduction happens before trace exit, avoiding saving the full batch when only the mean is needed."""
+    prompts = [
+        "The movie was great and I felt happy.",
+        "The dinner was excellent and I felt joyful.",
+    ]
+    path = "model.transformer.h[5].output[-1]"
+
+    with gpt2_model.trace() as tracer:
+        with tracer.invoke(prompts):
+            saved = save_activations(gpt2_model, [path])
+            mean_saved = saved.mean()
+
+    assert saved.get(path).shape == (2, 768)
     assert mean_saved.get(path).shape == (1, 768)
+    assert torch.allclose(mean_saved.get(path), saved.get(path).mean(dim=0, keepdim=True))