Repositories / more_nnsight.git

tests/test_saved_activation_gpt2.py

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch
16685 bytes · 775ab62478a8
from __future__ import annotations import os from pathlib import Path import pytest import torch from nnsight import LanguageModel from nnsight.modeling.base import NNsight from transformers import AutoConfig, AutoModelForCausalLM from more_nnsight import SavedActivation, save_activations, updates @pytest.fixture(scope="module") def gpt2_model() -> LanguageModel: os.environ.setdefault("HF_HOME", str(Path.home() / "models")) torch.cuda.memory.set_per_process_memory_fraction(0.8) return LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True) # Saving real GPT-2 activations should produce batch-by-hidden tensors at the requested paths. def test_save_with_gpt2_returns_saved_tensors_at_requested_paths(gpt2_model: LanguageModel) -> None: prompts = [ "One two three four five six seven eight nine ten eleven twelve.", "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.", ] with gpt2_model.trace() as tracer: with tracer.invoke(prompts): saved = save_activations( gpt2_model, ["model.transformer.h[2].output[10]", "model.transformer.h[3].output[-1]"], ) with gpt2_model.trace() as tracer: with tracer.invoke(prompts): direct_layer_2 = gpt2_model.transformer.h[2].output[:, 10, :].save() direct_layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save() assert torch.allclose(saved.get("model.transformer.h[2].output[10]"), direct_layer_2) assert torch.allclose(saved.get("model.transformer.h[3].output[-1]"), direct_layer_3) # Slice syntax over GPT-2's transformer block list should match a direct per-layer NNSight loop. def test_save_with_gpt2_layer_slice_matches_direct_nnsight_loop(gpt2_model: LanguageModel) -> None: prompts = [ "One two three four five six seven eight nine ten eleven twelve.", "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.", ] with gpt2_model.trace() as tracer: with tracer.invoke(prompts): saved = save_activations(gpt2_model, ["model.transformer.h[:].output[10]"]) direct: list[torch.Tensor] = [] with gpt2_model.trace() as tracer: with tracer.invoke(prompts): direct.extend( gpt2_model.transformer.h[layer].output[:, 10, :].save() for layer in range(len(gpt2_model.transformer.h)) ) assert len(saved.keys()) == len(direct) for layer, direct_tensor in enumerate(direct): assert torch.allclose(saved.get(f"model.transformer.h[{layer}].output[10]"), direct_tensor) # Saving several token positions across all layers must still traverse layers in graph order. def test_save_with_gpt2_layer_slice_and_multiple_positions_matches_direct_loop( gpt2_model: LanguageModel, ) -> None: prompts = [ "One two three four five six seven eight nine ten eleven twelve.", "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.", ] positions = [-4, -3, -2, -1] paths = [f"model.transformer.h[:].output[{position}]" for position in positions] with gpt2_model.trace() as tracer: with tracer.invoke(prompts): saved = save_activations(gpt2_model, paths) direct: dict[str, torch.Tensor] = {} with gpt2_model.trace() as tracer: with tracer.invoke(prompts): for layer in range(len(gpt2_model.transformer.h)): for position in positions: direct[f"model.transformer.h[{layer}].output[{position}]"] = ( gpt2_model.transformer.h[layer].output[:, position, :].save() ) assert saved.keys() == list(direct) for path, direct_tensor in direct.items(): assert torch.allclose(saved.get(path), direct_tensor) # Unsaved GPT-2 paths should still error even when nearby paths were captured. def test_unsaved_paths_error_with_gpt2_saved_activation(gpt2_model: LanguageModel) -> None: prompts = [ "One two three four five six seven eight nine ten eleven twelve.", "Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.", ] with gpt2_model.trace() as tracer: with tracer.invoke(prompts): saved = save_activations(gpt2_model, ["model.transformer.h[2].output[10]"]) with pytest.raises(KeyError, match="saved_activation.transformer.h\\[2\\].output"): _ = saved.values.transformer.h[2].output[9] with pytest.raises(KeyError, match="saved_activation.transformer.h"): _ = saved.values.transformer.h[3] # Applying a saved activation should overwrite the live activation slice during a later run. def test_apply_with_gpt2_patches_current_run(gpt2_model: LanguageModel) -> None: clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to" corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" with gpt2_model.trace() as tracer: with tracer.invoke([clean_prompt]): patch = save_activations(gpt2_model, ["model.transformer.h[2].output[9]"]) clean_tensor = patch.values.transformer.h[2].output[9] with gpt2_model.trace() as tracer: with tracer.invoke([corrupted_prompt]): before_patch = gpt2_model.transformer.h[2].output[:, 9, :].clone().save() patch.apply(gpt2_model) after_patch = gpt2_model.transformer.h[2].output[:, 9, :].save() assert before_patch.shape == (1, 768) assert after_patch.shape == (1, 768) assert torch.equal(after_patch, clean_tensor) assert not torch.equal(before_patch, after_patch) # Applying multiple patches must replay them in model order even if the SavedActivation keys are reordered. def test_apply_with_gpt2_reordered_patch_set_preserves_nnsight_access_order( gpt2_model: LanguageModel, ) -> None: clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to" corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" patch_paths = [ "model.transformer.h[2].output[9]", "model.transformer.h[3].output[8]", ] with gpt2_model.trace() as tracer: with tracer.invoke([clean_prompt]): patch = save_activations(gpt2_model, patch_paths) reordered_patch = patch.subset( [ "model.transformer.h[3].output[8]", "model.transformer.h[2].output[9]", ] ) assert reordered_patch.keys() == [ "model.transformer.h[2].output[9]", "model.transformer.h[3].output[8]", ] with gpt2_model.trace() as tracer: with tracer.invoke([corrupted_prompt]): reordered_patch.apply(gpt2_model) patched_logits = gpt2_model.lm_head.output.save() with gpt2_model.trace() as tracer: with tracer.invoke([corrupted_prompt]): gpt2_model.transformer.h[2].output[:, 9, :] = patch.get("model.transformer.h[2].output[9]") gpt2_model.transformer.h[3].output[:, 8, :] = patch.get("model.transformer.h[3].output[8]") direct_logits = gpt2_model.lm_head.output.save() assert torch.allclose(patched_logits, direct_logits) # Union should combine disjoint GPT-2 patches so they replay like one direct multi-assignment patch. def test_union_with_gpt2_combines_disjoint_patches(gpt2_model: LanguageModel) -> None: clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to" corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" with gpt2_model.trace() as tracer: with tracer.invoke([clean_prompt]): layer_2_patch = save_activations(gpt2_model, ["model.transformer.h[2].output[9]"]) layer_3_patch = save_activations(gpt2_model, ["model.transformer.h[3].output[8]"]) union_patch = layer_2_patch.union(layer_3_patch) assert union_patch.keys() == [ "model.transformer.h[2].output[9]", "model.transformer.h[3].output[8]", ] with gpt2_model.trace() as tracer: with tracer.invoke([corrupted_prompt]): union_patch.apply(gpt2_model) union_logits = gpt2_model.lm_head.output.save() with gpt2_model.trace() as tracer: with tracer.invoke([corrupted_prompt]): gpt2_model.transformer.h[2].output[:, 9, :] = layer_2_patch.get("model.transformer.h[2].output[9]") gpt2_model.transformer.h[3].output[:, 8, :] = layer_3_patch.get("model.transformer.h[3].output[8]") direct_logits = gpt2_model.lm_head.output.save() assert torch.allclose(union_logits, direct_logits) # Saving, reading, and updating a multi-layer last-token patch should work within one invoke. def test_updates_with_gpt2_supports_current_plus_scaled_saved_activation_in_single_invoke( gpt2_model: LanguageModel, ) -> None: prompt = "After John and Mary went to the store, John gave a bottle of milk to" patch_paths = [ "model.transformer.h[2].output[-1]", "model.transformer.h[3].output[-1]", ] with gpt2_model.trace() as tracer: with tracer.invoke([prompt]): saved = save_activations(gpt2_model, patch_paths) with gpt2_model.trace() as tracer: with tracer.invoke([prompt]): layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save() gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0]) layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save() gpt2_model.transformer.h[3].output[:, -1, :] = layer_3 + 2.0 * saved.get(patch_paths[1]) direct_logits = gpt2_model.lm_head.output.save() with gpt2_model.trace() as tracer: with tracer.invoke([prompt]): for key, value, update in updates(gpt2_model, saved.keys()): update(value + 2.0 * saved.get(key)) patched_logits = gpt2_model.lm_head.output.save() assert torch.allclose(patched_logits, direct_logits) def test_updates_with_gpt2_supports_key_specific_update_logic( gpt2_model: LanguageModel, ) -> None: prompt = "After John and Mary went to the store, John gave a bottle of milk to" patch_paths = [ "model.transformer.h[2].output[-1]", "model.transformer.h[3].output[-1]", ] with gpt2_model.trace() as tracer: with tracer.invoke([prompt]): saved = save_activations(gpt2_model, patch_paths) with gpt2_model.trace() as tracer: with tracer.invoke([prompt]): layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save() gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0]) layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save() gpt2_model.transformer.h[3].output[:, -1, :] = saved.get(patch_paths[1]) direct_logits = gpt2_model.lm_head.output.save() with gpt2_model.trace() as tracer: with tracer.invoke([prompt]): for key, value, update in updates(gpt2_model, saved.keys()): if key == "model.transformer.h[2].output[-1]": update(value + 2.0 * saved.get(key)) else: update(saved.get(key)) patched_logits = gpt2_model.lm_head.output.save() assert torch.allclose(patched_logits, direct_logits) def test_updates_with_gpt2_supports_key_specific_batch_reductions( gpt2_model: LanguageModel, ) -> None: prompts = [ "The movie was great and I felt happy.", "The dinner was excellent and I felt joyful.", "The trip was terrible and I felt sad.", ] patch_paths = [ "model.transformer.h[2].output[-1]", "model.transformer.h[3].output[-1]", ] with gpt2_model.trace() as tracer: with tracer.invoke(prompts): saved = save_activations(gpt2_model, patch_paths) with gpt2_model.trace() as tracer: with tracer.invoke(prompts): layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save() gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0]) gpt2_model.transformer.h[3].output[:, -1, :] = saved.get(patch_paths[1])[[0, 2]].mean( dim=0, keepdim=True ) direct_logits = gpt2_model.lm_head.output.save() with gpt2_model.trace() as tracer: with tracer.invoke(prompts): for key, value, update in updates(gpt2_model, saved.keys()): if key == "model.transformer.h[2].output[-1]": update(value + 2.0 * saved.get(key)) else: update(saved.get(key)[[0, 2]].mean(dim=0, keepdim=True)) patched_logits = gpt2_model.lm_head.output.save() assert torch.allclose(patched_logits, direct_logits) # `mean()` should reduce a real saved GPT-2 batch to a single-row steering vector. def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None: prompts = [ "The movie was great and I felt happy.", "The dinner was excellent and I felt joyful.", ] path = "model.transformer.h[5].output[-1]" with gpt2_model.trace() as tracer: with tracer.invoke(prompts): saved = save_activations(gpt2_model, [path]) with gpt2_model.trace() as tracer: with tracer.invoke(prompts): direct_mean = gpt2_model.transformer.h[5].output[:, -1, :].mean(dim=0, keepdim=True).save() assert saved.keys() == [path] assert saved.get(path).shape == (2, 768) mean_saved = saved.mean() assert mean_saved.get(path).shape == (1, 768) assert torch.allclose(mean_saved.get(path), direct_mean) # Slicing should let one GPT-2 save feed multiple downstream batch views. def test_slice_with_gpt2_selects_requested_batch_rows(gpt2_model: LanguageModel) -> None: prompts = [ "The movie was great and I felt happy.", "The dinner was excellent and I felt joyful.", "The trip was terrible and I felt sad.", ] path = "model.transformer.h[5].output[-1]" with gpt2_model.trace() as tracer: with tracer.invoke(prompts): saved = save_activations(gpt2_model, [path]) sliced = saved.slice[0, 2:3] with gpt2_model.trace() as tracer: with tracer.invoke(prompts): direct_slice = gpt2_model.transformer.h[5].output[[0, 2], -1, :].save() assert saved.get(path).shape == (3, 768) assert sliced.get(path).shape == (2, 768) assert torch.allclose(sliced.get(path), direct_slice) def test_mean_inside_trace_saves_reduced_result_on_device(gpt2_model: LanguageModel) -> None: """This guards the intended workflow where reduction happens before trace exit, avoiding saving the full batch when only the mean is needed.""" prompts = [ "The movie was great and I felt happy.", "The dinner was excellent and I felt joyful.", ] path = "model.transformer.h[5].output[-1]" with gpt2_model.trace() as tracer: with tracer.invoke(prompts): saved = save_activations(gpt2_model, [path]) mean_saved = saved.mean() assert saved.get(path).shape == (2, 768) assert mean_saved.get(path).shape == (1, 768) assert torch.allclose(mean_saved.get(path), saved.get(path).mean(dim=0, keepdim=True)) # A randomly initialized non-GPT2 architecture should work as long as layers live in a ModuleList. def test_save_with_random_init_qwen_layer_slice_matches_direct_loop() -> None: cfg = AutoConfig.from_pretrained("Qwen/Qwen3-4B-Instruct-2507") cfg.hidden_size = 64 cfg.intermediate_size = 128 cfg.num_hidden_layers = 3 cfg.num_attention_heads = 4 cfg.num_key_value_heads = 4 cfg.vocab_size = 256 model = NNsight(AutoModelForCausalLM.from_config(cfg)) input_ids = torch.randint(0, cfg.vocab_size, (2, 6)) attention_mask = torch.ones_like(input_ids) with model.trace(input_ids=input_ids, attention_mask=attention_mask): saved = save_activations(model, ["model.model.layers[:].output[2]"]) direct: list[torch.Tensor] = [] with model.trace(input_ids=input_ids, attention_mask=attention_mask): direct.extend( model.model.layers[layer].output[:, 2, :].save() for layer in range(cfg.num_hidden_layers) ) assert len(saved.keys()) == cfg.num_hidden_layers for layer, direct_tensor in enumerate(direct): assert torch.allclose(saved.get(f"model.model.layers[{layer}].output[2]"), direct_tensor)