Repositories / more_nnsight.git
more_nnsight.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
@@ -205,6 +205,15 @@ class SavedActivation: } ) + def union(self, other: "SavedActivation") -> "SavedActivation": + """Combines two disjoint saved activation sets into one patch collection.""" + overlapping_targets = set(self._targets) & set(other._targets) + if overlapping_targets: + overlapping_paths = ", ".join(self._format_target(target) for target in overlapping_targets) + raise ValueError(f"SavedActivation.union requires disjoint keys, got overlap: {overlapping_paths}") + + return self._from_targets({**self._targets, **other._targets}) + def apply(self, model: Any) -> None: """Writes the stored activations into the current traced run of the model.""" for target, saved_tensor in self._targets.items():
@@ -122,3 +122,26 @@ def test_add_and_subtract_require_matching_keys() -> None: with pytest.raises(ValueError, match="same keys"): _ = left + other + + +# Union should merge disjoint saved paths and reject overlapping ones. +def test_union_merges_disjoint_paths_and_rejects_overlap() -> None: + left_tensor = torch.ones(1, 2) + right_tensor = 2 * torch.ones(1, 2) + overlap_tensor = 3 * torch.ones(1, 2) + + left = SavedActivation({"transformer": {"h": {2: {"output": {10: left_tensor}}}}}) + right = SavedActivation({"transformer": {"h": {3: {"output": {10: right_tensor}}}}}) + overlap = SavedActivation({"transformer": {"h": {2: {"output": {10: overlap_tensor}}}}}) + + union = left.union(right) + + assert union.keys() == [ + "model.transformer.h[2].output[10]", + "model.transformer.h[3].output[10]", + ] + assert union.get("model.transformer.h[2].output[10]") is left_tensor + assert union.get("model.transformer.h[3].output[10]") is right_tensor + + with pytest.raises(ValueError, match="disjoint keys"): + _ = left.union(overlap)
@@ -172,6 +172,37 @@ def test_apply_with_gpt2_reordered_patch_set_preserves_nnsight_access_order( assert torch.allclose(patched_logits, direct_logits) +# Union should combine disjoint GPT-2 patches so they replay like one direct multi-assignment patch. +def test_union_with_gpt2_combines_disjoint_patches(gpt2_model: LanguageModel) -> None: + clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to" + corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" + + with gpt2_model.trace() as tracer: + with tracer.invoke([clean_prompt]): + layer_2_patch = save_activations(gpt2_model, ["model.transformer.h[2].output[9]"]) + layer_3_patch = save_activations(gpt2_model, ["model.transformer.h[3].output[8]"]) + + union_patch = layer_2_patch.union(layer_3_patch) + + assert union_patch.keys() == [ + "model.transformer.h[2].output[9]", + "model.transformer.h[3].output[8]", + ] + + with gpt2_model.trace() as tracer: + with tracer.invoke([corrupted_prompt]): + union_patch.apply(gpt2_model) + union_logits = gpt2_model.lm_head.output.save() + + with gpt2_model.trace() as tracer: + with tracer.invoke([corrupted_prompt]): + gpt2_model.transformer.h[2].output[:, 9, :] = layer_2_patch.get("model.transformer.h[2].output[9]") + gpt2_model.transformer.h[3].output[:, 8, :] = layer_3_patch.get("model.transformer.h[3].output[8]") + direct_logits = gpt2_model.lm_head.output.save() + + assert torch.allclose(union_logits, direct_logits) + + # `mean()` should reduce a real saved GPT-2 batch to a single-row steering vector. def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None: prompts = [