Repositories / more_nnsight.git
scripts/demo_activation_steering.py
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
from __future__ import annotations
import os
from pathlib import Path
import torch
from nnsight import LanguageModel
from more_nnsight import SavedActivation, save_activations
def _token_id(model: LanguageModel, token: str) -> int:
"""Resolves a single-token string so the demo can compare specific logits."""
return model.tokenizer.encode(token)[0]
def candidate_logits(
model: LanguageModel,
prompt: str,
correct_token: str,
incorrect_token: str,
patch: SavedActivation | None = None,
) -> tuple[float, float]:
"""Measures two candidate next-token logits, optionally after steering the run."""
correct_index = _token_id(model, correct_token)
incorrect_index = _token_id(model, incorrect_token)
with model.trace() as tracer:
with tracer.invoke(prompt):
if patch is not None:
patch.apply(model)
logits = model.lm_head.output.save()
return (
float(logits[0, -1, correct_index].detach()),
float(logits[0, -1, incorrect_index].detach()),
)
def main() -> None:
"""Demonstrates simple activation steering built from positive and negative prompt sets."""
os.environ.setdefault("HF_HOME", str(Path.home() / "models"))
torch.cuda.memory.set_per_process_memory_fraction(0.8)
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
steering_path = "model.transformer.h[5].output[-1]"
positive_prompts = [
"The movie was absolutely wonderful and I felt",
"The dinner was excellent and I left feeling",
"The vacation was amazing and it made me feel",
]
negative_prompts = [
"The movie was absolutely terrible and I felt",
"The dinner was awful and I left feeling",
"The vacation was horrible and it made me feel",
]
neutral_prompt = "The day was long and by the end I felt"
prompts = positive_prompts + negative_prompts + [neutral_prompt]
with model.trace() as tracer:
with tracer.invoke(prompts):
saved = save_activations(model, [steering_path])
positive = saved.slice[0 : len(positive_prompts)].mean()
negative = saved.slice[
len(positive_prompts) : len(positive_prompts) + len(negative_prompts)
].mean()
neutral = saved.slice[len(prompts) - 1]
steering_vector = positive - negative
steered = neutral + 1.5 * steering_vector
baseline_happy, baseline_sad = candidate_logits(model, neutral_prompt, " happy", " sad")
steered_happy, steered_sad = candidate_logits(
model, neutral_prompt, " happy", " sad", patch=steered
)
print(f"steering_keys={steering_vector.keys()}")
print(f"steering_norm={steering_vector.get(steering_path).norm().item():.4f}")
print(f"baseline_logit_happy={baseline_happy:.4f}")
print(f"baseline_logit_sad={baseline_sad:.4f}")
print(f"steered_logit_happy={steered_happy:.4f}")
print(f"steered_logit_sad={steered_sad:.4f}")
if __name__ == "__main__":
main()