Repositories / more_nnsight.git
scripts/demo_activation_patching.py
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
from __future__ import annotations
import os
from pathlib import Path
import torch
from nnsight import LanguageModel
from more_nnsight import SavedActivation, save_activations
def _token_id(model: LanguageModel, token: str) -> int:
"""Resolves a single-token string so the demo can compare specific logits."""
return model.tokenizer.encode(token)[0]
def logit_diff(model: LanguageModel, prompt: str, correct_token: str, incorrect_token: str) -> float:
"""Measures which of two candidate next tokens the model currently prefers."""
correct_index = _token_id(model, correct_token)
incorrect_index = _token_id(model, incorrect_token)
with model.trace() as tracer:
with tracer.invoke(prompt):
logits = model.lm_head.output.save()
return float((logits[0, -1, correct_index] - logits[0, -1, incorrect_index]).detach())
def patched_logit_diff(
model: LanguageModel,
prompt: str,
patch: SavedActivation,
correct_token: str,
incorrect_token: str,
) -> float:
"""Measures the same preference after replaying a saved activation patch."""
correct_index = _token_id(model, correct_token)
incorrect_index = _token_id(model, incorrect_token)
with model.trace() as tracer:
with tracer.invoke(prompt):
patch.apply(model)
logits = model.lm_head.output.save()
return float((logits[0, -1, correct_index] - logits[0, -1, incorrect_index]).detach())
def main() -> None:
"""Demonstrates save, subset, and apply on the GPT-2 IOI-style patching setup."""
os.environ.setdefault("HF_HOME", str(Path.home() / "models"))
torch.cuda.memory.set_per_process_memory_fraction(0.8)
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
patch_paths = [
"model.transformer.h[0].output[9]",
"model.transformer.h[1].output[9]",
"model.transformer.h[2].output[9]",
"model.transformer.h[3].output[9]",
]
with model.trace() as tracer:
with tracer.invoke([clean_prompt]):
clean_saved = save_activations(model, patch_paths)
focused_patch = clean_saved.subset(["model.transformer.h[2].output[9]"])
clean_diff = logit_diff(model, clean_prompt, " John", " Mary")
corrupted_diff = logit_diff(model, corrupted_prompt, " John", " Mary")
patched_diff = patched_logit_diff(model, corrupted_prompt, focused_patch, " John", " Mary")
print(f"clean_logit_diff={clean_diff:.4f}")
print(f"corrupted_logit_diff={corrupted_diff:.4f}")
print(f"patched_logit_diff={patched_diff:.4f}")
print("patched_path=model.transformer.h[2].output[9]")
if __name__ == "__main__":
main()