Repositories / more_nnsight.git
tests/test_steering_search.py
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
"""Unit tests for steering_search — CPU-only, no model required."""
from __future__ import annotations
import json
import sys
import pytest
import torch
from more_nnsight.steering_search import (
PENALTY,
SavedActivation,
SteeringSearchConfig,
SteeringSearchOutput,
SteeringTrialResult,
_decode_point,
_group_paths_by_layer,
best_trial,
layer_sweep,
steer_and_eval,
steering_search,
)
_ss = sys.modules["more_nnsight.steering_search"]
# ---------------------------------------------------------------------------
# _decode_point
# ---------------------------------------------------------------------------
def test_decode_deduplicates_and_removes_sentinel():
candidates = ["path_a", "path_b", "path_c"]
selected, alpha = _decode_point([2, 0, 2, 1.5], n_candidates=3, candidate_paths=candidates)
assert selected == ["path_a", "path_c"]
assert alpha == 1.5
def test_decode_all_sentinel_returns_empty():
candidates = ["path_a", "path_b"]
selected, alpha = _decode_point([2, 2, 0.5], n_candidates=2, candidate_paths=candidates)
assert selected == []
assert alpha == 0.5
# ---------------------------------------------------------------------------
# best_trial
# ---------------------------------------------------------------------------
def test_best_trial_returns_minimum_score():
trials = [
SteeringTrialResult(0, ["a"], 1.0, 5.0, [0, 1.0], ["r0"]),
SteeringTrialResult(1, ["b"], 2.0, 1.0, [1, 2.0], ["r1"]),
SteeringTrialResult(2, ["a", "b"], 1.5, 3.0, [0, 1, 1.5], ["r2"]),
]
assert best_trial(trials).trial_index == 1
# ---------------------------------------------------------------------------
# steering_search (with mocked internals)
# ---------------------------------------------------------------------------
class _FakeDirection:
"""Mimics SavedActivation enough for the objective function."""
def __init__(self, paths):
self._paths = list(paths)
def keys(self):
return list(self._paths)
def get(self, key):
return torch.zeros(1, 8)
def _patch_internals(monkeypatch, fake_vectors, generate_fn=None):
"""Patch _compute_steering_vectors, SavedActivation.from_pairs,
_generate_batch (baseline), and _generate_steered_batch."""
monkeypatch.setattr(_ss, "_compute_steering_vectors", lambda *a, **kw: fake_vectors)
def fake_from_pairs(*pairs):
return _FakeDirection([p for p, _ in pairs])
monkeypatch.setattr(SavedActivation, "from_pairs", staticmethod(fake_from_pairs))
# Mock baseline generation.
def baseline_batch_fn(model, prompts, max_new_tokens):
return ["baseline response"] * len(prompts)
monkeypatch.setattr(_ss, "_generate_batch", baseline_batch_fn)
# Mock steered generation.
if generate_fn is not None:
def batch_fn(model, prompts, direction, alpha, max_new_tokens):
return [generate_fn(model, p, direction, alpha, max_new_tokens) for p in prompts]
else:
def batch_fn(model, prompts, direction, alpha, max_new_tokens):
return ["fake response"] * len(prompts)
monkeypatch.setattr(_ss, "_generate_steered_batch", batch_fn)
def test_empty_candidates_raises():
config = SteeringSearchConfig(
candidate_paths=[],
max_simultaneous_paths=2,
alpha_range=(0.1, 5.0),
n_calls=5,
)
with pytest.raises(ValueError, match="candidate_paths must not be empty"):
steering_search(
config=config,
model=None,
positive_prompts=[],
negative_prompts=[],
eval_rows=[],
score=lambda row, b, s: 0.0,
max_new_tokens=10,
)
def test_search_returns_correct_number_of_results(tmp_path, monkeypatch):
fake_vectors = {"p0": torch.randn(1, 8), "p1": torch.randn(1, 8)}
_patch_internals(monkeypatch, fake_vectors)
config = SteeringSearchConfig(
candidate_paths=["p0", "p1"],
max_simultaneous_paths=2,
alpha_range=(0.1, 3.0),
n_calls=15,
n_initial_points=5,
)
output = tmp_path / "trials.jsonl"
out = steering_search(
config=config,
model=None,
positive_prompts=["pos1"],
negative_prompts=["neg1"],
eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}],
score=lambda row, b, s: 1.0,
max_new_tokens=10,
output_path=output,
)
assert isinstance(out, SteeringSearchOutput)
assert len(out.trials) == 15
assert len(out.baseline_responses) == 2
def test_search_finds_known_optimum(monkeypatch):
candidates = [f"path_{i}" for i in range(4)]
fake_vectors = {p: torch.randn(1, 8) for p in candidates}
def fake_generate(model, prompt, direction, alpha, max_new_tokens):
return ",".join(direction.keys())
_patch_internals(monkeypatch, fake_vectors, generate_fn=fake_generate)
def score_fn(row, baseline, steered):
paths = steered.split(",") if steered else []
return 0.0 if "path_2" in paths else 1.0
config = SteeringSearchConfig(
candidate_paths=candidates,
max_simultaneous_paths=2,
alpha_range=(0.5, 5.0),
n_calls=30,
n_initial_points=10,
)
out = steering_search(
config=config,
model=None,
positive_prompts=["pos"],
negative_prompts=["neg"],
eval_rows=[{"prompt": "eval"}],
score=score_fn,
max_new_tokens=10,
)
best = best_trial(out.trials)
assert best.score == 0.0
assert "path_2" in best.selected_paths
def test_jsonl_output_written_incrementally(tmp_path, monkeypatch):
fake_vectors = {"p0": torch.randn(1, 8)}
_patch_internals(monkeypatch, fake_vectors)
config = SteeringSearchConfig(
candidate_paths=["p0"],
max_simultaneous_paths=1,
alpha_range=(0.1, 3.0),
n_calls=8,
n_initial_points=4,
)
output = tmp_path / "out.jsonl"
steering_search(
config=config,
model=None,
positive_prompts=["pos"],
negative_prompts=["neg"],
eval_rows=[{"prompt": "eval"}],
score=lambda row, b, s: 1.0,
max_new_tokens=10,
output_path=output,
)
lines = output.read_text().strip().split("\n")
assert len(lines) == 8
for line in lines:
record = json.loads(line)
assert "trial_index" in record
assert "selected_paths" in record
assert "alpha" in record
assert "score" in record
assert "raw_point" in record
assert "responses" in record
def test_empty_selection_gets_penalty(monkeypatch):
fake_vectors = {"p0": torch.randn(1, 8)}
_patch_internals(monkeypatch, fake_vectors)
config = SteeringSearchConfig(
candidate_paths=["p0"],
max_simultaneous_paths=2,
alpha_range=(0.1, 3.0),
n_calls=10,
n_initial_points=5,
)
out = steering_search(
config=config,
model=None,
positive_prompts=["pos"],
negative_prompts=["neg"],
eval_rows=[{"prompt": "eval"}],
score=lambda row, b, s: 0.0,
max_new_tokens=10,
)
penalty_trials = [r for r in out.trials if not r.selected_paths]
for t in penalty_trials:
assert t.score == PENALTY
def test_score_receives_row_metadata(monkeypatch):
"""Score function receives the full row dict, not just the prompt."""
fake_vectors = {"p0": torch.randn(1, 8), "p1": torch.randn(1, 8)}
_patch_internals(monkeypatch, fake_vectors)
received_rows = []
def score_fn(row, baseline, steered):
received_rows.append(row)
return 0.5
config = SteeringSearchConfig(
candidate_paths=["p0", "p1"],
max_simultaneous_paths=2,
alpha_range=(0.1, 3.0),
n_calls=5,
n_initial_points=5,
)
steering_search(
config=config,
model=None,
positive_prompts=["pos"],
negative_prompts=["neg"],
eval_rows=[{"prompt": "eval", "task_id": "T1", "extra": 42}],
score=score_fn,
max_new_tokens=10,
)
non_penalty = [r for r in received_rows if r is not None]
assert len(non_penalty) > 0
assert non_penalty[0]["task_id"] == "T1"
assert non_penalty[0]["extra"] == 42
# ---------------------------------------------------------------------------
# steer_and_eval
# ---------------------------------------------------------------------------
def test_steer_and_eval_rejects_both_prompts_and_vectors(monkeypatch):
fake_vectors = {"p0": torch.randn(1, 8)}
_patch_internals(monkeypatch, fake_vectors)
with pytest.raises(ValueError, match="not both"):
steer_and_eval(
paths=["p0"],
alpha=1.0,
model=None,
eval_rows=[{"prompt": "eval"}],
score=lambda row, b, s: 0.0,
max_new_tokens=10,
positive_prompts=["pos"],
negative_prompts=["neg"],
steering_vectors=fake_vectors,
)
def test_steer_and_eval_rejects_neither_prompts_nor_vectors(monkeypatch):
fake_vectors = {"p0": torch.randn(1, 8)}
_patch_internals(monkeypatch, fake_vectors)
with pytest.raises(ValueError, match="Provide either"):
steer_and_eval(
paths=["p0"],
alpha=1.0,
model=None,
eval_rows=[{"prompt": "eval"}],
score=lambda row, b, s: 0.0,
max_new_tokens=10,
)
def test_steer_and_eval_rejects_only_positive_prompts(monkeypatch):
fake_vectors = {"p0": torch.randn(1, 8)}
_patch_internals(monkeypatch, fake_vectors)
with pytest.raises(ValueError, match="both be provided together"):
steer_and_eval(
paths=["p0"],
alpha=1.0,
model=None,
eval_rows=[{"prompt": "eval"}],
score=lambda row, b, s: 0.0,
max_new_tokens=10,
positive_prompts=["pos"],
)
def test_steer_and_eval_rejects_empty_paths(monkeypatch):
fake_vectors = {}
_patch_internals(monkeypatch, fake_vectors)
with pytest.raises(ValueError, match="paths must not be empty"):
steer_and_eval(
paths=[],
alpha=1.0,
model=None,
eval_rows=[{"prompt": "eval"}],
score=lambda row, b, s: 0.0,
max_new_tokens=10,
steering_vectors=fake_vectors,
)
def test_steer_and_eval_rejects_empty_eval_rows(monkeypatch):
fake_vectors = {"p0": torch.randn(1, 8)}
_patch_internals(monkeypatch, fake_vectors)
with pytest.raises(ValueError, match="eval_rows must not be empty"):
steer_and_eval(
paths=["p0"],
alpha=1.0,
model=None,
eval_rows=[],
score=lambda row, b, s: 0.0,
max_new_tokens=10,
steering_vectors=fake_vectors,
)
def test_steer_and_eval_with_prompts(monkeypatch):
fake_vectors = {"p0": torch.randn(1, 8), "p1": torch.randn(1, 8)}
_patch_internals(monkeypatch, fake_vectors)
out = steer_and_eval(
paths=["p0", "p1"],
alpha=2.0,
model=None,
eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}],
score=lambda row, b, s: 0.5,
max_new_tokens=10,
positive_prompts=["pos"],
negative_prompts=["neg"],
)
assert isinstance(out, SteeringSearchOutput)
assert len(out.trials) == 1
assert len(out.baseline_responses) == 2
trial = out.trials[0]
assert trial.trial_index == 0
assert trial.selected_paths == ["p0", "p1"]
assert trial.alpha == 2.0
assert trial.score == 0.5
assert len(trial.responses) == 2
def test_steer_and_eval_with_steering_vectors(monkeypatch):
fake_vectors = {"p0": torch.randn(1, 8)}
# Ensure _compute_steering_vectors is NOT called when steering_vectors is passed.
compute_called = []
def fake_compute(*a, **kw):
compute_called.append(True)
return fake_vectors
monkeypatch.setattr(_ss, "_compute_steering_vectors", fake_compute)
def fake_from_pairs(*pairs):
return _FakeDirection([p for p, _ in pairs])
monkeypatch.setattr(SavedActivation, "from_pairs", staticmethod(fake_from_pairs))
monkeypatch.setattr(_ss, "_generate_batch", lambda model, prompts, max_new_tokens: ["baseline"] * len(prompts))
monkeypatch.setattr(_ss, "_generate_steered_batch", lambda model, prompts, direction, alpha, max_new_tokens: ["steered"] * len(prompts))
out = steer_and_eval(
paths=["p0"],
alpha=1.5,
model=None,
eval_rows=[{"prompt": "eval"}],
score=lambda row, b, s: 0.0,
max_new_tokens=10,
steering_vectors=fake_vectors,
)
assert not compute_called, "_compute_steering_vectors must not be called when steering_vectors is provided"
assert len(out.trials) == 1
assert out.trials[0].alpha == 1.5
assert out.trials[0].score == 0.0
def test_steer_and_eval_jsonl_output(tmp_path, monkeypatch):
fake_vectors = {"p0": torch.randn(1, 8)}
_patch_internals(monkeypatch, fake_vectors)
output = tmp_path / "steer.jsonl"
steer_and_eval(
paths=["p0"],
alpha=3.0,
model=None,
eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}],
score=lambda row, b, s: 1.0,
max_new_tokens=10,
steering_vectors=fake_vectors,
output_path=output,
)
lines = output.read_text().strip().split("\n")
assert len(lines) == 1
record = json.loads(lines[0])
assert record["trial_index"] == 0
assert record["selected_paths"] == ["p0"]
assert record["alpha"] == 3.0
assert record["score"] == 1.0
def test_steer_and_eval_score_receives_row_metadata(monkeypatch):
fake_vectors = {"p0": torch.randn(1, 8)}
_patch_internals(monkeypatch, fake_vectors)
received = []
def score_fn(row, baseline, steered):
received.append(row)
return 0.0
steer_and_eval(
paths=["p0"],
alpha=1.0,
model=None,
eval_rows=[{"prompt": "eval", "task_id": "T42", "extra": 99}],
score=score_fn,
max_new_tokens=10,
steering_vectors=fake_vectors,
)
assert len(received) == 1
assert received[0]["task_id"] == "T42"
assert received[0]["extra"] == 99
# ---------------------------------------------------------------------------
# _group_paths_by_layer
# ---------------------------------------------------------------------------
def test_group_paths_same_layer_grouped_together():
paths = [
"model.transformer.h[10].output[-1]",
"model.transformer.h[10].output[-2]",
"model.transformer.h[11].output[-1]",
]
groups = _group_paths_by_layer(paths)
assert len(groups) == 2
assert groups[0] == [
"model.transformer.h[10].output[-1]",
"model.transformer.h[10].output[-2]",
]
assert groups[1] == ["model.transformer.h[11].output[-1]"]
def test_group_paths_preserves_insertion_order():
paths = [
"model.transformer.h[5].output[-1]",
"model.transformer.h[3].output[-1]",
"model.transformer.h[5].output[-2]",
]
groups = _group_paths_by_layer(paths)
assert len(groups) == 2
assert groups[0][0].startswith("model.transformer.h[5]")
assert groups[1][0].startswith("model.transformer.h[3]")
def test_group_paths_no_token_index():
# Paths without a trailing [N] each form their own group.
paths = ["model.transformer.h[10].output", "model.transformer.h[11].output"]
groups = _group_paths_by_layer(paths)
assert len(groups) == 2
# ---------------------------------------------------------------------------
# layer_sweep
# ---------------------------------------------------------------------------
def test_layer_sweep_groups_by_layer(monkeypatch):
# Two paths in the same layer → one trial; one path in a different layer → second trial.
paths = [
"model.transformer.h[10].output[-1]",
"model.transformer.h[10].output[-2]",
"model.transformer.h[11].output[-1]",
]
fake_vectors = {p: torch.randn(1, 8) for p in paths}
_patch_internals(monkeypatch, fake_vectors)
out = layer_sweep(
paths=paths,
alpha=2.0,
model=None,
positive_prompts=["pos"],
negative_prompts=["neg"],
eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}],
score=lambda row, b, s: 0.5,
max_new_tokens=10,
)
assert len(out.trials) == 2
assert out.trials[0].selected_paths == [
"model.transformer.h[10].output[-1]",
"model.transformer.h[10].output[-2]",
]
assert out.trials[1].selected_paths == ["model.transformer.h[11].output[-1]"]
for i, trial in enumerate(out.trials):
assert trial.trial_index == i
assert trial.alpha == 2.0
assert trial.raw_point == [i, 2.0]
def test_layer_sweep_empty_paths_raises():
with pytest.raises(ValueError, match="paths must not be empty"):
layer_sweep(
paths=[],
alpha=1.0,
model=None,
positive_prompts=[],
negative_prompts=[],
eval_rows=[],
score=lambda row, b, s: 0.0,
max_new_tokens=10,
)
def test_layer_sweep_jsonl_output(tmp_path, monkeypatch):
import json
paths = [
"model.transformer.h[10].output[-1]",
"model.transformer.h[10].output[-2]",
"model.transformer.h[11].output[-1]",
]
fake_vectors = {p: torch.randn(1, 8) for p in paths}
_patch_internals(monkeypatch, fake_vectors)
output = tmp_path / "sweep.jsonl"
layer_sweep(
paths=paths,
alpha=3.0,
model=None,
positive_prompts=["pos"],
negative_prompts=["neg"],
eval_rows=[{"prompt": "eval"}],
score=lambda row, b, s: 0.0,
max_new_tokens=10,
output_path=output,
)
lines = output.read_text().strip().split("\n")
assert len(lines) == 2 # two groups → two trials
for line in lines:
record = json.loads(line)
assert "trial_index" in record
assert "selected_paths" in record
assert record["alpha"] == 3.0