Repositories / more_nnsight.git
more_nnsight.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
@@ -6,7 +6,7 @@ from pathlib import Path import torch from nnsight import LanguageModel -from more_nnsight import SavedActivation +from more_nnsight import SavedActivation, save_activations def _token_id(model: LanguageModel, token: str) -> int: @@ -61,7 +61,9 @@ def main() -> None: "model.transformer.h[2].output[9]", "model.transformer.h[3].output[9]", ] - clean_saved = SavedActivation.save(model, [clean_prompt], patch_paths) + with model.trace() as tracer: + with tracer.invoke([clean_prompt]): + clean_saved = save_activations(model, patch_paths) focused_patch = clean_saved.subset(["model.transformer.h[2].output[9]"]) clean_diff = logit_diff(model, clean_prompt, " John", " Mary")
@@ -6,7 +6,7 @@ from pathlib import Path import torch from nnsight import LanguageModel -from more_nnsight import SavedActivation +from more_nnsight import SavedActivation, save_activations def _token_id(model: LanguageModel, token: str) -> int: @@ -56,12 +56,18 @@ def main() -> None: "The vacation was horrible and it made me feel", ] neutral_prompt = "The day was long and by the end I felt" + prompts = positive_prompts + negative_prompts + [neutral_prompt] - positive = SavedActivation.save(model, positive_prompts, [steering_path], batch_size=2) - negative = SavedActivation.save(model, negative_prompts, [steering_path], batch_size=2) - neutral = SavedActivation.save(model, [neutral_prompt], [steering_path]) - - steering_vector = positive.mean() - negative.mean() + with model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(model, [steering_path]) + positive = saved.slice[0 : len(positive_prompts)].mean() + negative = saved.slice[ + len(positive_prompts) : len(positive_prompts) + len(negative_prompts) + ].mean() + neutral = saved.slice[len(prompts) - 1] + + steering_vector = positive - negative steered = neutral + 1.5 * steering_vector baseline_happy, baseline_sad = candidate_logits(model, neutral_prompt, " happy", " sad")
@@ -1,6 +1,6 @@ -from .saved_activation import SavedActivation +from .saved_activation import SavedActivation, save_activations -__all__ = ["SavedActivation"] +__all__ = ["SavedActivation", "save_activations"] def main() -> None:
@@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Any + + +class SavedActivationValues: + """Provides dynamic attribute and index access over a sparse activation tree.""" + + def __init__(self, values: dict[Any, Any], path: str = "saved_activation") -> None: + """Initializes one view into the saved activation hierarchy.""" + self._values = values + self._path = path + + def __getattr__(self, name: str) -> Any: + """Exposes saved attribute-style paths while rejecting unsaved ones.""" + if name.startswith("_"): + raise AttributeError(name) + try: + value = self._values[name] + except KeyError as exc: + raise AttributeError(f"Unknown saved activation path: {self._path}.{name}") from exc + return self._wrap(value, f"{self._path}.{name}") + + def __getitem__(self, index: int) -> Any: + """Exposes saved index-style paths while rejecting unsaved ones.""" + try: + value = self._values[index] + except KeyError as exc: + raise KeyError(f"Unknown saved activation path: {self._path}[{index}]") from exc + return self._wrap(value, f"{self._path}[{index}]") + + @staticmethod + def _wrap(value: Any, path: str) -> Any: + """Returns child nodes for nested paths and tensors for saved leaves.""" + if isinstance(value, dict): + return SavedActivationValues(value, path) + return value
@@ -1,11 +1,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any +from typing import Any, Callable import torch from lark import Lark, Token, Transformer from nnsight import LanguageModel +from nnsight.intervention.tracing.globals import Globals, save as nnsight_save + +from ._saved_activation_values import SavedActivationValues GRAMMAR = r""" @@ -78,92 +81,63 @@ class _ActivationTargetTransformer(Transformer[Token, ActivationTarget | PathSeg _PATH_PARSER = Lark(GRAMMAR, parser="lalr", transformer=_ActivationTargetTransformer()) -class _SavedActivationNode: - """Provides dynamic attribute and index access over a sparse activation tree.""" - - def __init__(self, values: dict[Any, Any], path: str = "saved_activation") -> None: - """Initializes one view into the saved activation hierarchy.""" - self._values = values - self._path = path - - def __getattr__(self, name: str) -> Any: - """Exposes saved attribute-style paths while rejecting unsaved ones.""" - if name.startswith("_"): - raise AttributeError(name) - try: - value = self._values[name] - except KeyError as exc: - raise AttributeError(f"Unknown saved activation path: {self._path}.{name}") from exc - return self._wrap(value, f"{self._path}.{name}") +class _SavedActivationBatchSlicer: + """Provides bracket-based batch slicing for saved activation batches.""" - def __getitem__(self, index: int) -> Any: - """Exposes saved index-style paths while rejecting unsaved ones.""" - try: - value = self._values[index] - except KeyError as exc: - raise KeyError(f"Unknown saved activation path: {self._path}[{index}]") from exc - return self._wrap(value, f"{self._path}[{index}]") + def __init__(self, saved_activation: "SavedActivation") -> None: + """Binds batch-slice syntax to one SavedActivation instance.""" + self._saved_activation = saved_activation - @staticmethod - def _wrap(value: Any, path: str) -> Any: - """Returns child nodes for nested paths and tensors for saved leaves.""" - if isinstance(value, dict): - return _SavedActivationNode(value, path) - return value + def __getitem__(self, batch_slices: int | range | slice | tuple[int | range | slice, ...]) -> "SavedActivation": + """Builds a sliced SavedActivation from one or more batch selectors.""" + selectors = batch_slices if isinstance(batch_slices, tuple) else (batch_slices,) + ranges = [self._saved_activation._normalize_batch_slice(selector) for selector in selectors] + sliced_targets = { + target: self._saved_activation._slice_value(value, ranges) + for target, value in self._saved_activation._targets.items() + } + return self._saved_activation._from_targets(sliced_targets) -class SavedActivation(_SavedActivationNode): +class SavedActivation: """Stores sparse activation slices and can replay them into later traces. Paths use the model's real attribute/index structure, with the final bracket selecting the token position to save or patch. For example: - - ``model.transformer.h[2].output[10]`` saves layer 2 output at token 10 - - ``model.transformer.h[3].output[-1]`` saves the last token from layer 3 + - ``SavedActivation.save_activations(model, ["model.transformer.h[2].output[10]"])`` + captures layer 2 output at token 10 + - ``SavedActivation.save_activations(model, ["model.transformer.h[3].output[-1]"])`` + captures the last token from layer 3 + + ``save_activations`` captures activation values, while ``save`` registers a + ``SavedActivation`` object itself with NNSight so it survives trace exit. - After saving, the same structure is exposed on the object itself, so - ``saved.transformer.h[2].output[10]`` returns the stored tensor. + After capture, the same structure is exposed under ``values``, so + ``saved.values.transformer.h[2].output[10]`` returns the stored tensor. """ - def __init__(self, values: dict[Any, Any], targets: dict[ActivationTarget, torch.Tensor] | None = None) -> None: + def __init__( + self, + values: dict[Any, Any] | None = None, + targets: dict[ActivationTarget, Any] | None = None, + ) -> None: """Creates a saved-activation tree from nested values or explicit targets.""" - super().__init__(values) + values = {} if values is None else values self._targets = targets if targets is not None else self._extract_targets(values) + self.values = SavedActivationValues(values) + self.slice = _SavedActivationBatchSlicer(self) - @classmethod - def save( - cls, - model: LanguageModel, - prompts: list[str], - activation_paths: list[str], - batch_size: int | None = None, - ) -> "SavedActivation": - """Runs one or more traces and captures only the requested token-position activations.""" - targets = [cls.parse_path(path) for path in activation_paths] - batch_size = len(prompts) if batch_size is None else batch_size - if batch_size <= 0: - raise ValueError("batch_size must be a positive integer.") - - saved_batches: dict[ActivationTarget, list[torch.Tensor]] = {target: [] for target in targets} - for prompt_batch in cls._chunk_prompts(prompts, batch_size): - with model.trace() as tracer: - with tracer.invoke(prompt_batch): - for target in targets: - activation = cls._resolve_activation(model, target.activation_path) - cls._validate_activation_shape(activation) - saved_batches[target].append(activation[:, target.position, :].save()) - - saved_tensors = { - target: torch.cat(batch_tensors, dim=0) for target, batch_tensors in saved_batches.items() - } - - return cls._from_targets(saved_tensors) + def save(self) -> "SavedActivation": + """Registers this SavedActivation object with NNSight inside an active trace.""" + self._register_object_save(self) + return self def subset(self, activation_paths: list[str]) -> "SavedActivation": """Keeps only selected saved paths so patches can be focused or composed.""" requested_targets = [self.parse_path(path) for path in activation_paths] - subset_targets: dict[ActivationTarget, torch.Tensor] = {} + subset_targets: dict[ActivationTarget, Any] = {} for target in requested_targets: try: subset_targets[target] = self._targets[target] @@ -173,10 +147,10 @@ class SavedActivation(_SavedActivationNode): return self._from_targets(subset_targets) def keys(self) -> list[str]: - """Returns the saved activation paths in the same syntax used by save.""" + """Returns the saved activation paths in the same syntax used by save_activations.""" return [self._format_target(target) for target in self._targets] - def get(self, activation_path: str) -> torch.Tensor: + def get(self, activation_path: str) -> Any: """Returns one saved tensor by its user-facing activation path.""" target = self.parse_path(activation_path) try: @@ -187,7 +161,10 @@ class SavedActivation(_SavedActivationNode): def mean(self) -> "SavedActivation": """Averages each saved tensor across the batch dimension while keeping one row.""" return self._from_targets( - {target: tensor.mean(dim=0, keepdim=True) for target, tensor in self._targets.items()} + { + target: self._saved_result(value.mean(dim=0, keepdim=True)) + for target, value in self._targets.items() + } ) def apply(self, model: LanguageModel) -> None: @@ -199,40 +176,59 @@ class SavedActivation(_SavedActivationNode): def __add__(self, other: "SavedActivation") -> "SavedActivation": """Combines two saved activation sets elementwise when they cover the same paths.""" - self._validate_matching_keys(other) - return self._from_targets( - {target: self._targets[target] + other._targets[target] for target in self._targets} - ) + return self._binary_op(other, lambda left, right: left + right) def __sub__(self, other: "SavedActivation") -> "SavedActivation": """Subtracts one saved activation set from another when they cover the same paths.""" - self._validate_matching_keys(other) - return self._from_targets( - {target: self._targets[target] - other._targets[target] for target in self._targets} - ) + return self._binary_op(other, lambda left, right: left - right) def __mul__(self, scalar: float) -> "SavedActivation": """Scales every saved tensor by the same numeric factor.""" if not isinstance(scalar, int | float): return NotImplemented - return self._from_targets({target: tensor * scalar for target, tensor in self._targets.items()}) + return self._from_targets( + { + target: self._saved_result(value * scalar) + for target, value in self._targets.items() + } + ) def __rmul__(self, scalar: float) -> "SavedActivation": """Allows scalar multiplication with the scalar on the left-hand side.""" return self * scalar + def _binary_op( + self, + other: "SavedActivation", + op: Callable[[Any, Any], Any], + ) -> "SavedActivation": + """Applies an elementwise operation across matching saved activation sets.""" + self._validate_matching_keys(other) + return self._from_targets( + { + target: self._saved_result(op(self._targets[target], other._targets[target])) + for target in self._targets + } + ) + @classmethod - def _from_targets(cls, targets: dict[ActivationTarget, torch.Tensor]) -> "SavedActivation": + def _from_targets(cls, targets: dict[ActivationTarget, Any]) -> "SavedActivation": """Builds the public tree view from the flat target-to-tensor mapping.""" + result = cls(cls._build_values_from_targets(targets), dict(targets)) + cls._register_object_save(result) + return result + + @classmethod + def _build_values_from_targets(cls, targets: dict[ActivationTarget, Any]) -> dict[Any, Any]: + """Builds the nested values tree from a flat target-to-value mapping.""" values: dict[Any, Any] = {} - for target, tensor in targets.items(): + for target, value in targets.items(): cursor = values for segment in target.activation_path: key = segment.name if isinstance(segment, AttrSegment) else segment.value cursor = cursor.setdefault(key, {}) - cursor[target.position] = tensor - - return cls(values, dict(targets)) + cursor[target.position] = value + return values @staticmethod def parse_path(path: str) -> ActivationTarget: @@ -264,9 +260,9 @@ class SavedActivation(_SavedActivationNode): cls, values: dict[Any, Any], prefix: tuple[PathSegment, ...] = (), - ) -> dict[ActivationTarget, torch.Tensor]: + ) -> dict[ActivationTarget, Any]: """Reconstructs the flat target mapping from the nested public tree form.""" - targets: dict[ActivationTarget, torch.Tensor] = {} + targets: dict[ActivationTarget, Any] = {} for key, value in values.items(): if isinstance(value, dict): next_segment: PathSegment @@ -281,23 +277,53 @@ class SavedActivation(_SavedActivationNode): if not isinstance(key, int): raise TypeError("Saved activation leaves must be keyed by token position integers.") - if not isinstance(value, torch.Tensor): - raise TypeError("Saved activation leaves must be torch tensors.") targets[ActivationTarget(prefix, key)] = value return targets - @staticmethod - def _chunk_prompts(prompts: list[str], batch_size: int) -> list[list[str]]: - """Splits prompts into trace-sized batches so activation saving can scale to larger runs.""" - return [prompts[index : index + batch_size] for index in range(0, len(prompts), batch_size)] - def _validate_matching_keys(self, other: "SavedActivation") -> None: """Ensures elementwise operations only happen across identical saved paths.""" if tuple(self._targets) != tuple(other._targets): raise ValueError("SavedActivation objects must have the same keys for arithmetic.") @staticmethod + def _saved_result(value: Any) -> Any: + """Preserves derived activation values after the trace by saving proxies when needed.""" + save = getattr(value, "save", None) + if Globals.stack > 0 and callable(save): + return save() + return value + + @staticmethod + def _register_object_save(value: "SavedActivation") -> None: + """Registers a newly created SavedActivation with NNSight when inside a trace.""" + if Globals.stack > 0: + nnsight_save(value) + + @staticmethod + def _slice_value(value: Any, ranges: list[range]) -> Any: + """Slices saved batch rows, concatenating multiple ranges when requested.""" + pieces = [value[batch_range] for batch_range in ranges] + if len(pieces) == 1: + return SavedActivation._saved_result(pieces[0]) + return SavedActivation._saved_result(torch.cat(pieces, dim=0)) + + @staticmethod + def _normalize_batch_slice(batch_slice: int | range | slice) -> range: + """Converts supported batch selectors into a common range representation.""" + if isinstance(batch_slice, int): + return range(batch_slice, batch_slice + 1) + if isinstance(batch_slice, range): + return batch_slice + if isinstance(batch_slice, slice): + if batch_slice.step not in (None, 1): + raise ValueError("slice does not support slice steps other than 1.") + if batch_slice.start is None or batch_slice.stop is None: + raise ValueError("slice requires explicit start and stop values.") + return range(batch_slice.start, batch_slice.stop) + raise TypeError("slice accepts only ints, ranges, or slice objects.") + + @staticmethod def _format_target(target: ActivationTarget) -> str: """Turns an internal target back into the user-facing path syntax.""" parts = ["model"] @@ -308,3 +334,14 @@ class SavedActivation(_SavedActivationNode): parts.append(f"[{segment.value}]") parts.append(f"[{target.position}]") return "".join(parts) + + +def save_activations(model: LanguageModel, activation_paths: list[str]) -> SavedActivation: + """Captures requested activations inside an already-active NNSight trace/invoke context.""" + targets = [SavedActivation.parse_path(path) for path in activation_paths] + saved_values: dict[ActivationTarget, Any] = {} + for target in targets: + activation = SavedActivation._resolve_activation(model, target.activation_path) + SavedActivation._validate_activation_shape(activation) + saved_values[target] = activation[:, target.position, :].save() + return SavedActivation._from_targets(saved_values)
@@ -7,6 +7,7 @@ from more_nnsight import SavedActivation from more_nnsight.saved_activation import ActivationTarget, AttrSegment, IndexSegment +# The parser should preserve the model's real attribute/index path structure. def test_parse_path_uses_real_model_names() -> None: target = SavedActivation.parse_path("model.transformer.h[2].output[10]") @@ -21,6 +22,7 @@ def test_parse_path_uses_real_model_names() -> None: ) +# Values should be reachable through the same nested path under `.values`. def test_saved_activation_mirrors_model_path() -> None: tensor = torch.randn(2, 768) saved = SavedActivation( @@ -35,23 +37,26 @@ def test_saved_activation_mirrors_model_path() -> None: } ) - assert saved.transformer.h[2].output[10] is tensor + assert saved.values.transformer.h[2].output[10] is tensor +# Missing attribute hops should fail immediately instead of producing empty nodes. def test_missing_attribute_errors() -> None: saved = SavedActivation({"transformer": {}}) with pytest.raises(AttributeError, match=r"saved_activation.transformer.h"): - _ = saved.transformer.h + _ = saved.values.transformer.h +# Missing index hops should also fail immediately. def test_missing_index_errors() -> None: saved = SavedActivation({"transformer": {"h": {2: {}}}}) with pytest.raises(KeyError, match="saved_activation.transformer.h"): - _ = saved.transformer.h[3] + _ = saved.values.transformer.h[3] +# Subsetting should keep only the requested saved path and discard the rest. def test_subset_returns_new_saved_activation() -> None: tensor_1 = torch.randn(2, 768) tensor_2 = torch.randn(2, 768) @@ -68,11 +73,12 @@ def test_subset_returns_new_saved_activation() -> None: subset = saved.subset(["model.transformer.h[3].output[-1]"]) - assert subset.transformer.h[3].output[-1] is tensor_2 + assert subset.values.transformer.h[3].output[-1] is tensor_2 with pytest.raises(KeyError, match="saved_activation.transformer.h"): - _ = subset.transformer.h[2] + _ = subset.values.transformer.h[2] +# String keys should round-trip through `keys()` and `get(...)`. def test_keys_and_get_round_trip_saved_paths() -> None: tensor = torch.randn(2, 768) saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}}) @@ -81,29 +87,38 @@ def test_keys_and_get_round_trip_saved_paths() -> None: assert saved.get("model.transformer.h[2].output[10]") is tensor +# Batch slicing should preserve order across mixed slice and single-index selectors. +def test_slice_keeps_requested_batch_rows() -> None: + tensor = torch.arange(20, dtype=torch.float32).reshape(5, 4) + saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}}) + + sliced = saved.slice[1:3, 4] + + assert sliced.values.transformer.h[2].output[10].shape == (3, 4) + assert torch.equal(sliced.values.transformer.h[2].output[10], tensor[[1, 2, 4]]) + + +# Mean should return a new SavedActivation reduced across the batch dimension. def test_mean_reduces_batch_dimension() -> None: tensor = torch.randn(3, 768) saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}}) mean_saved = saved.mean() - assert mean_saved.transformer.h[2].output[10].shape == (1, 768) - assert torch.allclose(mean_saved.transformer.h[2].output[10], tensor.mean(dim=0, keepdim=True)) + assert saved.values.transformer.h[2].output[10].shape == (3, 768) + assert mean_saved.values.transformer.h[2].output[10].shape == (1, 768) + assert torch.allclose(mean_saved.values.transformer.h[2].output[10], tensor.mean(dim=0, keepdim=True)) +# Arithmetic should work elementwise on matching keys and reject mismatched ones. def test_add_and_subtract_require_matching_keys() -> None: left = SavedActivation({"transformer": {"h": {2: {"output": {10: torch.ones(1, 2)}}}}}) right = SavedActivation({"transformer": {"h": {2: {"output": {10: 2 * torch.ones(1, 2)}}}}}) other = SavedActivation({"transformer": {"h": {3: {"output": {10: torch.ones(1, 2)}}}}}) - assert torch.equal((left + right).transformer.h[2].output[10], 3 * torch.ones(1, 2)) - assert torch.equal((right - left).transformer.h[2].output[10], torch.ones(1, 2)) - assert torch.equal((2.0 * left).transformer.h[2].output[10], 2 * torch.ones(1, 2)) + assert torch.equal((left + right).values.transformer.h[2].output[10], 3 * torch.ones(1, 2)) + assert torch.equal((right - left).values.transformer.h[2].output[10], torch.ones(1, 2)) + assert torch.equal((2.0 * left).values.transformer.h[2].output[10], 2 * torch.ones(1, 2)) with pytest.raises(ValueError, match="same keys"): _ = left + other - - -def test_save_rejects_non_positive_batch_size() -> None: - with pytest.raises(ValueError, match="batch_size"): - SavedActivation.save(None, [], [], batch_size=0) # type: ignore[arg-type]
@@ -7,7 +7,7 @@ import pytest import torch from nnsight import LanguageModel -from more_nnsight import SavedActivation +from more_nnsight import SavedActivation, save_activations @pytest.fixture(scope="module") @@ -17,52 +17,54 @@ def gpt2_model() -> LanguageModel: return LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True) +# Saving real GPT-2 activations should produce batch-by-hidden tensors at the requested paths. def test_save_with_gpt2_returns_saved_tensors_at_requested_paths(gpt2_model: LanguageModel) -> None: prompts = [ "One two three four five six seven eight nine ten eleven twelve.", "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.", ] + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations( + gpt2_model, + ["model.transformer.h[2].output[10]", "model.transformer.h[3].output[-1]"], + ) + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + direct_layer_2 = gpt2_model.transformer.h[2].output[:, 10, :].save() + direct_layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save() - saved = SavedActivation.save( - gpt2_model, - prompts, - ["model.transformer.h[2].output[10]", "model.transformer.h[3].output[-1]"], - batch_size=1, - ) - - assert saved.transformer.h[2].output[10].shape == (2, 768) - assert saved.transformer.h[3].output[-1].shape == (2, 768) + assert torch.allclose(saved.get("model.transformer.h[2].output[10]"), direct_layer_2) + assert torch.allclose(saved.get("model.transformer.h[3].output[-1]"), direct_layer_3) +# Unsaved GPT-2 paths should still error even when nearby paths were captured. def test_unsaved_paths_error_with_gpt2_saved_activation(gpt2_model: LanguageModel) -> None: prompts = [ "One two three four five six seven eight nine ten eleven twelve.", "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.", ] - saved = SavedActivation.save( - gpt2_model, - prompts, - ["model.transformer.h[2].output[10]"], - ) + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(gpt2_model, ["model.transformer.h[2].output[10]"]) with pytest.raises(KeyError, match="saved_activation.transformer.h\\[2\\].output"): - _ = saved.transformer.h[2].output[9] + _ = saved.values.transformer.h[2].output[9] with pytest.raises(KeyError, match="saved_activation.transformer.h"): - _ = saved.transformer.h[3] + _ = saved.values.transformer.h[3] +# Applying a saved activation should overwrite the live activation slice during a later run. def test_apply_with_gpt2_patches_current_run(gpt2_model: LanguageModel) -> None: clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to" corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" - patch = SavedActivation.save( - gpt2_model, - [clean_prompt], - ["model.transformer.h[2].output[9]"], - ) + with gpt2_model.trace() as tracer: + with tracer.invoke([clean_prompt]): + patch = save_activations(gpt2_model, ["model.transformer.h[2].output[9]"]) - clean_tensor = patch.transformer.h[2].output[9] + clean_tensor = patch.values.transformer.h[2].output[9] with gpt2_model.trace() as tracer: with tracer.invoke([corrupted_prompt]): @@ -76,6 +78,7 @@ def test_apply_with_gpt2_patches_current_run(gpt2_model: LanguageModel) -> None: assert not torch.equal(before_patch, after_patch) +# `mean()` should reduce a real saved GPT-2 batch to a single-row steering vector. def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None: prompts = [ "The movie was great and I felt happy.", @@ -83,9 +86,55 @@ def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None: ] path = "model.transformer.h[5].output[-1]" - saved = SavedActivation.save(gpt2_model, prompts, [path], batch_size=1) - mean_saved = saved.mean() + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(gpt2_model, [path]) + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + direct_mean = gpt2_model.transformer.h[5].output[:, -1, :].mean(dim=0, keepdim=True).save() assert saved.keys() == [path] assert saved.get(path).shape == (2, 768) + mean_saved = saved.mean() + assert mean_saved.get(path).shape == (1, 768) + assert torch.allclose(mean_saved.get(path), direct_mean) + + +# Slicing should let one GPT-2 save feed multiple downstream batch views. +def test_slice_with_gpt2_selects_requested_batch_rows(gpt2_model: LanguageModel) -> None: + prompts = [ + "The movie was great and I felt happy.", + "The dinner was excellent and I felt joyful.", + "The trip was terrible and I felt sad.", + ] + path = "model.transformer.h[5].output[-1]" + + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(gpt2_model, [path]) + sliced = saved.slice[0, 2:3] + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + direct_slice = gpt2_model.transformer.h[5].output[[0, 2], -1, :].save() + + assert saved.get(path).shape == (3, 768) + assert sliced.get(path).shape == (2, 768) + assert torch.allclose(sliced.get(path), direct_slice) + + +def test_mean_inside_trace_saves_reduced_result_on_device(gpt2_model: LanguageModel) -> None: + """This guards the intended workflow where reduction happens before trace exit, avoiding saving the full batch when only the mean is needed.""" + prompts = [ + "The movie was great and I felt happy.", + "The dinner was excellent and I felt joyful.", + ] + path = "model.transformer.h[5].output[-1]" + + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(gpt2_model, [path]) + mean_saved = saved.mean() + + assert saved.get(path).shape == (2, 768) assert mean_saved.get(path).shape == (1, 768) + assert torch.allclose(mean_saved.get(path), saved.get(path).mean(dim=0, keepdim=True))