Repositories / more_nnsight.git
tests/test_saved_activation_gpt2.py
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
from __future__ import annotations
import os
from pathlib import Path
import pytest
import torch
from nnsight import LanguageModel
from nnsight.modeling.base import NNsight
from transformers import AutoConfig, AutoModelForCausalLM
from more_nnsight import SavedActivation, save_activations, updates
@pytest.fixture(scope="module")
def gpt2_model() -> LanguageModel:
os.environ.setdefault("HF_HOME", str(Path.home() / "models"))
torch.cuda.memory.set_per_process_memory_fraction(0.8)
return LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
# Saving real GPT-2 activations should produce batch-by-hidden tensors at the requested paths.
def test_save_with_gpt2_returns_saved_tensors_at_requested_paths(gpt2_model: LanguageModel) -> None:
prompts = [
"One two three four five six seven eight nine ten eleven twelve.",
"Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.",
]
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
saved = save_activations(
gpt2_model,
["model.transformer.h[2].output[10]", "model.transformer.h[3].output[-1]"],
)
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
direct_layer_2 = gpt2_model.transformer.h[2].output[:, 10, :].save()
direct_layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save()
assert torch.allclose(saved.get("model.transformer.h[2].output[10]"), direct_layer_2)
assert torch.allclose(saved.get("model.transformer.h[3].output[-1]"), direct_layer_3)
# Slice syntax over GPT-2's transformer block list should match a direct per-layer NNSight loop.
def test_save_with_gpt2_layer_slice_matches_direct_nnsight_loop(gpt2_model: LanguageModel) -> None:
prompts = [
"One two three four five six seven eight nine ten eleven twelve.",
"Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.",
]
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
saved = save_activations(gpt2_model, ["model.transformer.h[:].output[10]"])
direct: list[torch.Tensor] = []
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
direct.extend(
gpt2_model.transformer.h[layer].output[:, 10, :].save()
for layer in range(len(gpt2_model.transformer.h))
)
assert len(saved.keys()) == len(direct)
for layer, direct_tensor in enumerate(direct):
assert torch.allclose(saved.get(f"model.transformer.h[{layer}].output[10]"), direct_tensor)
# Saving several token positions across all layers must still traverse layers in graph order.
def test_save_with_gpt2_layer_slice_and_multiple_positions_matches_direct_loop(
gpt2_model: LanguageModel,
) -> None:
prompts = [
"One two three four five six seven eight nine ten eleven twelve.",
"Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.",
]
positions = [-4, -3, -2, -1]
paths = [f"model.transformer.h[:].output[{position}]" for position in positions]
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
saved = save_activations(gpt2_model, paths)
direct: dict[str, torch.Tensor] = {}
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
for layer in range(len(gpt2_model.transformer.h)):
for position in positions:
direct[f"model.transformer.h[{layer}].output[{position}]"] = (
gpt2_model.transformer.h[layer].output[:, position, :].save()
)
assert saved.keys() == list(direct)
for path, direct_tensor in direct.items():
assert torch.allclose(saved.get(path), direct_tensor)
# Unsaved GPT-2 paths should still error even when nearby paths were captured.
def test_unsaved_paths_error_with_gpt2_saved_activation(gpt2_model: LanguageModel) -> None:
prompts = [
"One two three four five six seven eight nine ten eleven twelve.",
"Alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu.",
]
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
saved = save_activations(gpt2_model, ["model.transformer.h[2].output[10]"])
with pytest.raises(KeyError, match="saved_activation.transformer.h\\[2\\].output"):
_ = saved.values.transformer.h[2].output[9]
with pytest.raises(KeyError, match="saved_activation.transformer.h"):
_ = saved.values.transformer.h[3]
# Applying a saved activation should overwrite the live activation slice during a later run.
def test_apply_with_gpt2_patches_current_run(gpt2_model: LanguageModel) -> None:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
with gpt2_model.trace() as tracer:
with tracer.invoke([clean_prompt]):
patch = save_activations(gpt2_model, ["model.transformer.h[2].output[9]"])
clean_tensor = patch.values.transformer.h[2].output[9]
with gpt2_model.trace() as tracer:
with tracer.invoke([corrupted_prompt]):
before_patch = gpt2_model.transformer.h[2].output[:, 9, :].clone().save()
patch.apply(gpt2_model)
after_patch = gpt2_model.transformer.h[2].output[:, 9, :].save()
assert before_patch.shape == (1, 768)
assert after_patch.shape == (1, 768)
assert torch.equal(after_patch, clean_tensor)
assert not torch.equal(before_patch, after_patch)
# Applying multiple patches must replay them in model order even if the SavedActivation keys are reordered.
def test_apply_with_gpt2_reordered_patch_set_preserves_nnsight_access_order(
gpt2_model: LanguageModel,
) -> None:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
patch_paths = [
"model.transformer.h[2].output[9]",
"model.transformer.h[3].output[8]",
]
with gpt2_model.trace() as tracer:
with tracer.invoke([clean_prompt]):
patch = save_activations(gpt2_model, patch_paths)
reordered_patch = patch.subset(
[
"model.transformer.h[3].output[8]",
"model.transformer.h[2].output[9]",
]
)
assert reordered_patch.keys() == [
"model.transformer.h[2].output[9]",
"model.transformer.h[3].output[8]",
]
with gpt2_model.trace() as tracer:
with tracer.invoke([corrupted_prompt]):
reordered_patch.apply(gpt2_model)
patched_logits = gpt2_model.lm_head.output.save()
with gpt2_model.trace() as tracer:
with tracer.invoke([corrupted_prompt]):
gpt2_model.transformer.h[2].output[:, 9, :] = patch.get("model.transformer.h[2].output[9]")
gpt2_model.transformer.h[3].output[:, 8, :] = patch.get("model.transformer.h[3].output[8]")
direct_logits = gpt2_model.lm_head.output.save()
assert torch.allclose(patched_logits, direct_logits)
# Union should combine disjoint GPT-2 patches so they replay like one direct multi-assignment patch.
def test_union_with_gpt2_combines_disjoint_patches(gpt2_model: LanguageModel) -> None:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
with gpt2_model.trace() as tracer:
with tracer.invoke([clean_prompt]):
layer_2_patch = save_activations(gpt2_model, ["model.transformer.h[2].output[9]"])
layer_3_patch = save_activations(gpt2_model, ["model.transformer.h[3].output[8]"])
union_patch = layer_2_patch.union(layer_3_patch)
assert union_patch.keys() == [
"model.transformer.h[2].output[9]",
"model.transformer.h[3].output[8]",
]
with gpt2_model.trace() as tracer:
with tracer.invoke([corrupted_prompt]):
union_patch.apply(gpt2_model)
union_logits = gpt2_model.lm_head.output.save()
with gpt2_model.trace() as tracer:
with tracer.invoke([corrupted_prompt]):
gpt2_model.transformer.h[2].output[:, 9, :] = layer_2_patch.get("model.transformer.h[2].output[9]")
gpt2_model.transformer.h[3].output[:, 8, :] = layer_3_patch.get("model.transformer.h[3].output[8]")
direct_logits = gpt2_model.lm_head.output.save()
assert torch.allclose(union_logits, direct_logits)
# Saving, reading, and updating a multi-layer last-token patch should work within one invoke.
def test_updates_with_gpt2_supports_current_plus_scaled_saved_activation_in_single_invoke(
gpt2_model: LanguageModel,
) -> None:
prompt = "After John and Mary went to the store, John gave a bottle of milk to"
patch_paths = [
"model.transformer.h[2].output[-1]",
"model.transformer.h[3].output[-1]",
]
with gpt2_model.trace() as tracer:
with tracer.invoke([prompt]):
saved = save_activations(gpt2_model, patch_paths)
with gpt2_model.trace() as tracer:
with tracer.invoke([prompt]):
layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save()
gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0])
layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save()
gpt2_model.transformer.h[3].output[:, -1, :] = layer_3 + 2.0 * saved.get(patch_paths[1])
direct_logits = gpt2_model.lm_head.output.save()
with gpt2_model.trace() as tracer:
with tracer.invoke([prompt]):
for key, value, update in updates(gpt2_model, saved.keys()):
update(value + 2.0 * saved.get(key))
patched_logits = gpt2_model.lm_head.output.save()
assert torch.allclose(patched_logits, direct_logits)
def test_updates_with_gpt2_supports_key_specific_update_logic(
gpt2_model: LanguageModel,
) -> None:
prompt = "After John and Mary went to the store, John gave a bottle of milk to"
patch_paths = [
"model.transformer.h[2].output[-1]",
"model.transformer.h[3].output[-1]",
]
with gpt2_model.trace() as tracer:
with tracer.invoke([prompt]):
saved = save_activations(gpt2_model, patch_paths)
with gpt2_model.trace() as tracer:
with tracer.invoke([prompt]):
layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save()
gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0])
layer_3 = gpt2_model.transformer.h[3].output[:, -1, :].save()
gpt2_model.transformer.h[3].output[:, -1, :] = saved.get(patch_paths[1])
direct_logits = gpt2_model.lm_head.output.save()
with gpt2_model.trace() as tracer:
with tracer.invoke([prompt]):
for key, value, update in updates(gpt2_model, saved.keys()):
if key == "model.transformer.h[2].output[-1]":
update(value + 2.0 * saved.get(key))
else:
update(saved.get(key))
patched_logits = gpt2_model.lm_head.output.save()
assert torch.allclose(patched_logits, direct_logits)
def test_updates_with_gpt2_supports_key_specific_batch_reductions(
gpt2_model: LanguageModel,
) -> None:
prompts = [
"The movie was great and I felt happy.",
"The dinner was excellent and I felt joyful.",
"The trip was terrible and I felt sad.",
]
patch_paths = [
"model.transformer.h[2].output[-1]",
"model.transformer.h[3].output[-1]",
]
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
saved = save_activations(gpt2_model, patch_paths)
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
layer_2 = gpt2_model.transformer.h[2].output[:, -1, :].save()
gpt2_model.transformer.h[2].output[:, -1, :] = layer_2 + 2.0 * saved.get(patch_paths[0])
gpt2_model.transformer.h[3].output[:, -1, :] = saved.get(patch_paths[1])[[0, 2]].mean(
dim=0, keepdim=True
)
direct_logits = gpt2_model.lm_head.output.save()
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
for key, value, update in updates(gpt2_model, saved.keys()):
if key == "model.transformer.h[2].output[-1]":
update(value + 2.0 * saved.get(key))
else:
update(saved.get(key)[[0, 2]].mean(dim=0, keepdim=True))
patched_logits = gpt2_model.lm_head.output.save()
assert torch.allclose(patched_logits, direct_logits)
# `mean()` should reduce a real saved GPT-2 batch to a single-row steering vector.
def test_mean_and_get_with_gpt2(gpt2_model: LanguageModel) -> None:
prompts = [
"The movie was great and I felt happy.",
"The dinner was excellent and I felt joyful.",
]
path = "model.transformer.h[5].output[-1]"
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
saved = save_activations(gpt2_model, [path])
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
direct_mean = gpt2_model.transformer.h[5].output[:, -1, :].mean(dim=0, keepdim=True).save()
assert saved.keys() == [path]
assert saved.get(path).shape == (2, 768)
mean_saved = saved.mean()
assert mean_saved.get(path).shape == (1, 768)
assert torch.allclose(mean_saved.get(path), direct_mean)
# Slicing should let one GPT-2 save feed multiple downstream batch views.
def test_slice_with_gpt2_selects_requested_batch_rows(gpt2_model: LanguageModel) -> None:
prompts = [
"The movie was great and I felt happy.",
"The dinner was excellent and I felt joyful.",
"The trip was terrible and I felt sad.",
]
path = "model.transformer.h[5].output[-1]"
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
saved = save_activations(gpt2_model, [path])
sliced = saved.slice[0, 2:3]
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
direct_slice = gpt2_model.transformer.h[5].output[[0, 2], -1, :].save()
assert saved.get(path).shape == (3, 768)
assert sliced.get(path).shape == (2, 768)
assert torch.allclose(sliced.get(path), direct_slice)
def test_mean_inside_trace_saves_reduced_result_on_device(gpt2_model: LanguageModel) -> None:
"""This guards the intended workflow where reduction happens before trace exit, avoiding saving the full batch when only the mean is needed."""
prompts = [
"The movie was great and I felt happy.",
"The dinner was excellent and I felt joyful.",
]
path = "model.transformer.h[5].output[-1]"
with gpt2_model.trace() as tracer:
with tracer.invoke(prompts):
saved = save_activations(gpt2_model, [path])
mean_saved = saved.mean()
assert saved.get(path).shape == (2, 768)
assert mean_saved.get(path).shape == (1, 768)
assert torch.allclose(mean_saved.get(path), saved.get(path).mean(dim=0, keepdim=True))
# A randomly initialized non-GPT2 architecture should work as long as layers live in a ModuleList.
def test_save_with_random_init_qwen_layer_slice_matches_direct_loop() -> None:
cfg = AutoConfig.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
cfg.hidden_size = 64
cfg.intermediate_size = 128
cfg.num_hidden_layers = 3
cfg.num_attention_heads = 4
cfg.num_key_value_heads = 4
cfg.vocab_size = 256
model = NNsight(AutoModelForCausalLM.from_config(cfg))
input_ids = torch.randint(0, cfg.vocab_size, (2, 6))
attention_mask = torch.ones_like(input_ids)
with model.trace(input_ids=input_ids, attention_mask=attention_mask):
saved = save_activations(model, ["model.model.layers[:].output[2]"])
direct: list[torch.Tensor] = []
with model.trace(input_ids=input_ids, attention_mask=attention_mask):
direct.extend(
model.model.layers[layer].output[:, 2, :].save()
for layer in range(cfg.num_hidden_layers)
)
assert len(saved.keys()) == cfg.num_hidden_layers
for layer, direct_tensor in enumerate(direct):
assert torch.allclose(saved.get(f"model.model.layers[{layer}].output[2]"), direct_tensor)