Repositories / more_nnsight.git
src/more_nnsight/saved_activation.py
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Iterator
import torch
from lark import Lark, Token, Transformer
from nnsight import LanguageModel
from nnsight.intervention.tracing.globals import Globals, save as nnsight_save
from ._saved_activation_values import SavedActivationValues
GRAMMAR = r"""
start: "model" segment+ "[" SIGNED_INT "]"
segment: "." CNAME -> attr
| "[" SIGNED_INT "]" -> index
| "[" slice_expr "]" -> slice
slice_expr: SIGNED_INT ":" SIGNED_INT -> bounded_slice
| SIGNED_INT ":" -> start_slice
| ":" SIGNED_INT -> stop_slice
| ":" -> full_slice
%import common.CNAME
%import common.SIGNED_INT
%import common.WS
%ignore WS
"""
@dataclass(frozen=True, slots=True)
class AttrSegment:
"""Represents an attribute access in a model activation path."""
name: str
@dataclass(frozen=True, slots=True)
class IndexSegment:
"""Represents an index access in a model activation path."""
value: int
@dataclass(frozen=True, slots=True)
class SliceSegment:
"""Represents a slice access in a model activation path."""
start: int | None
stop: int | None
PathSegment = AttrSegment | IndexSegment | SliceSegment
@dataclass(frozen=True, slots=True)
class ActivationTarget:
"""Identifies one token-position slice of a traced activation tensor."""
activation_path: tuple[PathSegment, ...]
position: int
class _ActivationTargetTransformer(Transformer[Token, ActivationTarget | PathSegment | str | int]):
"""Converts parsed path syntax into structured activation targets."""
def start(self, items: list[PathSegment | int]) -> ActivationTarget:
"""Builds the final target object used by save, subset, and apply."""
*activation_path, position = items
if not activation_path:
raise ValueError("Activation path must include at least one segment before the position.")
if not isinstance(position, int):
raise TypeError(f"Expected final position to be an integer, got {type(position)!r}")
return ActivationTarget(tuple(activation_path), position)
def attr(self, items: list[str]) -> AttrSegment:
"""Turns a parsed attribute name into a path segment."""
return AttrSegment(items[0])
def index(self, items: list[int]) -> IndexSegment:
"""Turns a parsed bracketed integer into a path segment."""
return IndexSegment(items[0])
def slice(self, items: list[tuple[int | None, int | None]]) -> SliceSegment:
"""Turns a parsed bracketed slice into a path segment."""
start, stop = items[0]
return SliceSegment(start, stop)
def bounded_slice(self, items: list[int]) -> tuple[int | None, int | None]:
"""Parses a slice with both explicit bounds."""
return (items[0], items[1])
def start_slice(self, items: list[int]) -> tuple[int | None, int | None]:
"""Parses a slice with only an explicit start bound."""
return (items[0], None)
def stop_slice(self, items: list[int]) -> tuple[int | None, int | None]:
"""Parses a slice with only an explicit stop bound."""
return (None, items[0])
def full_slice(self, _: list[int]) -> tuple[int | None, int | None]:
"""Parses a slice that spans the entire container."""
return (None, None)
def CNAME(self, token: Token) -> str:
"""Preserves parsed attribute names as plain strings."""
return str(token)
def SIGNED_INT(self, token: Token) -> int:
"""Parses integer literals so positions and indices stay numeric."""
return int(token)
_PATH_PARSER = Lark(GRAMMAR, parser="lalr", transformer=_ActivationTargetTransformer())
class _SavedActivationBatchSlicer:
"""Provides bracket-based batch slicing for saved activation batches."""
def __init__(self, saved_activation: "SavedActivation") -> None:
"""Binds batch-slice syntax to one SavedActivation instance."""
self._saved_activation = saved_activation
def __getitem__(self, batch_slices: int | range | slice | tuple[int | range | slice, ...]) -> "SavedActivation":
"""Builds a sliced SavedActivation from one or more batch selectors."""
selectors = batch_slices if isinstance(batch_slices, tuple) else (batch_slices,)
ranges = [self._saved_activation._normalize_batch_slice(selector) for selector in selectors]
sliced_targets = {
target: self._saved_activation._slice_value(value, ranges)
for target, value in self._saved_activation._targets.items()
}
return self._saved_activation._from_targets(sliced_targets)
class SavedActivation:
"""Stores sparse activation slices and can replay them into later traces.
Paths use the model's real attribute/index structure, with the final
bracket selecting the token position to save or patch. For example:
- ``save_activations(model, ["model.transformer.h[2].output[10]"])``
captures layer 2 output at token 10
- ``save_activations(model, ["model.transformer.h[3].output[-1]"])``
captures the last token from layer 3
- ``save_activations(model, ["model.transformer.h[:].output[10]"])``
captures token 10 at every layer in a ModuleList-like container
``save_activations`` captures activation values, while ``save`` registers a
``SavedActivation`` object itself with NNSight so it survives trace exit.
After capture, the same structure is exposed under ``values``, so
``saved.values.transformer.h[2].output[10]`` returns the stored tensor.
"""
def __init__(
self,
values: dict[Any, Any] | None = None,
targets: dict[ActivationTarget, Any] | None = None,
) -> None:
"""Creates a saved-activation tree from nested values or explicit targets."""
values = {} if values is None else values
extracted_targets = targets if targets is not None else self._extract_targets(values)
self._targets = self._canonicalize_targets(extracted_targets)
self.values = SavedActivationValues(values)
self.slice = _SavedActivationBatchSlicer(self)
def save(self) -> "SavedActivation":
"""Registers this SavedActivation object with NNSight inside an active trace."""
self._register_object_save(self)
return self
@classmethod
def from_pairs(cls, *pairs: tuple[str, Any]) -> "SavedActivation":
"""Builds a SavedActivation directly from explicit path-value pairs."""
targets: dict[ActivationTarget, Any] = {}
for path, value in pairs:
target = cls.parse_path(path)
if target in targets:
raise ValueError(f"SavedActivation.from_pairs got duplicate path: {path}")
targets[target] = value
return cls._from_targets(targets)
def subset(self, activation_paths: list[str]) -> "SavedActivation":
"""Keeps only selected saved paths so patches can be focused or composed."""
requested_targets = [self.parse_path(path) for path in activation_paths]
subset_targets: dict[ActivationTarget, Any] = {}
for target in requested_targets:
try:
subset_targets[target] = self._targets[target]
except KeyError as exc:
raise KeyError(f"Unknown saved activation path: {self._format_target(target)}") from exc
return self._from_targets(subset_targets)
def keys(self) -> list[str]:
"""Returns the saved activation paths in the same syntax used by save_activations."""
return [self._format_target(target) for target in self._targets]
def get(self, activation_path: str) -> Any:
"""Returns one saved tensor by its user-facing activation path."""
target = self.parse_path(activation_path)
try:
return self._targets[target]
except KeyError as exc:
raise KeyError(f"Unknown saved activation path: {activation_path}") from exc
def mean(self) -> "SavedActivation":
"""Averages each saved tensor across the batch dimension while keeping one row."""
return self._from_targets(
{
target: self._saved_result(value.mean(dim=0, keepdim=True))
for target, value in self._targets.items()
}
)
def union(self, other: "SavedActivation") -> "SavedActivation":
"""Combines two disjoint saved activation sets into one patch collection."""
overlapping_targets = set(self._targets) & set(other._targets)
if overlapping_targets:
overlapping_paths = ", ".join(self._format_target(target) for target in overlapping_targets)
raise ValueError(f"SavedActivation.union requires disjoint keys, got overlap: {overlapping_paths}")
return self._from_targets({**self._targets, **other._targets})
def apply(self, model: Any) -> None:
"""Writes the stored activations into the current traced run of the model."""
for key, _current_value, update in updates(model, self.keys()):
update(self.get(key))
def __add__(self, other: "SavedActivation") -> "SavedActivation":
"""Combines two saved activation sets elementwise when they cover the same paths."""
return self._binary_op(other, lambda left, right: left + right)
def __sub__(self, other: "SavedActivation") -> "SavedActivation":
"""Subtracts one saved activation set from another when they cover the same paths."""
return self._binary_op(other, lambda left, right: left - right)
def __mul__(self, scalar: float) -> "SavedActivation":
"""Scales every saved tensor by the same numeric factor."""
if not isinstance(scalar, int | float):
return NotImplemented
return self._from_targets(
{
target: self._saved_result(value * scalar)
for target, value in self._targets.items()
}
)
def __rmul__(self, scalar: float) -> "SavedActivation":
"""Allows scalar multiplication with the scalar on the left-hand side."""
return self * scalar
def _binary_op(
self,
other: "SavedActivation",
op: Callable[[Any, Any], Any],
) -> "SavedActivation":
"""Applies an elementwise operation across matching saved activation sets."""
self._validate_matching_keys(other)
return self._from_targets(
{
target: self._saved_result(op(self._targets[target], other._targets[target]))
for target in self._targets
}
)
@classmethod
def _from_targets(cls, targets: dict[ActivationTarget, Any]) -> "SavedActivation":
"""Builds the public tree view from the flat target-to-tensor mapping."""
result = cls(cls._build_values_from_targets(targets), dict(targets))
cls._register_object_save(result)
return result
@classmethod
def _build_values_from_targets(cls, targets: dict[ActivationTarget, Any]) -> dict[Any, Any]:
"""Builds the nested values tree from a flat target-to-value mapping."""
values: dict[Any, Any] = {}
for target, value in targets.items():
cursor = values
for segment in target.activation_path:
key = segment.name if isinstance(segment, AttrSegment) else segment.value
cursor = cursor.setdefault(key, {})
cursor[target.position] = value
return values
@staticmethod
def parse_path(path: str) -> ActivationTarget:
"""Parses user-facing model paths into the internal target representation."""
return _PATH_PARSER.parse(path)
@staticmethod
def _resolve_activation(model: Any, activation_path: tuple[PathSegment, ...]) -> Any:
"""Finds the live NNSight object identified by a parsed activation path."""
current: Any = model
for segment in activation_path:
current = SavedActivation._apply_concrete_segment(current, segment)
return current
@staticmethod
def _validate_activation_shape(activation: Any) -> None:
"""Enforces the batch-sequence-hidden convention expected by this abstraction."""
if len(activation.shape) != 3:
raise ValueError(
"Saved activation paths must resolve to a tensor with shape "
"(batch, sequence, hidden) before the final position index."
)
@classmethod
def _extract_targets(
cls,
values: dict[Any, Any],
prefix: tuple[PathSegment, ...] = (),
) -> dict[ActivationTarget, Any]:
"""Reconstructs the flat target mapping from the nested public tree form."""
targets: dict[ActivationTarget, Any] = {}
for key, value in values.items():
if isinstance(value, dict):
next_segment: PathSegment
if isinstance(key, str):
next_segment = AttrSegment(key)
elif isinstance(key, int):
next_segment = IndexSegment(key)
else:
raise TypeError(f"Unsupported key type in saved activation tree: {type(key)!r}")
targets.update(cls._extract_targets(value, prefix + (next_segment,)))
continue
if not isinstance(key, int):
raise TypeError("Saved activation leaves must be keyed by token position integers.")
targets[ActivationTarget(prefix, key)] = value
return targets
def _validate_matching_keys(self, other: "SavedActivation") -> None:
"""Ensures elementwise operations only happen across identical saved paths."""
if tuple(self._targets) != tuple(other._targets):
raise ValueError("SavedActivation objects must have the same keys for arithmetic.")
@staticmethod
def _saved_result(value: Any) -> Any:
"""Preserves derived activation values after the trace by saving proxies when needed."""
save = getattr(value, "save", None)
if Globals.stack > 0 and callable(save):
return save()
return value
@staticmethod
def _register_object_save(value: "SavedActivation") -> None:
"""Registers a newly created SavedActivation with NNSight when inside a trace."""
if Globals.stack > 0:
nnsight_save(value)
@classmethod
def _canonicalize_targets(cls, targets: dict[ActivationTarget, Any]) -> dict[ActivationTarget, Any]:
"""Normalizes target order so every SavedActivation iterates in graph-safe model order."""
return dict(sorted(targets.items(), key=lambda item: cls._target_order_key(item[0])))
@staticmethod
def _slice_value(value: Any, ranges: list[range]) -> Any:
"""Slices saved batch rows, concatenating multiple ranges when requested."""
pieces = [value[batch_range] for batch_range in ranges]
if len(pieces) == 1:
return SavedActivation._saved_result(pieces[0])
return SavedActivation._saved_result(torch.cat(pieces, dim=0))
@staticmethod
def _normalize_batch_slice(batch_slice: int | range | slice) -> range:
"""Converts supported batch selectors into a common range representation."""
if isinstance(batch_slice, int):
return range(batch_slice, batch_slice + 1)
if isinstance(batch_slice, range):
return batch_slice
if isinstance(batch_slice, slice):
if batch_slice.step not in (None, 1):
raise ValueError("slice does not support slice steps other than 1.")
if batch_slice.start is None or batch_slice.stop is None:
raise ValueError("slice requires explicit start and stop values.")
return range(batch_slice.start, batch_slice.stop)
raise TypeError("slice accepts only ints, ranges, or slice objects.")
@staticmethod
def _format_target(target: ActivationTarget) -> str:
"""Turns an internal target back into the user-facing path syntax."""
parts = ["model"]
for segment in target.activation_path:
if isinstance(segment, AttrSegment):
parts.append(f".{segment.name}")
elif isinstance(segment, IndexSegment):
parts.append(f"[{segment.value}]")
else:
start = "" if segment.start is None else str(segment.start)
stop = "" if segment.stop is None else str(segment.stop)
parts.append(f"[{start}:{stop}]")
parts.append(f"[{target.position}]")
return "".join(parts)
@classmethod
def _target_order_key(cls, target: ActivationTarget) -> tuple[Any, ...]:
"""Orders concrete targets by model traversal path, then by token position."""
path_key: list[tuple[int, Any]] = []
for segment in target.activation_path:
if isinstance(segment, AttrSegment):
path_key.append((0, segment.name))
elif isinstance(segment, IndexSegment):
path_key.append((1, segment.value))
else:
raise ValueError("Targets must be concrete before access ordering is computed.")
return (*path_key, (2, target.position))
@classmethod
def _expand_target(
cls,
model: Any,
target: ActivationTarget,
) -> list[ActivationTarget]:
"""Expands slice segments into concrete index paths before saving activations."""
return cls._expand_segments(
cls._static_structure(model),
tuple(),
target.activation_path,
target.position,
)
@classmethod
def _expand_segments(
cls,
current: Any,
prefix: tuple[PathSegment, ...],
remaining: tuple[PathSegment, ...],
position: int,
) -> list[ActivationTarget]:
"""Recursively expands one parsed activation path into concrete saved targets."""
if not remaining:
return [ActivationTarget(prefix, position)]
if not cls._contains_slice(remaining):
return [ActivationTarget(prefix + remaining, position)]
segment, *rest = remaining
concrete_targets: list[ActivationTarget] = []
for next_current, concrete_segment in cls._expand_segment(current, segment):
concrete_targets.extend(
cls._expand_segments(next_current, prefix + (concrete_segment,), tuple(rest), position)
)
return concrete_targets
@staticmethod
def _contains_slice(segments: tuple[PathSegment, ...]) -> bool:
"""Checks whether a remaining path still needs ModuleList slice expansion."""
return any(isinstance(segment, SliceSegment) for segment in segments)
@staticmethod
def _normalize_path_slice(segment: SliceSegment, length: int) -> range:
"""Converts a parsed path slice into concrete indices for a module container."""
return range(*slice(segment.start, segment.stop).indices(length))
@staticmethod
def _static_structure(value: Any) -> Any:
"""Returns the underlying module tree used only for slice expansion."""
return getattr(value, "_module", value)
@staticmethod
def _apply_concrete_segment(current: Any, segment: PathSegment) -> Any:
"""Applies one non-slice path segment to either a live envoy or static module."""
if isinstance(segment, AttrSegment):
return getattr(current, segment.name)
if isinstance(segment, IndexSegment):
return current[segment.value]
raise ValueError("Slice segments must be expanded before resolving a concrete activation.")
@classmethod
def _expand_segment(
cls,
current: Any,
segment: PathSegment,
) -> list[tuple[Any, AttrSegment | IndexSegment]]:
"""Expands one path segment into concrete child accesses."""
if isinstance(segment, SliceSegment):
static_current = cls._static_structure(current)
if not hasattr(static_current, "__len__") or not hasattr(static_current, "__getitem__"):
raise TypeError("Slice syntax is only supported on indexable module containers.")
return [
(static_current[index], IndexSegment(index))
for index in cls._normalize_path_slice(segment, len(static_current))
]
next_current = cls._apply_concrete_segment(current, segment)
return [(cls._static_structure(next_current), segment)]
def save_activations(model: Any, activation_paths: list[str]) -> SavedActivation:
"""Captures requested activations inside an already-active NNSight trace/invoke context."""
parsed_targets = [SavedActivation.parse_path(path) for path in activation_paths]
concrete_targets = [
concrete_target
for parsed_target in parsed_targets
for concrete_target in SavedActivation._expand_target(model, parsed_target)
]
saved_values: dict[ActivationTarget, Any] = {}
for target in sorted(concrete_targets, key=SavedActivation._target_order_key):
activation = SavedActivation._resolve_activation(model, target.activation_path)
SavedActivation._validate_activation_shape(activation)
saved_values[target] = activation[:, target.position, :].save()
return SavedActivation._from_targets(saved_values)
def updates(
model: Any,
activation_paths: list[str],
) -> Iterator[tuple[str, Any, Callable[[Any], None]]]:
"""Returns layer-major live update handles for concrete saved-activation paths."""
parsed_targets = [SavedActivation.parse_path(path) for path in activation_paths]
for path, target in zip(activation_paths, parsed_targets, strict=True):
if SavedActivation._contains_slice(target.activation_path):
raise ValueError(f"updates() requires concrete keys, got slice path: {path}")
if len(set(parsed_targets)) != len(parsed_targets):
raise ValueError("updates() requires unique keys.")
canonical_targets = sorted(parsed_targets, key=SavedActivation._target_order_key)
if tuple(parsed_targets) != tuple(canonical_targets):
raise ValueError("updates() requires keys in canonical layer-major order.")
for key, target in zip(activation_paths, parsed_targets, strict=True):
activation = SavedActivation._resolve_activation(model, target.activation_path)
SavedActivation._validate_activation_shape(activation)
current_value = activation[:, target.position, :]
def update(new_value: Any, activation: Any = activation, position: int = target.position) -> None:
activation[:, position, :] = new_value
yield (key, current_value, update)