Repositories / more_nnsight.git
more_nnsight.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
@@ -161,7 +161,8 @@ class SavedActivation: ) -> None: """Creates a saved-activation tree from nested values or explicit targets.""" values = {} if values is None else values - self._targets = targets if targets is not None else self._extract_targets(values) + extracted_targets = targets if targets is not None else self._extract_targets(values) + self._targets = self._canonicalize_targets(extracted_targets) self.values = SavedActivationValues(values) self.slice = _SavedActivationBatchSlicer(self) @@ -334,6 +335,11 @@ class SavedActivation: if Globals.stack > 0: nnsight_save(value) + @classmethod + def _canonicalize_targets(cls, targets: dict[ActivationTarget, Any]) -> dict[ActivationTarget, Any]: + """Normalizes target order so every SavedActivation iterates in graph-safe model order.""" + return dict(sorted(targets.items(), key=lambda item: cls._target_order_key(item[0]))) + @staticmethod def _slice_value(value: Any, ranges: list[range]) -> Any: """Slices saved batch rows, concatenating multiple ranges when requested."""
@@ -131,6 +131,47 @@ def test_apply_with_gpt2_patches_current_run(gpt2_model: LanguageModel) -> None: assert not torch.equal(before_patch, after_patch) +# Applying multiple patches must replay them in model order even if the SavedActivation keys are reordered. +def test_apply_with_gpt2_reordered_patch_set_preserves_nnsight_access_order( + gpt2_model: LanguageModel, +) -> None: + clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to" + corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" + patch_paths = [ + "model.transformer.h[2].output[9]", + "model.transformer.h[3].output[8]", + ] + + with gpt2_model.trace() as tracer: + with tracer.invoke([clean_prompt]): + patch = save_activations(gpt2_model, patch_paths) + + reordered_patch = patch.subset( + [ + "model.transformer.h[3].output[8]", + "model.transformer.h[2].output[9]", + ] + ) + + assert reordered_patch.keys() == [ + "model.transformer.h[2].output[9]", + "model.transformer.h[3].output[8]", + ] + + with gpt2_model.trace() as tracer: + with tracer.invoke([corrupted_prompt]): + reordered_patch.apply(gpt2_model) + patched_logits = gpt2_model.lm_head.output.save() + + with gpt2_model.trace() as tracer: + with tracer.invoke([corrupted_prompt]): + gpt2_model.transformer.h[2].output[:, 9, :] = patch.get("model.transformer.h[2].output[9]") + gpt2_model.transformer.h[3].output[:, 8, :] = patch.get("model.transformer.h[3].output[8]") + direct_logits = gpt2_model.lm_head.output.save() + + assert torch.allclose(patched_logits, direct_logits) + + # `mean()` should reduce a real saved GPT-2 batch to a single-row steering vector. def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None: prompts = [