Repositories / more_nnsight.git
more_nnsight.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
@@ -5,6 +5,7 @@ from .steering_search import ( SteeringTrialResult, best_trial, layer_sweep, + steer_and_eval, steering_search, ) @@ -17,6 +18,7 @@ __all__ = [ "SteeringTrialResult", "best_trial", "layer_sweep", + "steer_and_eval", "steering_search", ]
@@ -365,6 +365,104 @@ def steering_search( return SteeringSearchOutput(baseline_responses=baseline_responses, trials=trials) +@torch.no_grad() +def steer_and_eval( + paths: list[str], + alpha: float, + model: Any, + eval_rows: list[dict], + score: Callable[[dict, str, str], float], + max_new_tokens: int, + *, + positive_prompts: list[str] | None = None, + negative_prompts: list[str] | None = None, + steering_vectors: dict[str, torch.Tensor] | None = None, + output_path: Path | None = None, + batch_size: int = 1, +) -> SteeringSearchOutput: + """Steer on the given paths at a fixed alpha and evaluate eval_rows. + + Provide either ``positive_prompts`` and ``negative_prompts`` together to + compute steering vectors, or pass pre-computed ``steering_vectors`` directly. + Exactly one of the two options must be used. + + Args: + paths: Activation paths to steer on. + alpha: Steering coefficient applied at every path. + model: An nnsight LanguageModel (already loaded). + eval_rows: Rows to evaluate. Each row must contain a ``"prompt"`` key; + other keys are opaque metadata passed through to ``score``. + score: Scoring function. Takes ``(row, baseline_response, + steered_response)`` and returns a float (lower is better). + max_new_tokens: Maximum tokens to generate per prompt. + positive_prompts: Prompts for the "positive" class used to compute + steering vectors. Must be paired with ``negative_prompts``. + Mutually exclusive with ``steering_vectors``. + negative_prompts: Prompts for the "negative" class used to compute + steering vectors. Must be paired with ``positive_prompts``. + Mutually exclusive with ``steering_vectors``. + steering_vectors: Pre-computed steering vectors keyed by path. + Mutually exclusive with ``positive_prompts`` / ``negative_prompts``. + output_path: If provided, write the single trial result as one JSONL line. + batch_size: Number of prompts to generate in parallel per batch. + + Returns: + SteeringSearchOutput with baseline_responses and a single trial. + """ + prompts_provided = positive_prompts is not None or negative_prompts is not None + vectors_provided = steering_vectors is not None + if prompts_provided and vectors_provided: + raise ValueError( + "Provide either (positive_prompts + negative_prompts) or steering_vectors, not both." + ) + if not prompts_provided and not vectors_provided: + raise ValueError( + "Provide either (positive_prompts + negative_prompts) or steering_vectors." + ) + if prompts_provided and (positive_prompts is None or negative_prompts is None): + raise ValueError( + "positive_prompts and negative_prompts must both be provided together." + ) + + if not paths: + raise ValueError("paths must not be empty") + if not eval_rows: + raise ValueError("eval_rows must not be empty") + + if vectors_provided: + vectors = steering_vectors + else: + vectors = _compute_steering_vectors(model, paths, positive_prompts, negative_prompts) + + eval_prompts = [row["prompt"] for row in eval_rows] + baseline_responses: list[str] = [] + for i in range(0, len(eval_prompts), batch_size): + batch = eval_prompts[i : i + batch_size] + baseline_responses.extend(_generate_batch(model, batch, max_new_tokens)) + + if output_path is not None: + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text("") + + responses, trial_score = _evaluate_steered( + paths, alpha, vectors, eval_rows, baseline_responses, model, score, max_new_tokens, batch_size, + ) + + result = SteeringTrialResult( + trial_index=0, + selected_paths=paths, + alpha=alpha, + score=trial_score, + raw_point=[alpha], + responses=responses, + ) + + if output_path is not None: + _append_jsonl(output_path, result) + + return SteeringSearchOutput(baseline_responses=baseline_responses, trials=[result]) + + def _group_paths_by_layer(paths: list[str]) -> list[list[str]]: """Group paths that share the same layer by stripping the trailing token index.
@@ -18,6 +18,7 @@ from more_nnsight.steering_search import ( _group_paths_by_layer, best_trial, layer_sweep, + steer_and_eval, steering_search, ) @@ -275,6 +276,198 @@ def test_score_receives_row_metadata(monkeypatch): # --------------------------------------------------------------------------- +# steer_and_eval +# --------------------------------------------------------------------------- + + +def test_steer_and_eval_rejects_both_prompts_and_vectors(monkeypatch): + fake_vectors = {"p0": torch.randn(1, 8)} + _patch_internals(monkeypatch, fake_vectors) + with pytest.raises(ValueError, match="not both"): + steer_and_eval( + paths=["p0"], + alpha=1.0, + model=None, + eval_rows=[{"prompt": "eval"}], + score=lambda row, b, s: 0.0, + max_new_tokens=10, + positive_prompts=["pos"], + negative_prompts=["neg"], + steering_vectors=fake_vectors, + ) + + +def test_steer_and_eval_rejects_neither_prompts_nor_vectors(monkeypatch): + fake_vectors = {"p0": torch.randn(1, 8)} + _patch_internals(monkeypatch, fake_vectors) + with pytest.raises(ValueError, match="Provide either"): + steer_and_eval( + paths=["p0"], + alpha=1.0, + model=None, + eval_rows=[{"prompt": "eval"}], + score=lambda row, b, s: 0.0, + max_new_tokens=10, + ) + + +def test_steer_and_eval_rejects_only_positive_prompts(monkeypatch): + fake_vectors = {"p0": torch.randn(1, 8)} + _patch_internals(monkeypatch, fake_vectors) + with pytest.raises(ValueError, match="both be provided together"): + steer_and_eval( + paths=["p0"], + alpha=1.0, + model=None, + eval_rows=[{"prompt": "eval"}], + score=lambda row, b, s: 0.0, + max_new_tokens=10, + positive_prompts=["pos"], + ) + + +def test_steer_and_eval_rejects_empty_paths(monkeypatch): + fake_vectors = {} + _patch_internals(monkeypatch, fake_vectors) + with pytest.raises(ValueError, match="paths must not be empty"): + steer_and_eval( + paths=[], + alpha=1.0, + model=None, + eval_rows=[{"prompt": "eval"}], + score=lambda row, b, s: 0.0, + max_new_tokens=10, + steering_vectors=fake_vectors, + ) + + +def test_steer_and_eval_rejects_empty_eval_rows(monkeypatch): + fake_vectors = {"p0": torch.randn(1, 8)} + _patch_internals(monkeypatch, fake_vectors) + with pytest.raises(ValueError, match="eval_rows must not be empty"): + steer_and_eval( + paths=["p0"], + alpha=1.0, + model=None, + eval_rows=[], + score=lambda row, b, s: 0.0, + max_new_tokens=10, + steering_vectors=fake_vectors, + ) + + +def test_steer_and_eval_with_prompts(monkeypatch): + fake_vectors = {"p0": torch.randn(1, 8), "p1": torch.randn(1, 8)} + _patch_internals(monkeypatch, fake_vectors) + + out = steer_and_eval( + paths=["p0", "p1"], + alpha=2.0, + model=None, + eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}], + score=lambda row, b, s: 0.5, + max_new_tokens=10, + positive_prompts=["pos"], + negative_prompts=["neg"], + ) + + assert isinstance(out, SteeringSearchOutput) + assert len(out.trials) == 1 + assert len(out.baseline_responses) == 2 + trial = out.trials[0] + assert trial.trial_index == 0 + assert trial.selected_paths == ["p0", "p1"] + assert trial.alpha == 2.0 + assert trial.score == 0.5 + assert len(trial.responses) == 2 + + +def test_steer_and_eval_with_steering_vectors(monkeypatch): + fake_vectors = {"p0": torch.randn(1, 8)} + + # Ensure _compute_steering_vectors is NOT called when steering_vectors is passed. + compute_called = [] + + def fake_compute(*a, **kw): + compute_called.append(True) + return fake_vectors + + monkeypatch.setattr(_ss, "_compute_steering_vectors", fake_compute) + + def fake_from_pairs(*pairs): + return _FakeDirection([p for p, _ in pairs]) + + monkeypatch.setattr(SavedActivation, "from_pairs", staticmethod(fake_from_pairs)) + monkeypatch.setattr(_ss, "_generate_batch", lambda model, prompts, max_new_tokens: ["baseline"] * len(prompts)) + monkeypatch.setattr(_ss, "_generate_steered_batch", lambda model, prompts, direction, alpha, max_new_tokens: ["steered"] * len(prompts)) + + out = steer_and_eval( + paths=["p0"], + alpha=1.5, + model=None, + eval_rows=[{"prompt": "eval"}], + score=lambda row, b, s: 0.0, + max_new_tokens=10, + steering_vectors=fake_vectors, + ) + + assert not compute_called, "_compute_steering_vectors must not be called when steering_vectors is provided" + assert len(out.trials) == 1 + assert out.trials[0].alpha == 1.5 + assert out.trials[0].score == 0.0 + + +def test_steer_and_eval_jsonl_output(tmp_path, monkeypatch): + fake_vectors = {"p0": torch.randn(1, 8)} + _patch_internals(monkeypatch, fake_vectors) + + output = tmp_path / "steer.jsonl" + steer_and_eval( + paths=["p0"], + alpha=3.0, + model=None, + eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}], + score=lambda row, b, s: 1.0, + max_new_tokens=10, + steering_vectors=fake_vectors, + output_path=output, + ) + + lines = output.read_text().strip().split("\n") + assert len(lines) == 1 + record = json.loads(lines[0]) + assert record["trial_index"] == 0 + assert record["selected_paths"] == ["p0"] + assert record["alpha"] == 3.0 + assert record["score"] == 1.0 + + +def test_steer_and_eval_score_receives_row_metadata(monkeypatch): + fake_vectors = {"p0": torch.randn(1, 8)} + _patch_internals(monkeypatch, fake_vectors) + + received = [] + + def score_fn(row, baseline, steered): + received.append(row) + return 0.0 + + steer_and_eval( + paths=["p0"], + alpha=1.0, + model=None, + eval_rows=[{"prompt": "eval", "task_id": "T42", "extra": 99}], + score=score_fn, + max_new_tokens=10, + steering_vectors=fake_vectors, + ) + + assert len(received) == 1 + assert received[0]["task_id"] == "T42" + assert received[0]["extra"] == 99 + + +# --------------------------------------------------------------------------- # _group_paths_by_layer # ---------------------------------------------------------------------------