Repositories / more_nnsight.git

src/more_nnsight/saved_activation.py

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch
22975 bytes · 89f0be3a0af6
from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable, Iterator import torch from lark import Lark, Token, Transformer from nnsight import LanguageModel from nnsight.intervention.tracing.globals import Globals, save as nnsight_save from ._saved_activation_values import SavedActivationValues GRAMMAR = r""" start: "model" segment+ "[" SIGNED_INT "]" segment: "." CNAME -> attr | "[" SIGNED_INT "]" -> index | "[" slice_expr "]" -> slice slice_expr: SIGNED_INT ":" SIGNED_INT -> bounded_slice | SIGNED_INT ":" -> start_slice | ":" SIGNED_INT -> stop_slice | ":" -> full_slice %import common.CNAME %import common.SIGNED_INT %import common.WS %ignore WS """ @dataclass(frozen=True, slots=True) class AttrSegment: """Represents an attribute access in a model activation path.""" name: str @dataclass(frozen=True, slots=True) class IndexSegment: """Represents an index access in a model activation path.""" value: int @dataclass(frozen=True, slots=True) class SliceSegment: """Represents a slice access in a model activation path.""" start: int | None stop: int | None PathSegment = AttrSegment | IndexSegment | SliceSegment @dataclass(frozen=True, slots=True) class ActivationTarget: """Identifies one token-position slice of a traced activation tensor.""" activation_path: tuple[PathSegment, ...] position: int class _ActivationTargetTransformer(Transformer[Token, ActivationTarget | PathSegment | str | int]): """Converts parsed path syntax into structured activation targets.""" def start(self, items: list[PathSegment | int]) -> ActivationTarget: """Builds the final target object used by save, subset, and apply.""" *activation_path, position = items if not activation_path: raise ValueError("Activation path must include at least one segment before the position.") if not isinstance(position, int): raise TypeError(f"Expected final position to be an integer, got {type(position)!r}") return ActivationTarget(tuple(activation_path), position) def attr(self, items: list[str]) -> AttrSegment: """Turns a parsed attribute name into a path segment.""" return AttrSegment(items[0]) def index(self, items: list[int]) -> IndexSegment: """Turns a parsed bracketed integer into a path segment.""" return IndexSegment(items[0]) def slice(self, items: list[tuple[int | None, int | None]]) -> SliceSegment: """Turns a parsed bracketed slice into a path segment.""" start, stop = items[0] return SliceSegment(start, stop) def bounded_slice(self, items: list[int]) -> tuple[int | None, int | None]: """Parses a slice with both explicit bounds.""" return (items[0], items[1]) def start_slice(self, items: list[int]) -> tuple[int | None, int | None]: """Parses a slice with only an explicit start bound.""" return (items[0], None) def stop_slice(self, items: list[int]) -> tuple[int | None, int | None]: """Parses a slice with only an explicit stop bound.""" return (None, items[0]) def full_slice(self, _: list[int]) -> tuple[int | None, int | None]: """Parses a slice that spans the entire container.""" return (None, None) def CNAME(self, token: Token) -> str: """Preserves parsed attribute names as plain strings.""" return str(token) def SIGNED_INT(self, token: Token) -> int: """Parses integer literals so positions and indices stay numeric.""" return int(token) _PATH_PARSER = Lark(GRAMMAR, parser="lalr", transformer=_ActivationTargetTransformer()) class _SavedActivationBatchSlicer: """Provides bracket-based batch slicing for saved activation batches.""" def __init__(self, saved_activation: "SavedActivation") -> None: """Binds batch-slice syntax to one SavedActivation instance.""" self._saved_activation = saved_activation def __getitem__(self, batch_slices: int | range | slice | tuple[int | range | slice, ...]) -> "SavedActivation": """Builds a sliced SavedActivation from one or more batch selectors.""" selectors = batch_slices if isinstance(batch_slices, tuple) else (batch_slices,) ranges = [self._saved_activation._normalize_batch_slice(selector) for selector in selectors] sliced_targets = { target: self._saved_activation._slice_value(value, ranges) for target, value in self._saved_activation._targets.items() } return self._saved_activation._from_targets(sliced_targets) class SavedActivation: """Stores sparse activation slices and can replay them into later traces. Paths use the model's real attribute/index structure, with the final bracket selecting the token position to save or patch. For example: - ``save_activations(model, ["model.transformer.h[2].output[10]"])`` captures layer 2 output at token 10 - ``save_activations(model, ["model.transformer.h[3].output[-1]"])`` captures the last token from layer 3 - ``save_activations(model, ["model.transformer.h[:].output[10]"])`` captures token 10 at every layer in a ModuleList-like container ``save_activations`` captures activation values, while ``save`` registers a ``SavedActivation`` object itself with NNSight so it survives trace exit. After capture, the same structure is exposed under ``values``, so ``saved.values.transformer.h[2].output[10]`` returns the stored tensor. """ def __init__( self, values: dict[Any, Any] | None = None, targets: dict[ActivationTarget, Any] | None = None, ) -> None: """Creates a saved-activation tree from nested values or explicit targets.""" values = {} if values is None else values extracted_targets = targets if targets is not None else self._extract_targets(values) self._targets = self._canonicalize_targets(extracted_targets) self.values = SavedActivationValues(values) self.slice = _SavedActivationBatchSlicer(self) def save(self) -> "SavedActivation": """Registers this SavedActivation object with NNSight inside an active trace.""" self._register_object_save(self) return self @classmethod def from_pairs(cls, *pairs: tuple[str, Any]) -> "SavedActivation": """Builds a SavedActivation directly from explicit path-value pairs.""" targets: dict[ActivationTarget, Any] = {} for path, value in pairs: target = cls.parse_path(path) if target in targets: raise ValueError(f"SavedActivation.from_pairs got duplicate path: {path}") targets[target] = value return cls._from_targets(targets) def subset(self, activation_paths: list[str]) -> "SavedActivation": """Keeps only selected saved paths so patches can be focused or composed.""" requested_targets = [self.parse_path(path) for path in activation_paths] subset_targets: dict[ActivationTarget, Any] = {} for target in requested_targets: try: subset_targets[target] = self._targets[target] except KeyError as exc: raise KeyError(f"Unknown saved activation path: {self._format_target(target)}") from exc return self._from_targets(subset_targets) def keys(self) -> list[str]: """Returns the saved activation paths in the same syntax used by save_activations.""" return [self._format_target(target) for target in self._targets] def get(self, activation_path: str) -> Any: """Returns one saved tensor by its user-facing activation path.""" target = self.parse_path(activation_path) try: return self._targets[target] except KeyError as exc: raise KeyError(f"Unknown saved activation path: {activation_path}") from exc def mean(self) -> "SavedActivation": """Averages each saved tensor across the batch dimension while keeping one row.""" return self._from_targets( { target: self._saved_result(value.mean(dim=0, keepdim=True)) for target, value in self._targets.items() } ) def union(self, other: "SavedActivation") -> "SavedActivation": """Combines two disjoint saved activation sets into one patch collection.""" overlapping_targets = set(self._targets) & set(other._targets) if overlapping_targets: overlapping_paths = ", ".join(self._format_target(target) for target in overlapping_targets) raise ValueError(f"SavedActivation.union requires disjoint keys, got overlap: {overlapping_paths}") return self._from_targets({**self._targets, **other._targets}) def apply(self, model: Any) -> None: """Writes the stored activations into the current traced run of the model.""" for key, _current_value, update in updates(model, self.keys()): update(self.get(key)) def __add__(self, other: "SavedActivation") -> "SavedActivation": """Combines two saved activation sets elementwise when they cover the same paths.""" return self._binary_op(other, lambda left, right: left + right) def __sub__(self, other: "SavedActivation") -> "SavedActivation": """Subtracts one saved activation set from another when they cover the same paths.""" return self._binary_op(other, lambda left, right: left - right) def __mul__(self, scalar: float) -> "SavedActivation": """Scales every saved tensor by the same numeric factor.""" if not isinstance(scalar, int | float): return NotImplemented return self._from_targets( { target: self._saved_result(value * scalar) for target, value in self._targets.items() } ) def __rmul__(self, scalar: float) -> "SavedActivation": """Allows scalar multiplication with the scalar on the left-hand side.""" return self * scalar def _binary_op( self, other: "SavedActivation", op: Callable[[Any, Any], Any], ) -> "SavedActivation": """Applies an elementwise operation across matching saved activation sets.""" self._validate_matching_keys(other) return self._from_targets( { target: self._saved_result(op(self._targets[target], other._targets[target])) for target in self._targets } ) @classmethod def _from_targets(cls, targets: dict[ActivationTarget, Any]) -> "SavedActivation": """Builds the public tree view from the flat target-to-tensor mapping.""" result = cls(cls._build_values_from_targets(targets), dict(targets)) cls._register_object_save(result) return result @classmethod def _build_values_from_targets(cls, targets: dict[ActivationTarget, Any]) -> dict[Any, Any]: """Builds the nested values tree from a flat target-to-value mapping.""" values: dict[Any, Any] = {} for target, value in targets.items(): cursor = values for segment in target.activation_path: key = segment.name if isinstance(segment, AttrSegment) else segment.value cursor = cursor.setdefault(key, {}) cursor[target.position] = value return values @staticmethod def parse_path(path: str) -> ActivationTarget: """Parses user-facing model paths into the internal target representation.""" return _PATH_PARSER.parse(path) @staticmethod def _resolve_activation(model: Any, activation_path: tuple[PathSegment, ...]) -> Any: """Finds the live NNSight object identified by a parsed activation path.""" current: Any = model for segment in activation_path: current = SavedActivation._apply_concrete_segment(current, segment) return current @staticmethod def _validate_activation_shape(activation: Any) -> None: """Enforces the batch-sequence-hidden convention expected by this abstraction.""" if len(activation.shape) != 3: raise ValueError( "Saved activation paths must resolve to a tensor with shape " "(batch, sequence, hidden) before the final position index." ) @classmethod def _extract_targets( cls, values: dict[Any, Any], prefix: tuple[PathSegment, ...] = (), ) -> dict[ActivationTarget, Any]: """Reconstructs the flat target mapping from the nested public tree form.""" targets: dict[ActivationTarget, Any] = {} for key, value in values.items(): if isinstance(value, dict): next_segment: PathSegment if isinstance(key, str): next_segment = AttrSegment(key) elif isinstance(key, int): next_segment = IndexSegment(key) else: raise TypeError(f"Unsupported key type in saved activation tree: {type(key)!r}") targets.update(cls._extract_targets(value, prefix + (next_segment,))) continue if not isinstance(key, int): raise TypeError("Saved activation leaves must be keyed by token position integers.") targets[ActivationTarget(prefix, key)] = value return targets def _validate_matching_keys(self, other: "SavedActivation") -> None: """Ensures elementwise operations only happen across identical saved paths.""" if tuple(self._targets) != tuple(other._targets): raise ValueError("SavedActivation objects must have the same keys for arithmetic.") @staticmethod def _saved_result(value: Any) -> Any: """Preserves derived activation values after the trace by saving proxies when needed.""" save = getattr(value, "save", None) if Globals.stack > 0 and callable(save): return save() return value @staticmethod def _register_object_save(value: "SavedActivation") -> None: """Registers a newly created SavedActivation with NNSight when inside a trace.""" if Globals.stack > 0: nnsight_save(value) @classmethod def _canonicalize_targets(cls, targets: dict[ActivationTarget, Any]) -> dict[ActivationTarget, Any]: """Normalizes target order so every SavedActivation iterates in graph-safe model order.""" return dict(sorted(targets.items(), key=lambda item: cls._target_order_key(item[0]))) @staticmethod def _slice_value(value: Any, ranges: list[range]) -> Any: """Slices saved batch rows, concatenating multiple ranges when requested.""" pieces = [value[batch_range] for batch_range in ranges] if len(pieces) == 1: return SavedActivation._saved_result(pieces[0]) return SavedActivation._saved_result(torch.cat(pieces, dim=0)) @staticmethod def _normalize_batch_slice(batch_slice: int | range | slice) -> range: """Converts supported batch selectors into a common range representation.""" if isinstance(batch_slice, int): return range(batch_slice, batch_slice + 1) if isinstance(batch_slice, range): return batch_slice if isinstance(batch_slice, slice): if batch_slice.step not in (None, 1): raise ValueError("slice does not support slice steps other than 1.") if batch_slice.start is None or batch_slice.stop is None: raise ValueError("slice requires explicit start and stop values.") return range(batch_slice.start, batch_slice.stop) raise TypeError("slice accepts only ints, ranges, or slice objects.") @staticmethod def _format_target(target: ActivationTarget) -> str: """Turns an internal target back into the user-facing path syntax.""" parts = ["model"] for segment in target.activation_path: if isinstance(segment, AttrSegment): parts.append(f".{segment.name}") elif isinstance(segment, IndexSegment): parts.append(f"[{segment.value}]") else: start = "" if segment.start is None else str(segment.start) stop = "" if segment.stop is None else str(segment.stop) parts.append(f"[{start}:{stop}]") parts.append(f"[{target.position}]") return "".join(parts) @classmethod def _target_order_key(cls, target: ActivationTarget) -> tuple[Any, ...]: """Orders concrete targets by model traversal path, then by token position.""" path_key: list[tuple[int, Any]] = [] for segment in target.activation_path: if isinstance(segment, AttrSegment): path_key.append((0, segment.name)) elif isinstance(segment, IndexSegment): path_key.append((1, segment.value)) else: raise ValueError("Targets must be concrete before access ordering is computed.") return (*path_key, (2, target.position)) @classmethod def _expand_target( cls, model: Any, target: ActivationTarget, ) -> list[ActivationTarget]: """Expands slice segments into concrete index paths before saving activations.""" return cls._expand_segments( cls._static_structure(model), tuple(), target.activation_path, target.position, ) @classmethod def _expand_segments( cls, current: Any, prefix: tuple[PathSegment, ...], remaining: tuple[PathSegment, ...], position: int, ) -> list[ActivationTarget]: """Recursively expands one parsed activation path into concrete saved targets.""" if not remaining: return [ActivationTarget(prefix, position)] if not cls._contains_slice(remaining): return [ActivationTarget(prefix + remaining, position)] segment, *rest = remaining concrete_targets: list[ActivationTarget] = [] for next_current, concrete_segment in cls._expand_segment(current, segment): concrete_targets.extend( cls._expand_segments(next_current, prefix + (concrete_segment,), tuple(rest), position) ) return concrete_targets @staticmethod def _contains_slice(segments: tuple[PathSegment, ...]) -> bool: """Checks whether a remaining path still needs ModuleList slice expansion.""" return any(isinstance(segment, SliceSegment) for segment in segments) @staticmethod def _normalize_path_slice(segment: SliceSegment, length: int) -> range: """Converts a parsed path slice into concrete indices for a module container.""" return range(*slice(segment.start, segment.stop).indices(length)) @staticmethod def _static_structure(value: Any) -> Any: """Returns the underlying module tree used only for slice expansion.""" return getattr(value, "_module", value) @staticmethod def _apply_concrete_segment(current: Any, segment: PathSegment) -> Any: """Applies one non-slice path segment to either a live envoy or static module.""" if isinstance(segment, AttrSegment): return getattr(current, segment.name) if isinstance(segment, IndexSegment): return current[segment.value] raise ValueError("Slice segments must be expanded before resolving a concrete activation.") @classmethod def _expand_segment( cls, current: Any, segment: PathSegment, ) -> list[tuple[Any, AttrSegment | IndexSegment]]: """Expands one path segment into concrete child accesses.""" if isinstance(segment, SliceSegment): static_current = cls._static_structure(current) if not hasattr(static_current, "__len__") or not hasattr(static_current, "__getitem__"): raise TypeError("Slice syntax is only supported on indexable module containers.") return [ (static_current[index], IndexSegment(index)) for index in cls._normalize_path_slice(segment, len(static_current)) ] next_current = cls._apply_concrete_segment(current, segment) return [(cls._static_structure(next_current), segment)] def save_activations(model: Any, activation_paths: list[str]) -> SavedActivation: """Captures requested activations inside an already-active NNSight trace/invoke context.""" parsed_targets = [SavedActivation.parse_path(path) for path in activation_paths] concrete_targets = [ concrete_target for parsed_target in parsed_targets for concrete_target in SavedActivation._expand_target(model, parsed_target) ] saved_values: dict[ActivationTarget, Any] = {} for target in sorted(concrete_targets, key=SavedActivation._target_order_key): activation = SavedActivation._resolve_activation(model, target.activation_path) SavedActivation._validate_activation_shape(activation) saved_values[target] = activation[:, target.position, :].save() return SavedActivation._from_targets(saved_values) def updates( model: Any, activation_paths: list[str], ) -> Iterator[tuple[str, Any, Callable[[Any], None]]]: """Returns layer-major live update handles for concrete saved-activation paths.""" parsed_targets = [SavedActivation.parse_path(path) for path in activation_paths] for path, target in zip(activation_paths, parsed_targets, strict=True): if SavedActivation._contains_slice(target.activation_path): raise ValueError(f"updates() requires concrete keys, got slice path: {path}") if len(set(parsed_targets)) != len(parsed_targets): raise ValueError("updates() requires unique keys.") canonical_targets = sorted(parsed_targets, key=SavedActivation._target_order_key) if tuple(parsed_targets) != tuple(canonical_targets): raise ValueError("updates() requires keys in canonical layer-major order.") for key, target in zip(activation_paths, parsed_targets, strict=True): activation = SavedActivation._resolve_activation(model, target.activation_path) SavedActivation._validate_activation_shape(activation) current_value = activation[:, target.position, :] def update(new_value: Any, activation: Any = activation, position: int = target.position) -> None: activation[:, position, :] = new_value yield (key, current_value, update)