Repositories / more_nnsight.git

scripts/demo_activation_patching.py

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch
2881 bytes · 7b8c0c87b113
from __future__ import annotations import os from pathlib import Path import torch from nnsight import LanguageModel from more_nnsight import SavedActivation, save_activations def _token_id(model: LanguageModel, token: str) -> int: """Resolves a single-token string so the demo can compare specific logits.""" return model.tokenizer.encode(token)[0] def logit_diff(model: LanguageModel, prompt: str, correct_token: str, incorrect_token: str) -> float: """Measures which of two candidate next tokens the model currently prefers.""" correct_index = _token_id(model, correct_token) incorrect_index = _token_id(model, incorrect_token) with model.trace() as tracer: with tracer.invoke(prompt): logits = model.lm_head.output.save() return float((logits[0, -1, correct_index] - logits[0, -1, incorrect_index]).detach()) def patched_logit_diff( model: LanguageModel, prompt: str, patch: SavedActivation, correct_token: str, incorrect_token: str, ) -> float: """Measures the same preference after replaying a saved activation patch.""" correct_index = _token_id(model, correct_token) incorrect_index = _token_id(model, incorrect_token) with model.trace() as tracer: with tracer.invoke(prompt): patch.apply(model) logits = model.lm_head.output.save() return float((logits[0, -1, correct_index] - logits[0, -1, incorrect_index]).detach()) def main() -> None: """Demonstrates save, subset, and apply on the GPT-2 IOI-style patching setup.""" os.environ.setdefault("HF_HOME", str(Path.home() / "models")) torch.cuda.memory.set_per_process_memory_fraction(0.8) model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True) clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to" corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" patch_paths = [ "model.transformer.h[0].output[9]", "model.transformer.h[1].output[9]", "model.transformer.h[2].output[9]", "model.transformer.h[3].output[9]", ] with model.trace() as tracer: with tracer.invoke([clean_prompt]): clean_saved = save_activations(model, patch_paths) focused_patch = clean_saved.subset(["model.transformer.h[2].output[9]"]) clean_diff = logit_diff(model, clean_prompt, " John", " Mary") corrupted_diff = logit_diff(model, corrupted_prompt, " John", " Mary") patched_diff = patched_logit_diff(model, corrupted_prompt, focused_patch, " John", " Mary") print(f"clean_logit_diff={clean_diff:.4f}") print(f"corrupted_logit_diff={corrupted_diff:.4f}") print(f"patched_logit_diff={patched_diff:.4f}") print("patched_path=model.transformer.h[2].output[9]") if __name__ == "__main__": main()