Repositories / more_nnsight.git

scripts/demo_activation_steering.py

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch
3062 bytes · 842e39c64fdc
from __future__ import annotations import os from pathlib import Path import torch from nnsight import LanguageModel from more_nnsight import SavedActivation, save_activations def _token_id(model: LanguageModel, token: str) -> int: """Resolves a single-token string so the demo can compare specific logits.""" return model.tokenizer.encode(token)[0] def candidate_logits( model: LanguageModel, prompt: str, correct_token: str, incorrect_token: str, patch: SavedActivation | None = None, ) -> tuple[float, float]: """Measures two candidate next-token logits, optionally after steering the run.""" correct_index = _token_id(model, correct_token) incorrect_index = _token_id(model, incorrect_token) with model.trace() as tracer: with tracer.invoke(prompt): if patch is not None: patch.apply(model) logits = model.lm_head.output.save() return ( float(logits[0, -1, correct_index].detach()), float(logits[0, -1, incorrect_index].detach()), ) def main() -> None: """Demonstrates simple activation steering built from positive and negative prompt sets.""" os.environ.setdefault("HF_HOME", str(Path.home() / "models")) torch.cuda.memory.set_per_process_memory_fraction(0.8) model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True) steering_path = "model.transformer.h[5].output[-1]" positive_prompts = [ "The movie was absolutely wonderful and I felt", "The dinner was excellent and I left feeling", "The vacation was amazing and it made me feel", ] negative_prompts = [ "The movie was absolutely terrible and I felt", "The dinner was awful and I left feeling", "The vacation was horrible and it made me feel", ] neutral_prompt = "The day was long and by the end I felt" prompts = positive_prompts + negative_prompts + [neutral_prompt] with model.trace() as tracer: with tracer.invoke(prompts): saved = save_activations(model, [steering_path]) positive = saved.slice[0 : len(positive_prompts)].mean() negative = saved.slice[ len(positive_prompts) : len(positive_prompts) + len(negative_prompts) ].mean() neutral = saved.slice[len(prompts) - 1] steering_vector = positive - negative steered = neutral + 1.5 * steering_vector baseline_happy, baseline_sad = candidate_logits(model, neutral_prompt, " happy", " sad") steered_happy, steered_sad = candidate_logits( model, neutral_prompt, " happy", " sad", patch=steered ) print(f"steering_keys={steering_vector.keys()}") print(f"steering_norm={steering_vector.get(steering_path).norm().item():.4f}") print(f"baseline_logit_happy={baseline_happy:.4f}") print(f"baseline_logit_sad={baseline_sad:.4f}") print(f"steered_logit_happy={steered_happy:.4f}") print(f"steered_logit_sad={steered_sad:.4f}") if __name__ == "__main__": main()