Repositories / more_nnsight.git
more_nnsight.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
@@ -14,8 +14,14 @@ from ._saved_activation_values import SavedActivationValues GRAMMAR = r""" start: "model" segment+ "[" SIGNED_INT "]" -segment: "." CNAME -> attr - | "[" SIGNED_INT "]" -> index +segment: "." CNAME -> attr + | "[" SIGNED_INT "]" -> index + | "[" slice_expr "]" -> slice + +slice_expr: SIGNED_INT ":" SIGNED_INT -> bounded_slice + | SIGNED_INT ":" -> start_slice + | ":" SIGNED_INT -> stop_slice + | ":" -> full_slice %import common.CNAME %import common.SIGNED_INT @@ -38,7 +44,15 @@ class IndexSegment: value: int -PathSegment = AttrSegment | IndexSegment +@dataclass(frozen=True, slots=True) +class SliceSegment: + """Represents a slice access in a model activation path.""" + + start: int | None + stop: int | None + + +PathSegment = AttrSegment | IndexSegment | SliceSegment @dataclass(frozen=True, slots=True) @@ -69,6 +83,27 @@ class _ActivationTargetTransformer(Transformer[Token, ActivationTarget | PathSeg """Turns a parsed bracketed integer into a path segment.""" return IndexSegment(items[0]) + def slice(self, items: list[tuple[int | None, int | None]]) -> SliceSegment: + """Turns a parsed bracketed slice into a path segment.""" + start, stop = items[0] + return SliceSegment(start, stop) + + def bounded_slice(self, items: list[int]) -> tuple[int | None, int | None]: + """Parses a slice with both explicit bounds.""" + return (items[0], items[1]) + + def start_slice(self, items: list[int]) -> tuple[int | None, int | None]: + """Parses a slice with only an explicit start bound.""" + return (items[0], None) + + def stop_slice(self, items: list[int]) -> tuple[int | None, int | None]: + """Parses a slice with only an explicit stop bound.""" + return (None, items[0]) + + def full_slice(self, _: list[int]) -> tuple[int | None, int | None]: + """Parses a slice that spans the entire container.""" + return (None, None) + def CNAME(self, token: Token) -> str: """Preserves parsed attribute names as plain strings.""" return str(token) @@ -105,10 +140,12 @@ class SavedActivation: Paths use the model's real attribute/index structure, with the final bracket selecting the token position to save or patch. For example: - - ``SavedActivation.save_activations(model, ["model.transformer.h[2].output[10]"])`` + - ``save_activations(model, ["model.transformer.h[2].output[10]"])`` captures layer 2 output at token 10 - - ``SavedActivation.save_activations(model, ["model.transformer.h[3].output[-1]"])`` + - ``save_activations(model, ["model.transformer.h[3].output[-1]"])`` captures the last token from layer 3 + - ``save_activations(model, ["model.transformer.h[:].output[10]"])`` + captures token 10 at every layer in a ModuleList-like container ``save_activations`` captures activation values, while ``save`` registers a ``SavedActivation`` object itself with NNSight so it survives trace exit. @@ -167,7 +204,7 @@ class SavedActivation: } ) - def apply(self, model: LanguageModel) -> None: + def apply(self, model: Any) -> None: """Writes the stored activations into the current traced run of the model.""" for target, saved_tensor in self._targets.items(): activation = self._resolve_activation(model, target.activation_path) @@ -236,14 +273,11 @@ class SavedActivation: return _PATH_PARSER.parse(path) @staticmethod - def _resolve_activation(model: LanguageModel, activation_path: tuple[PathSegment, ...]) -> Any: + def _resolve_activation(model: Any, activation_path: tuple[PathSegment, ...]) -> Any: """Finds the live NNSight object identified by a parsed activation path.""" current: Any = model for segment in activation_path: - if isinstance(segment, AttrSegment): - current = getattr(current, segment.name) - else: - current = current[segment.value] + current = SavedActivation._apply_concrete_segment(current, segment) return current @staticmethod @@ -330,18 +364,103 @@ class SavedActivation: for segment in target.activation_path: if isinstance(segment, AttrSegment): parts.append(f".{segment.name}") - else: + elif isinstance(segment, IndexSegment): parts.append(f"[{segment.value}]") + else: + start = "" if segment.start is None else str(segment.start) + stop = "" if segment.stop is None else str(segment.stop) + parts.append(f"[{start}:{stop}]") parts.append(f"[{target.position}]") return "".join(parts) + @classmethod + def _expand_target( + cls, + model: Any, + target: ActivationTarget, + ) -> list[ActivationTarget]: + """Expands slice segments into concrete index paths before saving activations.""" + return cls._expand_segments( + cls._static_structure(model), + tuple(), + target.activation_path, + target.position, + ) + + @classmethod + def _expand_segments( + cls, + current: Any, + prefix: tuple[PathSegment, ...], + remaining: tuple[PathSegment, ...], + position: int, + ) -> list[ActivationTarget]: + """Recursively expands one parsed activation path into concrete saved targets.""" + if not remaining: + return [ActivationTarget(prefix, position)] + + if not cls._contains_slice(remaining): + return [ActivationTarget(prefix + remaining, position)] + + segment, *rest = remaining + concrete_targets: list[ActivationTarget] = [] + for next_current, concrete_segment in cls._expand_segment(current, segment): + concrete_targets.extend( + cls._expand_segments(next_current, prefix + (concrete_segment,), tuple(rest), position) + ) + return concrete_targets + + @staticmethod + def _contains_slice(segments: tuple[PathSegment, ...]) -> bool: + """Checks whether a remaining path still needs ModuleList slice expansion.""" + return any(isinstance(segment, SliceSegment) for segment in segments) -def save_activations(model: LanguageModel, activation_paths: list[str]) -> SavedActivation: + @staticmethod + def _normalize_path_slice(segment: SliceSegment, length: int) -> range: + """Converts a parsed path slice into concrete indices for a module container.""" + return range(*slice(segment.start, segment.stop).indices(length)) + + @staticmethod + def _static_structure(value: Any) -> Any: + """Returns the underlying module tree used only for slice expansion.""" + return getattr(value, "_module", value) + + @staticmethod + def _apply_concrete_segment(current: Any, segment: PathSegment) -> Any: + """Applies one non-slice path segment to either a live envoy or static module.""" + if isinstance(segment, AttrSegment): + return getattr(current, segment.name) + if isinstance(segment, IndexSegment): + return current[segment.value] + raise ValueError("Slice segments must be expanded before resolving a concrete activation.") + + @classmethod + def _expand_segment( + cls, + current: Any, + segment: PathSegment, + ) -> list[tuple[Any, AttrSegment | IndexSegment]]: + """Expands one path segment into concrete child accesses.""" + if isinstance(segment, SliceSegment): + static_current = cls._static_structure(current) + if not hasattr(static_current, "__len__") or not hasattr(static_current, "__getitem__"): + raise TypeError("Slice syntax is only supported on indexable module containers.") + return [ + (static_current[index], IndexSegment(index)) + for index in cls._normalize_path_slice(segment, len(static_current)) + ] + + next_current = cls._apply_concrete_segment(current, segment) + return [(cls._static_structure(next_current), segment)] + + +def save_activations(model: Any, activation_paths: list[str]) -> SavedActivation: """Captures requested activations inside an already-active NNSight trace/invoke context.""" - targets = [SavedActivation.parse_path(path) for path in activation_paths] + parsed_targets = [SavedActivation.parse_path(path) for path in activation_paths] saved_values: dict[ActivationTarget, Any] = {} - for target in targets: - activation = SavedActivation._resolve_activation(model, target.activation_path) - SavedActivation._validate_activation_shape(activation) - saved_values[target] = activation[:, target.position, :].save() + for parsed_target in parsed_targets: + for target in SavedActivation._expand_target(model, parsed_target): + activation = SavedActivation._resolve_activation(model, target.activation_path) + SavedActivation._validate_activation_shape(activation) + saved_values[target] = activation[:, target.position, :].save() return SavedActivation._from_targets(saved_values)
@@ -6,6 +6,8 @@ from pathlib import Path import pytest import torch from nnsight import LanguageModel +from nnsight.modeling.base import NNsight +from transformers import AutoConfig, AutoModelForCausalLM from more_nnsight import SavedActivation, save_activations @@ -38,6 +40,28 @@ def test_save_with_gpt2_returns_saved_tensors_at_requested_paths(gpt2_model: Lan assert torch.allclose(saved.get("model.transformer.h[3].output[-1]"), direct_layer_3) +# Slice syntax over GPT-2's transformer block list should match a direct per-layer NNSight loop. +def test_save_with_gpt2_layer_slice_matches_direct_nnsight_loop(gpt2_model: LanguageModel) -> None: + prompts = [ + "One two three four five six seven eight nine ten eleven twelve.", + "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.", + ] + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(gpt2_model, ["model.transformer.h[:].output[10]"]) + direct: list[torch.Tensor] = [] + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + direct.extend( + gpt2_model.transformer.h[layer].output[:, 10, :].save() + for layer in range(len(gpt2_model.transformer.h)) + ) + + assert len(saved.keys()) == len(direct) + for layer, direct_tensor in enumerate(direct): + assert torch.allclose(saved.get(f"model.transformer.h[{layer}].output[10]"), direct_tensor) + + # Unsaved GPT-2 paths should still error even when nearby paths were captured. def test_unsaved_paths_error_with_gpt2_saved_activation(gpt2_model: LanguageModel) -> None: prompts = [ @@ -138,3 +162,31 @@ def test_mean_inside_trace_saves_reduced_result_on_device(gpt2_model: LanguageMo assert saved.get(path).shape == (2, 768) assert mean_saved.get(path).shape == (1, 768) assert torch.allclose(mean_saved.get(path), saved.get(path).mean(dim=0, keepdim=True)) + + +# A randomly initialized non-GPT2 architecture should work as long as layers live in a ModuleList. +def test_save_with_random_init_qwen_layer_slice_matches_direct_loop() -> None: + cfg = AutoConfig.from_pretrained("Qwen/Qwen3-4B-Instruct-2507") + cfg.hidden_size = 64 + cfg.intermediate_size = 128 + cfg.num_hidden_layers = 3 + cfg.num_attention_heads = 4 + cfg.num_key_value_heads = 4 + cfg.vocab_size = 256 + + model = NNsight(AutoModelForCausalLM.from_config(cfg)) + input_ids = torch.randint(0, cfg.vocab_size, (2, 6)) + attention_mask = torch.ones_like(input_ids) + + with model.trace(input_ids=input_ids, attention_mask=attention_mask): + saved = save_activations(model, ["model.model.layers[:].output[2]"]) + direct: list[torch.Tensor] = [] + with model.trace(input_ids=input_ids, attention_mask=attention_mask): + direct.extend( + model.model.layers[layer].output[:, 2, :].save() + for layer in range(cfg.num_hidden_layers) + ) + + assert len(saved.keys()) == cfg.num_hidden_layers + for layer, direct_tensor in enumerate(direct): + assert torch.allclose(saved.get(f"model.model.layers[{layer}].output[2]"), direct_tensor)