Repositories / more_nnsight.git
tests/test_steering_search_gpt2.py
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
"""Integration tests for steering_search with GPT-2."""
from __future__ import annotations
import pytest
import torch
from nnsight import LanguageModel
from more_nnsight import SteeringSearchConfig, SteeringSearchOutput, best_trial, steering_search
# Derived by running a forward pass on GPT-2 with "The weather today is" and
# manually classifying the top-50 next tokens by sentiment.
_POSITIVE_KEYWORDS = [
"good", "perfect", "nice", "beautiful", "sunny", "great", "fine",
"calm", "amazing", "excellent", "warm", "wonderful", "pleasant", "lovely",
]
_NEGATIVE_KEYWORDS = [
"bad", "cold", "freezing", "cloudy", "terrible", "awful",
"gloomy", "miserable", "horrible",
]
@pytest.fixture(scope="module")
def gpt2():
torch.cuda.memory.set_per_process_memory_fraction(0.8)
return LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
def _sentiment_score(response: str) -> float:
"""Binary score: -1 if only negative keywords, +1 if only positive, 0 if neither/both."""
text = response.lower()
has_pos = any(kw in text for kw in _POSITIVE_KEYWORDS)
has_neg = any(kw in text for kw in _NEGATIVE_KEYWORDS)
if has_neg and not has_pos:
return -1.0
if has_pos and not has_neg:
return 1.0
return 0.0
def test_steering_changes_sentiment(gpt2, tmp_path):
"""Steering negative - positive should push output toward negative sentiment."""
candidate_paths = [
"model.transformer.h[10].output[-1]",
"model.transformer.h[11].output[-1]",
]
# Swapped: steering vector = mean(neg_activations) - mean(pos_activations)
positive_prompts = [
"I hate this! It's terrible",
"This is awful and horrible",
"I'm so sad about this",
]
negative_prompts = [
"I love this! It's fantastic",
"This is wonderful and amazing",
"I'm so happy about this",
]
eval_prompt = "The weather today is"
def score(row: dict, baseline: str, steered: str) -> float:
return _sentiment_score(steered)
config = SteeringSearchConfig(
candidate_paths=candidate_paths,
max_simultaneous_paths=2,
alpha_range=(0.5, 5.0),
n_calls=8,
n_initial_points=4,
seed=42,
)
output = tmp_path / "trials.jsonl"
out = steering_search(
config=config,
model=gpt2,
positive_prompts=positive_prompts,
negative_prompts=negative_prompts,
eval_rows=[{"prompt": eval_prompt}],
score=score,
max_new_tokens=4,
output_path=output,
)
assert isinstance(out, SteeringSearchOutput)
assert len(out.trials) == 8
assert output.exists()
# Baseline response should contain a positive keyword (GPT-2 default).
assert len(out.baseline_responses) == 1
baseline_response = out.baseline_responses[0]
baseline_score = _sentiment_score(baseline_response)
print(f"\nBaseline: {repr(baseline_response)} score={baseline_score}")
assert any(kw in baseline_response.lower() for kw in _POSITIVE_KEYWORDS), (
f"Expected positive keyword in baseline: {repr(baseline_response)}"
)
# Best steered trial should score better (more negative) than baseline.
best = best_trial(out.trials)
print(f"Best steered score={best.score} paths={best.selected_paths} alpha={best.alpha:.2f}")
assert len(best.selected_paths) > 0
assert best.score < baseline_score, (
f"Steering did not make sentiment more negative: baseline={baseline_score}, best={best.score}"
)
# Best trial's responses should contain a negative keyword.
assert len(best.responses) == 1
best_response = best.responses[0]
print(f"Best steered response: {repr(best_response)}")
assert any(kw in best_response.lower() for kw in _NEGATIVE_KEYWORDS), (
f"Expected negative keyword in steered response: {repr(best_response)}"
)