Repositories / more_nnsight.git
more_nnsight.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
@@ -374,6 +374,19 @@ class SavedActivation: return "".join(parts) @classmethod + def _target_order_key(cls, target: ActivationTarget) -> tuple[Any, ...]: + """Orders concrete targets by model traversal path, then by token position.""" + path_key: list[tuple[int, Any]] = [] + for segment in target.activation_path: + if isinstance(segment, AttrSegment): + path_key.append((0, segment.name)) + elif isinstance(segment, IndexSegment): + path_key.append((1, segment.value)) + else: + raise ValueError("Targets must be concrete before access ordering is computed.") + return (*path_key, (2, target.position)) + + @classmethod def _expand_target( cls, model: Any, @@ -457,10 +470,15 @@ class SavedActivation: def save_activations(model: Any, activation_paths: list[str]) -> SavedActivation: """Captures requested activations inside an already-active NNSight trace/invoke context.""" parsed_targets = [SavedActivation.parse_path(path) for path in activation_paths] + concrete_targets = [ + concrete_target + for parsed_target in parsed_targets + for concrete_target in SavedActivation._expand_target(model, parsed_target) + ] + saved_values: dict[ActivationTarget, Any] = {} - for parsed_target in parsed_targets: - for target in SavedActivation._expand_target(model, parsed_target): - activation = SavedActivation._resolve_activation(model, target.activation_path) - SavedActivation._validate_activation_shape(activation) - saved_values[target] = activation[:, target.position, :].save() + for target in sorted(concrete_targets, key=SavedActivation._target_order_key): + activation = SavedActivation._resolve_activation(model, target.activation_path) + SavedActivation._validate_activation_shape(activation) + saved_values[target] = activation[:, target.position, :].save() return SavedActivation._from_targets(saved_values)
@@ -62,6 +62,35 @@ def test_save_with_gpt2_layer_slice_matches_direct_nnsight_loop(gpt2_model: Lang assert torch.allclose(saved.get(f"model.transformer.h[{layer}].output[10]"), direct_tensor) +# Saving several token positions across all layers must still traverse layers in graph order. +def test_save_with_gpt2_layer_slice_and_multiple_positions_matches_direct_loop( + gpt2_model: LanguageModel, +) -> None: + prompts = [ + "One two three four five six seven eight nine ten eleven twelve.", + "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.", + ] + positions = [-4, -3, -2, -1] + paths = [f"model.transformer.h[:].output[{position}]" for position in positions] + + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(gpt2_model, paths) + + direct: dict[str, torch.Tensor] = {} + with gpt2_model.trace() as tracer: + with tracer.invoke(prompts): + for layer in range(len(gpt2_model.transformer.h)): + for position in positions: + direct[f"model.transformer.h[{layer}].output[{position}]"] = ( + gpt2_model.transformer.h[layer].output[:, position, :].save() + ) + + assert saved.keys() == list(direct) + for path, direct_tensor in direct.items(): + assert torch.allclose(saved.get(path), direct_tensor) + + # Unsaved GPT-2 paths should still error even when nearby paths were captured. def test_unsaved_paths_error_with_gpt2_saved_activation(gpt2_model: LanguageModel) -> None: prompts = [