Repositories / more_nnsight.git

more_nnsight.git

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch

Add steer_and_eval for single-configuration evaluation

steer_and_eval runs one steering trial at fixed paths and alpha,
accepts either (positive_prompts + negative_prompts) or pre-computed
steering_vectors, and returns a SteeringSearchOutput with one trial.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Author
Arjun Guha <a.guha@northeastern.edu>
Date
2026-04-03 20:15:22 -0400
Commit
4888a9d74b380057ee16930f82875e9bcce82cb7
src/more_nnsight/__init__.py
index edd5a8a..c629e33 100644
--- a/src/more_nnsight/__init__.py
+++ b/src/more_nnsight/__init__.py
@@ -5,6 +5,7 @@ from .steering_search import (
     SteeringTrialResult,
     best_trial,
     layer_sweep,
+    steer_and_eval,
     steering_search,
 )
 
@@ -17,6 +18,7 @@ __all__ = [
     "SteeringTrialResult",
     "best_trial",
     "layer_sweep",
+    "steer_and_eval",
     "steering_search",
 ]
 
src/more_nnsight/steering_search.py
index 3a4b8e6..11bedd8 100644
--- a/src/more_nnsight/steering_search.py
+++ b/src/more_nnsight/steering_search.py
@@ -365,6 +365,104 @@ def steering_search(
     return SteeringSearchOutput(baseline_responses=baseline_responses, trials=trials)
 
 
+@torch.no_grad()
+def steer_and_eval(
+    paths: list[str],
+    alpha: float,
+    model: Any,
+    eval_rows: list[dict],
+    score: Callable[[dict, str, str], float],
+    max_new_tokens: int,
+    *,
+    positive_prompts: list[str] | None = None,
+    negative_prompts: list[str] | None = None,
+    steering_vectors: dict[str, torch.Tensor] | None = None,
+    output_path: Path | None = None,
+    batch_size: int = 1,
+) -> SteeringSearchOutput:
+    """Steer on the given paths at a fixed alpha and evaluate eval_rows.
+
+    Provide either ``positive_prompts`` and ``negative_prompts`` together to
+    compute steering vectors, or pass pre-computed ``steering_vectors`` directly.
+    Exactly one of the two options must be used.
+
+    Args:
+        paths: Activation paths to steer on.
+        alpha: Steering coefficient applied at every path.
+        model: An nnsight LanguageModel (already loaded).
+        eval_rows: Rows to evaluate. Each row must contain a ``"prompt"`` key;
+            other keys are opaque metadata passed through to ``score``.
+        score: Scoring function. Takes ``(row, baseline_response,
+            steered_response)`` and returns a float (lower is better).
+        max_new_tokens: Maximum tokens to generate per prompt.
+        positive_prompts: Prompts for the "positive" class used to compute
+            steering vectors. Must be paired with ``negative_prompts``.
+            Mutually exclusive with ``steering_vectors``.
+        negative_prompts: Prompts for the "negative" class used to compute
+            steering vectors. Must be paired with ``positive_prompts``.
+            Mutually exclusive with ``steering_vectors``.
+        steering_vectors: Pre-computed steering vectors keyed by path.
+            Mutually exclusive with ``positive_prompts`` / ``negative_prompts``.
+        output_path: If provided, write the single trial result as one JSONL line.
+        batch_size: Number of prompts to generate in parallel per batch.
+
+    Returns:
+        SteeringSearchOutput with baseline_responses and a single trial.
+    """
+    prompts_provided = positive_prompts is not None or negative_prompts is not None
+    vectors_provided = steering_vectors is not None
+    if prompts_provided and vectors_provided:
+        raise ValueError(
+            "Provide either (positive_prompts + negative_prompts) or steering_vectors, not both."
+        )
+    if not prompts_provided and not vectors_provided:
+        raise ValueError(
+            "Provide either (positive_prompts + negative_prompts) or steering_vectors."
+        )
+    if prompts_provided and (positive_prompts is None or negative_prompts is None):
+        raise ValueError(
+            "positive_prompts and negative_prompts must both be provided together."
+        )
+
+    if not paths:
+        raise ValueError("paths must not be empty")
+    if not eval_rows:
+        raise ValueError("eval_rows must not be empty")
+
+    if vectors_provided:
+        vectors = steering_vectors
+    else:
+        vectors = _compute_steering_vectors(model, paths, positive_prompts, negative_prompts)
+
+    eval_prompts = [row["prompt"] for row in eval_rows]
+    baseline_responses: list[str] = []
+    for i in range(0, len(eval_prompts), batch_size):
+        batch = eval_prompts[i : i + batch_size]
+        baseline_responses.extend(_generate_batch(model, batch, max_new_tokens))
+
+    if output_path is not None:
+        output_path.parent.mkdir(parents=True, exist_ok=True)
+        output_path.write_text("")
+
+    responses, trial_score = _evaluate_steered(
+        paths, alpha, vectors, eval_rows, baseline_responses, model, score, max_new_tokens, batch_size,
+    )
+
+    result = SteeringTrialResult(
+        trial_index=0,
+        selected_paths=paths,
+        alpha=alpha,
+        score=trial_score,
+        raw_point=[alpha],
+        responses=responses,
+    )
+
+    if output_path is not None:
+        _append_jsonl(output_path, result)
+
+    return SteeringSearchOutput(baseline_responses=baseline_responses, trials=[result])
+
+
 def _group_paths_by_layer(paths: list[str]) -> list[list[str]]:
     """Group paths that share the same layer by stripping the trailing token index.
 
tests/test_steering_search.py
index 495d4f9..91771e3 100644
--- a/tests/test_steering_search.py
+++ b/tests/test_steering_search.py
@@ -18,6 +18,7 @@ from more_nnsight.steering_search import (
     _group_paths_by_layer,
     best_trial,
     layer_sweep,
+    steer_and_eval,
     steering_search,
 )
 
@@ -275,6 +276,198 @@ def test_score_receives_row_metadata(monkeypatch):
 
 
 # ---------------------------------------------------------------------------
+# steer_and_eval
+# ---------------------------------------------------------------------------
+
+
+def test_steer_and_eval_rejects_both_prompts_and_vectors(monkeypatch):
+    fake_vectors = {"p0": torch.randn(1, 8)}
+    _patch_internals(monkeypatch, fake_vectors)
+    with pytest.raises(ValueError, match="not both"):
+        steer_and_eval(
+            paths=["p0"],
+            alpha=1.0,
+            model=None,
+            eval_rows=[{"prompt": "eval"}],
+            score=lambda row, b, s: 0.0,
+            max_new_tokens=10,
+            positive_prompts=["pos"],
+            negative_prompts=["neg"],
+            steering_vectors=fake_vectors,
+        )
+
+
+def test_steer_and_eval_rejects_neither_prompts_nor_vectors(monkeypatch):
+    fake_vectors = {"p0": torch.randn(1, 8)}
+    _patch_internals(monkeypatch, fake_vectors)
+    with pytest.raises(ValueError, match="Provide either"):
+        steer_and_eval(
+            paths=["p0"],
+            alpha=1.0,
+            model=None,
+            eval_rows=[{"prompt": "eval"}],
+            score=lambda row, b, s: 0.0,
+            max_new_tokens=10,
+        )
+
+
+def test_steer_and_eval_rejects_only_positive_prompts(monkeypatch):
+    fake_vectors = {"p0": torch.randn(1, 8)}
+    _patch_internals(monkeypatch, fake_vectors)
+    with pytest.raises(ValueError, match="both be provided together"):
+        steer_and_eval(
+            paths=["p0"],
+            alpha=1.0,
+            model=None,
+            eval_rows=[{"prompt": "eval"}],
+            score=lambda row, b, s: 0.0,
+            max_new_tokens=10,
+            positive_prompts=["pos"],
+        )
+
+
+def test_steer_and_eval_rejects_empty_paths(monkeypatch):
+    fake_vectors = {}
+    _patch_internals(monkeypatch, fake_vectors)
+    with pytest.raises(ValueError, match="paths must not be empty"):
+        steer_and_eval(
+            paths=[],
+            alpha=1.0,
+            model=None,
+            eval_rows=[{"prompt": "eval"}],
+            score=lambda row, b, s: 0.0,
+            max_new_tokens=10,
+            steering_vectors=fake_vectors,
+        )
+
+
+def test_steer_and_eval_rejects_empty_eval_rows(monkeypatch):
+    fake_vectors = {"p0": torch.randn(1, 8)}
+    _patch_internals(monkeypatch, fake_vectors)
+    with pytest.raises(ValueError, match="eval_rows must not be empty"):
+        steer_and_eval(
+            paths=["p0"],
+            alpha=1.0,
+            model=None,
+            eval_rows=[],
+            score=lambda row, b, s: 0.0,
+            max_new_tokens=10,
+            steering_vectors=fake_vectors,
+        )
+
+
+def test_steer_and_eval_with_prompts(monkeypatch):
+    fake_vectors = {"p0": torch.randn(1, 8), "p1": torch.randn(1, 8)}
+    _patch_internals(monkeypatch, fake_vectors)
+
+    out = steer_and_eval(
+        paths=["p0", "p1"],
+        alpha=2.0,
+        model=None,
+        eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}],
+        score=lambda row, b, s: 0.5,
+        max_new_tokens=10,
+        positive_prompts=["pos"],
+        negative_prompts=["neg"],
+    )
+
+    assert isinstance(out, SteeringSearchOutput)
+    assert len(out.trials) == 1
+    assert len(out.baseline_responses) == 2
+    trial = out.trials[0]
+    assert trial.trial_index == 0
+    assert trial.selected_paths == ["p0", "p1"]
+    assert trial.alpha == 2.0
+    assert trial.score == 0.5
+    assert len(trial.responses) == 2
+
+
+def test_steer_and_eval_with_steering_vectors(monkeypatch):
+    fake_vectors = {"p0": torch.randn(1, 8)}
+
+    # Ensure _compute_steering_vectors is NOT called when steering_vectors is passed.
+    compute_called = []
+
+    def fake_compute(*a, **kw):
+        compute_called.append(True)
+        return fake_vectors
+
+    monkeypatch.setattr(_ss, "_compute_steering_vectors", fake_compute)
+
+    def fake_from_pairs(*pairs):
+        return _FakeDirection([p for p, _ in pairs])
+
+    monkeypatch.setattr(SavedActivation, "from_pairs", staticmethod(fake_from_pairs))
+    monkeypatch.setattr(_ss, "_generate_batch", lambda model, prompts, max_new_tokens: ["baseline"] * len(prompts))
+    monkeypatch.setattr(_ss, "_generate_steered_batch", lambda model, prompts, direction, alpha, max_new_tokens: ["steered"] * len(prompts))
+
+    out = steer_and_eval(
+        paths=["p0"],
+        alpha=1.5,
+        model=None,
+        eval_rows=[{"prompt": "eval"}],
+        score=lambda row, b, s: 0.0,
+        max_new_tokens=10,
+        steering_vectors=fake_vectors,
+    )
+
+    assert not compute_called, "_compute_steering_vectors must not be called when steering_vectors is provided"
+    assert len(out.trials) == 1
+    assert out.trials[0].alpha == 1.5
+    assert out.trials[0].score == 0.0
+
+
+def test_steer_and_eval_jsonl_output(tmp_path, monkeypatch):
+    fake_vectors = {"p0": torch.randn(1, 8)}
+    _patch_internals(monkeypatch, fake_vectors)
+
+    output = tmp_path / "steer.jsonl"
+    steer_and_eval(
+        paths=["p0"],
+        alpha=3.0,
+        model=None,
+        eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}],
+        score=lambda row, b, s: 1.0,
+        max_new_tokens=10,
+        steering_vectors=fake_vectors,
+        output_path=output,
+    )
+
+    lines = output.read_text().strip().split("\n")
+    assert len(lines) == 1
+    record = json.loads(lines[0])
+    assert record["trial_index"] == 0
+    assert record["selected_paths"] == ["p0"]
+    assert record["alpha"] == 3.0
+    assert record["score"] == 1.0
+
+
+def test_steer_and_eval_score_receives_row_metadata(monkeypatch):
+    fake_vectors = {"p0": torch.randn(1, 8)}
+    _patch_internals(monkeypatch, fake_vectors)
+
+    received = []
+
+    def score_fn(row, baseline, steered):
+        received.append(row)
+        return 0.0
+
+    steer_and_eval(
+        paths=["p0"],
+        alpha=1.0,
+        model=None,
+        eval_rows=[{"prompt": "eval", "task_id": "T42", "extra": 99}],
+        score=score_fn,
+        max_new_tokens=10,
+        steering_vectors=fake_vectors,
+    )
+
+    assert len(received) == 1
+    assert received[0]["task_id"] == "T42"
+    assert received[0]["extra"] == 99
+
+
+# ---------------------------------------------------------------------------
 # _group_paths_by_layer
 # ---------------------------------------------------------------------------