Repositories / more_nnsight.git

tests/test_steering_search.py

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch
18341 bytes · 91771e3ee41e
"""Unit tests for steering_search — CPU-only, no model required.""" from __future__ import annotations import json import sys import pytest import torch from more_nnsight.steering_search import ( PENALTY, SavedActivation, SteeringSearchConfig, SteeringSearchOutput, SteeringTrialResult, _decode_point, _group_paths_by_layer, best_trial, layer_sweep, steer_and_eval, steering_search, ) _ss = sys.modules["more_nnsight.steering_search"] # --------------------------------------------------------------------------- # _decode_point # --------------------------------------------------------------------------- def test_decode_deduplicates_and_removes_sentinel(): candidates = ["path_a", "path_b", "path_c"] selected, alpha = _decode_point([2, 0, 2, 1.5], n_candidates=3, candidate_paths=candidates) assert selected == ["path_a", "path_c"] assert alpha == 1.5 def test_decode_all_sentinel_returns_empty(): candidates = ["path_a", "path_b"] selected, alpha = _decode_point([2, 2, 0.5], n_candidates=2, candidate_paths=candidates) assert selected == [] assert alpha == 0.5 # --------------------------------------------------------------------------- # best_trial # --------------------------------------------------------------------------- def test_best_trial_returns_minimum_score(): trials = [ SteeringTrialResult(0, ["a"], 1.0, 5.0, [0, 1.0], ["r0"]), SteeringTrialResult(1, ["b"], 2.0, 1.0, [1, 2.0], ["r1"]), SteeringTrialResult(2, ["a", "b"], 1.5, 3.0, [0, 1, 1.5], ["r2"]), ] assert best_trial(trials).trial_index == 1 # --------------------------------------------------------------------------- # steering_search (with mocked internals) # --------------------------------------------------------------------------- class _FakeDirection: """Mimics SavedActivation enough for the objective function.""" def __init__(self, paths): self._paths = list(paths) def keys(self): return list(self._paths) def get(self, key): return torch.zeros(1, 8) def _patch_internals(monkeypatch, fake_vectors, generate_fn=None): """Patch _compute_steering_vectors, SavedActivation.from_pairs, _generate_batch (baseline), and _generate_steered_batch.""" monkeypatch.setattr(_ss, "_compute_steering_vectors", lambda *a, **kw: fake_vectors) def fake_from_pairs(*pairs): return _FakeDirection([p for p, _ in pairs]) monkeypatch.setattr(SavedActivation, "from_pairs", staticmethod(fake_from_pairs)) # Mock baseline generation. def baseline_batch_fn(model, prompts, max_new_tokens): return ["baseline response"] * len(prompts) monkeypatch.setattr(_ss, "_generate_batch", baseline_batch_fn) # Mock steered generation. if generate_fn is not None: def batch_fn(model, prompts, direction, alpha, max_new_tokens): return [generate_fn(model, p, direction, alpha, max_new_tokens) for p in prompts] else: def batch_fn(model, prompts, direction, alpha, max_new_tokens): return ["fake response"] * len(prompts) monkeypatch.setattr(_ss, "_generate_steered_batch", batch_fn) def test_empty_candidates_raises(): config = SteeringSearchConfig( candidate_paths=[], max_simultaneous_paths=2, alpha_range=(0.1, 5.0), n_calls=5, ) with pytest.raises(ValueError, match="candidate_paths must not be empty"): steering_search( config=config, model=None, positive_prompts=[], negative_prompts=[], eval_rows=[], score=lambda row, b, s: 0.0, max_new_tokens=10, ) def test_search_returns_correct_number_of_results(tmp_path, monkeypatch): fake_vectors = {"p0": torch.randn(1, 8), "p1": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) config = SteeringSearchConfig( candidate_paths=["p0", "p1"], max_simultaneous_paths=2, alpha_range=(0.1, 3.0), n_calls=15, n_initial_points=5, ) output = tmp_path / "trials.jsonl" out = steering_search( config=config, model=None, positive_prompts=["pos1"], negative_prompts=["neg1"], eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}], score=lambda row, b, s: 1.0, max_new_tokens=10, output_path=output, ) assert isinstance(out, SteeringSearchOutput) assert len(out.trials) == 15 assert len(out.baseline_responses) == 2 def test_search_finds_known_optimum(monkeypatch): candidates = [f"path_{i}" for i in range(4)] fake_vectors = {p: torch.randn(1, 8) for p in candidates} def fake_generate(model, prompt, direction, alpha, max_new_tokens): return ",".join(direction.keys()) _patch_internals(monkeypatch, fake_vectors, generate_fn=fake_generate) def score_fn(row, baseline, steered): paths = steered.split(",") if steered else [] return 0.0 if "path_2" in paths else 1.0 config = SteeringSearchConfig( candidate_paths=candidates, max_simultaneous_paths=2, alpha_range=(0.5, 5.0), n_calls=30, n_initial_points=10, ) out = steering_search( config=config, model=None, positive_prompts=["pos"], negative_prompts=["neg"], eval_rows=[{"prompt": "eval"}], score=score_fn, max_new_tokens=10, ) best = best_trial(out.trials) assert best.score == 0.0 assert "path_2" in best.selected_paths def test_jsonl_output_written_incrementally(tmp_path, monkeypatch): fake_vectors = {"p0": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) config = SteeringSearchConfig( candidate_paths=["p0"], max_simultaneous_paths=1, alpha_range=(0.1, 3.0), n_calls=8, n_initial_points=4, ) output = tmp_path / "out.jsonl" steering_search( config=config, model=None, positive_prompts=["pos"], negative_prompts=["neg"], eval_rows=[{"prompt": "eval"}], score=lambda row, b, s: 1.0, max_new_tokens=10, output_path=output, ) lines = output.read_text().strip().split("\n") assert len(lines) == 8 for line in lines: record = json.loads(line) assert "trial_index" in record assert "selected_paths" in record assert "alpha" in record assert "score" in record assert "raw_point" in record assert "responses" in record def test_empty_selection_gets_penalty(monkeypatch): fake_vectors = {"p0": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) config = SteeringSearchConfig( candidate_paths=["p0"], max_simultaneous_paths=2, alpha_range=(0.1, 3.0), n_calls=10, n_initial_points=5, ) out = steering_search( config=config, model=None, positive_prompts=["pos"], negative_prompts=["neg"], eval_rows=[{"prompt": "eval"}], score=lambda row, b, s: 0.0, max_new_tokens=10, ) penalty_trials = [r for r in out.trials if not r.selected_paths] for t in penalty_trials: assert t.score == PENALTY def test_score_receives_row_metadata(monkeypatch): """Score function receives the full row dict, not just the prompt.""" fake_vectors = {"p0": torch.randn(1, 8), "p1": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) received_rows = [] def score_fn(row, baseline, steered): received_rows.append(row) return 0.5 config = SteeringSearchConfig( candidate_paths=["p0", "p1"], max_simultaneous_paths=2, alpha_range=(0.1, 3.0), n_calls=5, n_initial_points=5, ) steering_search( config=config, model=None, positive_prompts=["pos"], negative_prompts=["neg"], eval_rows=[{"prompt": "eval", "task_id": "T1", "extra": 42}], score=score_fn, max_new_tokens=10, ) non_penalty = [r for r in received_rows if r is not None] assert len(non_penalty) > 0 assert non_penalty[0]["task_id"] == "T1" assert non_penalty[0]["extra"] == 42 # --------------------------------------------------------------------------- # steer_and_eval # --------------------------------------------------------------------------- def test_steer_and_eval_rejects_both_prompts_and_vectors(monkeypatch): fake_vectors = {"p0": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) with pytest.raises(ValueError, match="not both"): steer_and_eval( paths=["p0"], alpha=1.0, model=None, eval_rows=[{"prompt": "eval"}], score=lambda row, b, s: 0.0, max_new_tokens=10, positive_prompts=["pos"], negative_prompts=["neg"], steering_vectors=fake_vectors, ) def test_steer_and_eval_rejects_neither_prompts_nor_vectors(monkeypatch): fake_vectors = {"p0": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) with pytest.raises(ValueError, match="Provide either"): steer_and_eval( paths=["p0"], alpha=1.0, model=None, eval_rows=[{"prompt": "eval"}], score=lambda row, b, s: 0.0, max_new_tokens=10, ) def test_steer_and_eval_rejects_only_positive_prompts(monkeypatch): fake_vectors = {"p0": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) with pytest.raises(ValueError, match="both be provided together"): steer_and_eval( paths=["p0"], alpha=1.0, model=None, eval_rows=[{"prompt": "eval"}], score=lambda row, b, s: 0.0, max_new_tokens=10, positive_prompts=["pos"], ) def test_steer_and_eval_rejects_empty_paths(monkeypatch): fake_vectors = {} _patch_internals(monkeypatch, fake_vectors) with pytest.raises(ValueError, match="paths must not be empty"): steer_and_eval( paths=[], alpha=1.0, model=None, eval_rows=[{"prompt": "eval"}], score=lambda row, b, s: 0.0, max_new_tokens=10, steering_vectors=fake_vectors, ) def test_steer_and_eval_rejects_empty_eval_rows(monkeypatch): fake_vectors = {"p0": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) with pytest.raises(ValueError, match="eval_rows must not be empty"): steer_and_eval( paths=["p0"], alpha=1.0, model=None, eval_rows=[], score=lambda row, b, s: 0.0, max_new_tokens=10, steering_vectors=fake_vectors, ) def test_steer_and_eval_with_prompts(monkeypatch): fake_vectors = {"p0": torch.randn(1, 8), "p1": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) out = steer_and_eval( paths=["p0", "p1"], alpha=2.0, model=None, eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}], score=lambda row, b, s: 0.5, max_new_tokens=10, positive_prompts=["pos"], negative_prompts=["neg"], ) assert isinstance(out, SteeringSearchOutput) assert len(out.trials) == 1 assert len(out.baseline_responses) == 2 trial = out.trials[0] assert trial.trial_index == 0 assert trial.selected_paths == ["p0", "p1"] assert trial.alpha == 2.0 assert trial.score == 0.5 assert len(trial.responses) == 2 def test_steer_and_eval_with_steering_vectors(monkeypatch): fake_vectors = {"p0": torch.randn(1, 8)} # Ensure _compute_steering_vectors is NOT called when steering_vectors is passed. compute_called = [] def fake_compute(*a, **kw): compute_called.append(True) return fake_vectors monkeypatch.setattr(_ss, "_compute_steering_vectors", fake_compute) def fake_from_pairs(*pairs): return _FakeDirection([p for p, _ in pairs]) monkeypatch.setattr(SavedActivation, "from_pairs", staticmethod(fake_from_pairs)) monkeypatch.setattr(_ss, "_generate_batch", lambda model, prompts, max_new_tokens: ["baseline"] * len(prompts)) monkeypatch.setattr(_ss, "_generate_steered_batch", lambda model, prompts, direction, alpha, max_new_tokens: ["steered"] * len(prompts)) out = steer_and_eval( paths=["p0"], alpha=1.5, model=None, eval_rows=[{"prompt": "eval"}], score=lambda row, b, s: 0.0, max_new_tokens=10, steering_vectors=fake_vectors, ) assert not compute_called, "_compute_steering_vectors must not be called when steering_vectors is provided" assert len(out.trials) == 1 assert out.trials[0].alpha == 1.5 assert out.trials[0].score == 0.0 def test_steer_and_eval_jsonl_output(tmp_path, monkeypatch): fake_vectors = {"p0": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) output = tmp_path / "steer.jsonl" steer_and_eval( paths=["p0"], alpha=3.0, model=None, eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}], score=lambda row, b, s: 1.0, max_new_tokens=10, steering_vectors=fake_vectors, output_path=output, ) lines = output.read_text().strip().split("\n") assert len(lines) == 1 record = json.loads(lines[0]) assert record["trial_index"] == 0 assert record["selected_paths"] == ["p0"] assert record["alpha"] == 3.0 assert record["score"] == 1.0 def test_steer_and_eval_score_receives_row_metadata(monkeypatch): fake_vectors = {"p0": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) received = [] def score_fn(row, baseline, steered): received.append(row) return 0.0 steer_and_eval( paths=["p0"], alpha=1.0, model=None, eval_rows=[{"prompt": "eval", "task_id": "T42", "extra": 99}], score=score_fn, max_new_tokens=10, steering_vectors=fake_vectors, ) assert len(received) == 1 assert received[0]["task_id"] == "T42" assert received[0]["extra"] == 99 # --------------------------------------------------------------------------- # _group_paths_by_layer # --------------------------------------------------------------------------- def test_group_paths_same_layer_grouped_together(): paths = [ "model.transformer.h[10].output[-1]", "model.transformer.h[10].output[-2]", "model.transformer.h[11].output[-1]", ] groups = _group_paths_by_layer(paths) assert len(groups) == 2 assert groups[0] == [ "model.transformer.h[10].output[-1]", "model.transformer.h[10].output[-2]", ] assert groups[1] == ["model.transformer.h[11].output[-1]"] def test_group_paths_preserves_insertion_order(): paths = [ "model.transformer.h[5].output[-1]", "model.transformer.h[3].output[-1]", "model.transformer.h[5].output[-2]", ] groups = _group_paths_by_layer(paths) assert len(groups) == 2 assert groups[0][0].startswith("model.transformer.h[5]") assert groups[1][0].startswith("model.transformer.h[3]") def test_group_paths_no_token_index(): # Paths without a trailing [N] each form their own group. paths = ["model.transformer.h[10].output", "model.transformer.h[11].output"] groups = _group_paths_by_layer(paths) assert len(groups) == 2 # --------------------------------------------------------------------------- # layer_sweep # --------------------------------------------------------------------------- def test_layer_sweep_groups_by_layer(monkeypatch): # Two paths in the same layer → one trial; one path in a different layer → second trial. paths = [ "model.transformer.h[10].output[-1]", "model.transformer.h[10].output[-2]", "model.transformer.h[11].output[-1]", ] fake_vectors = {p: torch.randn(1, 8) for p in paths} _patch_internals(monkeypatch, fake_vectors) out = layer_sweep( paths=paths, alpha=2.0, model=None, positive_prompts=["pos"], negative_prompts=["neg"], eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}], score=lambda row, b, s: 0.5, max_new_tokens=10, ) assert len(out.trials) == 2 assert out.trials[0].selected_paths == [ "model.transformer.h[10].output[-1]", "model.transformer.h[10].output[-2]", ] assert out.trials[1].selected_paths == ["model.transformer.h[11].output[-1]"] for i, trial in enumerate(out.trials): assert trial.trial_index == i assert trial.alpha == 2.0 assert trial.raw_point == [i, 2.0] def test_layer_sweep_empty_paths_raises(): with pytest.raises(ValueError, match="paths must not be empty"): layer_sweep( paths=[], alpha=1.0, model=None, positive_prompts=[], negative_prompts=[], eval_rows=[], score=lambda row, b, s: 0.0, max_new_tokens=10, ) def test_layer_sweep_jsonl_output(tmp_path, monkeypatch): import json paths = [ "model.transformer.h[10].output[-1]", "model.transformer.h[10].output[-2]", "model.transformer.h[11].output[-1]", ] fake_vectors = {p: torch.randn(1, 8) for p in paths} _patch_internals(monkeypatch, fake_vectors) output = tmp_path / "sweep.jsonl" layer_sweep( paths=paths, alpha=3.0, model=None, positive_prompts=["pos"], negative_prompts=["neg"], eval_rows=[{"prompt": "eval"}], score=lambda row, b, s: 0.0, max_new_tokens=10, output_path=output, ) lines = output.read_text().strip().split("\n") assert len(lines) == 2 # two groups → two trials for line in lines: record = json.loads(line) assert "trial_index" in record assert "selected_paths" in record assert record["alpha"] == 3.0