Repositories / more_nnsight.git

more_nnsight.git

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch

Change steering_search score API to three-argument (row, baseline, steered)

The score callback now receives the full eval row dict (with opaque metadata)
plus both baseline and steered responses. eval_prompts replaced by eval_rows
(list[dict] with "prompt" key). Removes ineligible_score and pre-filtering
from more_nnsight — callers handle eligibility themselves.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Author
Arjun Guha <a.guha@northeastern.edu>
Date
2026-04-03 12:06:55 -0400
Commit
9634ecc45094c3ed1f6f3069471f98e63ffdc117
scripts/equivalence.py
new file mode 100644
index 0000000..98873fe
--- /dev/null
+++ b/scripts/equivalence.py
@@ -0,0 +1,270 @@
+"""
+Establishing the equivalence between activation steering approaches in nnsight.
+
+This script investigates how nnsight's .generate() applies interventions and confirms
+that using position -1 in a plain invoke context is equivalent to using the concrete
+last-prompt-token index with tracer.iter[0].
+
+Key findings established by these experiments:
+
+1. INTERVENTION FIRES ONCE (on prompt pass only)
+   nnsight's plain `with tracer.invoke(prompt):` fires the intervention graph exactly
+   once, during the first forward pass over the full prompt. Subsequent generation
+   steps use the KV cache and process only 1 new token per step — the intervention
+   does NOT fire again.
+
+2. POSITION -1 CORRECTLY TARGETS THE LAST PROMPT TOKEN
+   Since the intervention fires once when the full prompt (seq_len=4) is being
+   processed, `activation[-1]` resolves to position 3 (the last prompt token).
+   This is exactly what we want for steering.
+
+3. THE EFFECT PERSISTS VIA THE KV CACHE
+   Corrupting a hidden state at position 3 during the prompt pass alters the
+   key/value pairs stored in the KV cache for that position. All subsequent
+   generation steps attend to these corrupted values, so the effect propagates
+   through the entire generation without re-applying the intervention.
+
+4. iter[0] WITH CONCRETE POSITION GIVES IDENTICAL RESULTS
+   Applying the intervention at step 0 only (via tracer.iter[0]) with position n-1
+   is bit-for-bit identical to using -1 in the plain invoke context.
+
+5. iter[:] FAILS FOR POSITION > 0 AT STEPS 1+
+   At generation steps 1 and beyond (KV cache steps), the activation tensor has
+   shape [1, hidden_size] — only the new token. Position 3 is out of bounds.
+   This is proof that nnsight uses the KV cache during generation.
+
+Usage:
+    uv run --no-sync python scripts/equivalence.py
+
+Expected output (GPT-2, prompt "The weather today is"):
+
+    === Experiment 1: Which prompt position matters? ===
+    Prompt: 4 tokens ['The', 'Ġweather', 'Ġtoday', 'Ġis']
+    Normal (no zero)      :  very good, and we
+    Zero [0]              :  very good, and we
+    Zero [1]              :  very good, and we
+    Zero [2]              :  very good, and we
+    Zero [3] = [-1]       : , of course, very
+
+    Only zeroing the last prompt token (position 3 = -1) changes the output
+    when using a late layer (e.g. layer 11). Positions 0-2 have no effect
+    because the last prompt token is what directly determines the next-token
+    logits in an autoregressive model. (At early layers like layer 0, all
+    positions matter because the corruption propagates forward through all
+    subsequent layers.)
+
+    === Experiment 2: KV cache persistence ===
+    Steered 6 tokens: ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ', 'Ċ']
+    Context for trace: ['The', 'Ġweather', 'Ġtoday', 'Ġis', 'Ċ', 'Ċ', 'Ċ', '"']
+    Actual 5th token from generate:   '\n' (id=198)
+    Predicted by full trace at pos 3: '\n' (id=198)
+    KV cache persistent == full recompute: True
+
+    === Experiment 3: iter[0] equivalence ===
+      iter[0] step shape: torch.Size([4, 768])
+    Approach 1 (-1, plain invoke): ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ']
+    Approach 2 (iter[0], pos 3):   ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ']
+    Identical: True
+
+    === Experiment 4: iter[:] out-of-bounds at steps 1+ ===
+    Shapes seen before error: [(0, (4, 768)), (1, (1, 768))]
+    IndexError at step 1+: index 3 is out of bounds for dimension 0 with size 1
+      -> Step 0 has seq_len=4, position 3 valid.
+      -> Steps 1+ have seq_len=1 (KV cache), position 3 out of bounds.
+
+    === Experiment 5: Unsteered differs from steered ===
+    Unsteered: ['Ġvery', 'Ġgood', ',', 'Ġand', 'Ġwe']
+    Steered:   ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ']
+    Different: True
+"""
+
+import torch
+torch.cuda.memory.set_per_process_memory_fraction(0.8)
+
+from nnsight import LanguageModel
+
+model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
+tokenizer = model.tokenizer
+if tokenizer.pad_token is None:
+    tokenizer.pad_token = tokenizer.eos_token
+
+prompt = "The weather today is"
+prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"][0]
+n = len(prompt_ids)
+
+
+# ---------------------------------------------------------------------------
+# Experiment 1: Which token position matters?
+#
+# We zero layer 11's output at each individual token position to determine
+# which positions influence the generated output. At a late layer like 11,
+# only the last prompt token (position n-1 = -1) changes the output; earlier
+# positions do not. This is because causal attention means the last position
+# is the only one whose representation directly feeds into the next-token
+# logits. (At early layers like layer 0, corrupting any position propagates
+# through all subsequent layers and changes the output.)
+# ---------------------------------------------------------------------------
+
+print("=" * 60)
+print("Experiment 1: Which prompt position matters? (layer 11)")
+print("=" * 60)
+print(f"Prompt: {n} tokens {tokenizer.convert_ids_to_tokens(prompt_ids.tolist())}")
+
+for label, idx in [
+    ("Normal (no zero)", None),
+    ("Zero [0]", 0),
+    ("Zero [1]", 1),
+    ("Zero [2]", 2),
+    ("Zero [3] = [-1]", 3),
+]:
+    with model.generate(max_new_tokens=5, do_sample=False) as tracer:
+        with tracer.invoke(prompt):
+            if idx is not None:
+                model.transformer.h[11].output[0][idx] = 0.0
+            out_ids = model.generator.output.save()
+    response = tokenizer.decode(out_ids[0][n:], skip_special_tokens=True)
+    print(f"{label:22s}: {repr(response)}")
+
+
+# ---------------------------------------------------------------------------
+# Experiment 2: KV cache persistence.
+#
+# We confirm that the intervention on the prompt pass persists via the KV
+# cache through all subsequent generation steps. Specifically:
+#   - Run generate from P with zero at position n-1, collect 6 steered tokens.
+#   - Do a fresh full forward pass over P + first 4 steered tokens, with the
+#     same zero at position n-1.
+#   - The full-pass logits at the last position should predict the same 5th
+#     token that generate produced.
+#
+# This works because transformers with causal attention are equivalent whether
+# computed incrementally (KV cache) or all at once, as long as the intervention
+# is applied at the same absolute position.
+# ---------------------------------------------------------------------------
+
+print()
+print("=" * 60)
+print("Experiment 2: KV cache persistence")
+print("=" * 60)
+
+with model.generate(max_new_tokens=6, do_sample=False) as tracer:
+    with tracer.invoke(prompt):
+        h = model.transformer.h[0].output[0]
+        model.transformer.h[0].output[0][-1] = torch.zeros_like(h[-1])
+        out_ids = model.generator.output.save()
+
+steered = out_ids[0][n:].tolist()
+t5_actual = steered[4]
+print(f"Steered 6 tokens: {tokenizer.convert_ids_to_tokens(steered)}")
+
+# Full trace over P + first 4 steered tokens, zero at the same absolute position.
+ctx = torch.cat([prompt_ids, torch.tensor(steered[:4])]).unsqueeze(0)
+print(f"Context for trace: {tokenizer.convert_ids_to_tokens(ctx[0].tolist())}")
+
+with torch.no_grad():
+    with model.trace() as tracer:
+        with tracer.invoke(ctx):
+            h = model.transformer.h[0].output[0]
+            model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1])
+            logits_proxy = model.lm_head.output.save()
+
+logits = torch.as_tensor(logits_proxy).float().cpu()   # shape [1, seq, vocab]
+last_logits = logits[0, -1]                             # logits predicting token after ctx
+t5_predicted = torch.argmax(last_logits).item()
+
+print(f"Actual 5th token from generate:   {repr(tokenizer.decode([t5_actual]))} (id={t5_actual})")
+print(f"Predicted by full trace at pos {n-1}: {repr(tokenizer.decode([t5_predicted]))} (id={t5_predicted})")
+print(f"KV cache persistent == full recompute: {t5_actual == t5_predicted}")
+
+
+# ---------------------------------------------------------------------------
+# Experiment 3: iter[0] with concrete position == -1 in plain invoke.
+#
+# tracer.iter[0] applies the intervention graph only at generation step 0,
+# which is the prompt pass. At this step the activation has shape
+# [seq_len, hidden_size] so position n-1 is valid. This is bit-for-bit
+# identical to using -1 in the plain invoke context.
+# ---------------------------------------------------------------------------
+
+print()
+print("=" * 60)
+print("Experiment 3: iter[0] equivalence")
+print("=" * 60)
+
+# Approach 1: -1 in plain invoke (fires once on prompt pass)
+with model.generate(max_new_tokens=5, do_sample=False) as tracer:
+    with tracer.invoke(prompt):
+        h = model.transformer.h[0].output[0]
+        model.transformer.h[0].output[0][-1] = torch.zeros_like(h[-1])
+        out1 = model.generator.output.save()
+tokens1 = out1[0][n:].tolist()
+
+# Approach 2: iter[0] with concrete position n-1 (fires once at step 0)
+with model.generate(max_new_tokens=5, do_sample=False) as tracer:
+    with tracer.invoke(prompt):
+        for step in tracer.iter[0]:
+            h = model.transformer.h[0].output[0]
+            print(f"  iter[0] step shape: {h.shape}")
+            model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1])
+        out2 = model.generator.output.save()
+tokens2 = out2[0][n:].tolist()
+
+print(f"Approach 1 (-1, plain invoke):  {tokenizer.convert_ids_to_tokens(tokens1)}")
+print(f"Approach 2 (iter[0], pos {n-1}):   {tokenizer.convert_ids_to_tokens(tokens2)}")
+print(f"Identical: {tokens1 == tokens2}")
+
+
+# ---------------------------------------------------------------------------
+# Experiment 4: iter[:] fails at steps 1+ for position > 0.
+#
+# tracer.iter[:] (= tracer.all()) fires the intervention at every generation
+# step. At step 0 the full prompt is processed (shape [seq_len, hidden]).
+# At steps 1+ nnsight uses the KV cache, processing only the new token
+# (shape [1, hidden]). Trying to access position n-1 at those steps raises
+# an IndexError, proving the KV cache is active.
+# ---------------------------------------------------------------------------
+
+print()
+print("=" * 60)
+print("Experiment 4: iter[:] shape at each step")
+print("=" * 60)
+
+shapes = []
+try:
+    with model.generate(max_new_tokens=5, do_sample=False) as tracer:
+        with tracer.invoke(prompt):
+            for step in tracer.iter[:]:
+                h = model.transformer.h[0].output[0]
+                shapes.append((step, tuple(h.shape)))
+                # This will raise IndexError at step 1+ because size is 1
+                model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1])
+            model.generator.output.save()
+except IndexError as e:
+    print(f"Shapes seen before error: {shapes}")
+    print(f"IndexError at step 1+: {e}")
+    print(f"  -> Step 0 has seq_len={n}, position {n-1} valid.")
+    print(f"  -> Steps 1+ have seq_len=1 (KV cache), position {n-1} out of bounds.")
+
+# ---------------------------------------------------------------------------
+# Experiment 5: Sanity check — unsteered generation differs from steered.
+#
+# Confirms the intervention actually has an effect: generation without any
+# intervention produces different tokens than generation with position -1
+# zeroed. Without this check, the equivalences above could be vacuous.
+# ---------------------------------------------------------------------------
+
+print()
+print("=" * 60)
+print("Experiment 5: Unsteered differs from steered")
+print("=" * 60)
+
+with model.generate(max_new_tokens=5, do_sample=False) as tracer:
+    with tracer.invoke(prompt):
+        out_unsteered = model.generator.output.save()
+
+tokens_unsteered = out_unsteered[0][n:].tolist()
+tokens_steered = out1[0][n:].tolist()  # from experiment 3
+
+print(f"Unsteered: {tokenizer.convert_ids_to_tokens(tokens_unsteered)}")
+print(f"Steered:   {tokenizer.convert_ids_to_tokens(tokens_steered)}")
+print(f"Different: {tokens_unsteered != tokens_steered}")
src/more_nnsight/__init__.py
index 7b476de..edd5a8a 100644
--- a/src/more_nnsight/__init__.py
+++ b/src/more_nnsight/__init__.py
@@ -4,6 +4,7 @@ from .steering_search import (
     SteeringSearchOutput,
     SteeringTrialResult,
     best_trial,
+    layer_sweep,
     steering_search,
 )
 
@@ -15,6 +16,7 @@ __all__ = [
     "SteeringSearchOutput",
     "SteeringTrialResult",
     "best_trial",
+    "layer_sweep",
     "steering_search",
 ]
 
src/more_nnsight/steering_search.py
index 34779b1..3a4b8e6 100644
--- a/src/more_nnsight/steering_search.py
+++ b/src/more_nnsight/steering_search.py
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import json
+import re
 from dataclasses import dataclass, asdict
 from pathlib import Path
 from typing import Any, Callable
@@ -211,16 +212,46 @@ def _generate_steered(
     return _generate_steered_batch(model, [prompt], direction, alpha, max_new_tokens)[0]
 
 
+def _evaluate_steered(
+    selected_paths: list[str],
+    alpha: float,
+    vectors: dict[str, torch.Tensor],
+    eval_rows: list[dict],
+    baseline_responses: list[str],
+    model: Any,
+    score: Callable[[dict, str, str], float],
+    max_new_tokens: int,
+    batch_size: int,
+) -> tuple[list[str], float]:
+    """Steer with the given paths+alpha and score the results.
+
+    Returns ``(responses, mean_score)`` over ``eval_rows``.
+    """
+    direction = SavedActivation.from_pairs(*[(p, vectors[p]) for p in selected_paths])
+    trial_scores: list[float] = []
+    trial_responses: list[str] = []
+    for i in range(0, len(eval_rows), batch_size):
+        batch_rows = eval_rows[i : i + batch_size]
+        batch_baselines = baseline_responses[i : i + batch_size]
+        batch_prompts = [row["prompt"] for row in batch_rows]
+        batch_responses = _generate_steered_batch(
+            model, batch_prompts, direction, alpha, max_new_tokens
+        )
+        trial_responses.extend(batch_responses)
+        for row, baseline, steered in zip(batch_rows, batch_baselines, batch_responses):
+            trial_scores.append(score(row, baseline, steered))
+    return trial_responses, sum(trial_scores) / len(trial_scores)
+
+
 @torch.no_grad()
 def steering_search(
     config: SteeringSearchConfig,
     model: Any,
     positive_prompts: list[str],
     negative_prompts: list[str],
-    eval_prompts: list[str],
-    score: Callable[[str, str], float | None],
+    eval_rows: list[dict],
+    score: Callable[[dict, str, str], float],
     max_new_tokens: int,
-    ineligible_score: float,
     output_path: Path | None = None,
     batch_size: int = 1,
 ) -> SteeringSearchOutput:
@@ -233,12 +264,13 @@ def steering_search(
             steering vectors).
         negative_prompts: Prompts for the "negative" class (used to compute
             steering vectors).
-        eval_prompts: Prompts to generate steered responses on each trial.
-        score: Scoring function. Takes (prompt, response) and returns a float
-            where **lower is better**, or ``None`` if the result is ineligible.
+        eval_rows: Rows to evaluate on each trial. Each row is a dict that
+            **must** contain a ``"prompt"`` key (the formatted prompt string).
+            All other keys are opaque metadata passed through to ``score``.
+            The caller should pre-filter to rows whose baselines are scorable.
+        score: Scoring function. Takes ``(row, baseline_response,
+            steered_response)`` and returns a float where **lower is better**.
         max_new_tokens: Maximum tokens to generate per prompt.
-        ineligible_score: Score to assign when steering makes a previously
-            eligible prompt ineligible (e.g. program stops passing tests).
         output_path: If provided, write trial results as JSONL (appended
             incrementally so partial results survive crashes).
         batch_size: Number of prompts to generate in parallel per batch.
@@ -257,23 +289,17 @@ def steering_search(
         model, config.candidate_paths, positive_prompts, negative_prompts
     )
 
-    # Pre-filter: generate baselines (no steering) and score.
-    # Prompts where the baseline score is None are permanently excluded.
-    eligible_prompts: list[str] = []
+    # Generate baseline responses for all eval rows.
+    eval_prompts = [row["prompt"] for row in eval_rows]
     baseline_responses: list[str] = []
-    print(f"Pre-filtering {len(eval_prompts)} eval prompts...")
+    print(f"Generating baselines for {len(eval_prompts)} eval prompts...")
     for i in range(0, len(eval_prompts), batch_size):
         batch = eval_prompts[i : i + batch_size]
-        responses = _generate_batch(model, batch, max_new_tokens)
-        for prompt, response in zip(batch, responses):
-            baseline_score = score(prompt, response)
-            if baseline_score is not None:
-                eligible_prompts.append(prompt)
-                baseline_responses.append(response)
-    print(f"Eligible prompts: {len(eligible_prompts)} / {len(eval_prompts)}")
+        baseline_responses.extend(_generate_batch(model, batch, max_new_tokens))
+    print(f"Baselines generated: {len(baseline_responses)}")
 
-    if not eligible_prompts:
-        raise ValueError("No eligible eval prompts after baseline filtering")
+    if not eval_rows:
+        raise ValueError("eval_rows must not be empty")
 
     # Clear output file if it exists.
     if output_path is not None:
@@ -296,21 +322,10 @@ def steering_search(
             trial_score = PENALTY
             trial_responses: list[str] = []
         else:
-            direction = SavedActivation.from_pairs(
-                *[(p, vectors[p]) for p in selected_paths]
+            trial_responses, trial_score = _evaluate_steered(
+                selected_paths, alpha, vectors, eval_rows,
+                baseline_responses, model, score, max_new_tokens, batch_size,
             )
-            trial_scores: list[float] = []
-            trial_responses = []
-            for i in range(0, len(eligible_prompts), batch_size):
-                batch_prompts = eligible_prompts[i : i + batch_size]
-                batch_responses = _generate_steered_batch(
-                    model, batch_prompts, direction, alpha, max_new_tokens
-                )
-                trial_responses.extend(batch_responses)
-                for prompt, response in zip(batch_prompts, batch_responses):
-                    s = score(prompt, response)
-                    trial_scores.append(s if s is not None else ineligible_score)
-            trial_score = sum(trial_scores) / len(trial_scores)
 
         result = SteeringTrialResult(
             trial_index=len(trials),
@@ -348,3 +363,107 @@ def steering_search(
     )
 
     return SteeringSearchOutput(baseline_responses=baseline_responses, trials=trials)
+
+
+def _group_paths_by_layer(paths: list[str]) -> list[list[str]]:
+    """Group paths that share the same layer by stripping the trailing token index.
+
+    ``"model.transformer.h[10].output[-1]"`` and
+    ``"model.transformer.h[10].output[-2]"`` both strip to
+    ``"model.transformer.h[10].output"`` and are placed in the same group,
+    producing a single trial that intervenes on both positions simultaneously.
+    Insertion order of groups is preserved.
+    """
+    groups: dict[str, list[str]] = {}
+    for path in paths:
+        key = re.sub(r"\[-?\d+\]$", "", path)
+        if key not in groups:
+            groups[key] = []
+        groups[key].append(path)
+    return list(groups.values())
+
+
+@torch.no_grad()
+def layer_sweep(
+    paths: list[str],
+    alpha: float,
+    model: Any,
+    positive_prompts: list[str],
+    negative_prompts: list[str],
+    eval_rows: list[dict],
+    score: Callable[[dict, str, str], float],
+    max_new_tokens: int,
+    output_path: Path | None = None,
+    batch_size: int = 1,
+) -> SteeringSearchOutput:
+    """Run one steering trial per layer at a fixed alpha value.
+
+    Paths that share the same layer (differing only in their trailing token
+    position index) are grouped automatically and intervened on simultaneously
+    within a single trial.  For example, passing
+    ``"model.transformer.h[10].output[-1]"`` and
+    ``"model.transformer.h[10].output[-2]"`` produces one trial that steers
+    both positions at once.
+
+    Args:
+        paths: Activation paths to sweep. Grouped by layer internally.
+        alpha: Fixed steering coefficient applied in every trial.
+        model: An nnsight LanguageModel (already loaded).
+        positive_prompts: Prompts for the "positive" class.
+        negative_prompts: Prompts for the "negative" class.
+        eval_rows: Rows to evaluate on each trial. Each row is a dict that
+            **must** contain a ``"prompt"`` key. Other keys are opaque metadata
+            passed through to ``score``.
+        score: Scoring function. Takes ``(row, baseline_response,
+            steered_response)`` and returns a float (lower is better).
+        max_new_tokens: Maximum tokens to generate per prompt.
+        output_path: If provided, append trial results as JSONL incrementally.
+        batch_size: Number of prompts to generate in parallel per batch.
+
+    Returns:
+        SteeringSearchOutput with baseline_responses and one trial per layer.
+    """
+    if not paths:
+        raise ValueError("paths must not be empty")
+
+    groups = _group_paths_by_layer(paths)
+    vectors = _compute_steering_vectors(model, paths, positive_prompts, negative_prompts)
+
+    # Generate baseline responses for all eval rows.
+    eval_prompts = [row["prompt"] for row in eval_rows]
+    baseline_responses: list[str] = []
+    print(f"Generating baselines for {len(eval_prompts)} eval prompts...")
+    for i in range(0, len(eval_prompts), batch_size):
+        batch = eval_prompts[i : i + batch_size]
+        baseline_responses.extend(_generate_batch(model, batch, max_new_tokens))
+    print(f"Baselines generated: {len(baseline_responses)}")
+
+    if not eval_rows:
+        raise ValueError("eval_rows must not be empty")
+
+    if output_path is not None:
+        output_path.parent.mkdir(parents=True, exist_ok=True)
+        output_path.write_text("")
+
+    trials: list[SteeringTrialResult] = []
+
+    for trial_index, group in enumerate(groups):
+        trial_responses, trial_score = _evaluate_steered(
+            group, alpha, vectors, eval_rows,
+            baseline_responses, model, score, max_new_tokens, batch_size,
+        )
+        result = SteeringTrialResult(
+            trial_index=trial_index,
+            selected_paths=group,
+            alpha=alpha,
+            score=trial_score,
+            raw_point=[trial_index, alpha],
+            responses=trial_responses,
+        )
+        trials.append(result)
+        if output_path is not None:
+            _append_jsonl(output_path, result)
+        label = group[0] if len(group) == 1 else f"{group[0]} (+{len(group) - 1} more)"
+        print(f"  [{trial_index + 1}/{len(groups)}] {label}: score={trial_score:.4f}")
+
+    return SteeringSearchOutput(baseline_responses=baseline_responses, trials=trials)
tests/test_steering_search.py
index f26bf58..495d4f9 100644
--- a/tests/test_steering_search.py
+++ b/tests/test_steering_search.py
@@ -15,7 +15,9 @@ from more_nnsight.steering_search import (
     SteeringSearchOutput,
     SteeringTrialResult,
     _decode_point,
+    _group_paths_by_layer,
     best_trial,
+    layer_sweep,
     steering_search,
 )
 
@@ -83,7 +85,7 @@ def _patch_internals(monkeypatch, fake_vectors, generate_fn=None):
 
     monkeypatch.setattr(SavedActivation, "from_pairs", staticmethod(fake_from_pairs))
 
-    # Mock baseline generation (for pre-filter).
+    # Mock baseline generation.
     def baseline_batch_fn(model, prompts, max_new_tokens):
         return ["baseline response"] * len(prompts)
 
@@ -113,10 +115,9 @@ def test_empty_candidates_raises():
             model=None,
             positive_prompts=[],
             negative_prompts=[],
-            eval_prompts=[],
-            score=lambda p, r: 0.0,
+            eval_rows=[],
+            score=lambda row, b, s: 0.0,
             max_new_tokens=10,
-            ineligible_score=1.0,
         )
 
 
@@ -137,10 +138,9 @@ def test_search_returns_correct_number_of_results(tmp_path, monkeypatch):
         model=None,
         positive_prompts=["pos1"],
         negative_prompts=["neg1"],
-        eval_prompts=["eval1", "eval2"],
-        score=lambda p, r: 1.0,
+        eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}],
+        score=lambda row, b, s: 1.0,
         max_new_tokens=10,
-        ineligible_score=1.0,
         output_path=output,
     )
     assert isinstance(out, SteeringSearchOutput)
@@ -157,8 +157,8 @@ def test_search_finds_known_optimum(monkeypatch):
 
     _patch_internals(monkeypatch, fake_vectors, generate_fn=fake_generate)
 
-    def score_fn(prompt, response):
-        paths = response.split(",") if response else []
+    def score_fn(row, baseline, steered):
+        paths = steered.split(",") if steered else []
         return 0.0 if "path_2" in paths else 1.0
 
     config = SteeringSearchConfig(
@@ -173,10 +173,9 @@ def test_search_finds_known_optimum(monkeypatch):
         model=None,
         positive_prompts=["pos"],
         negative_prompts=["neg"],
-        eval_prompts=["eval"],
+        eval_rows=[{"prompt": "eval"}],
         score=score_fn,
         max_new_tokens=10,
-        ineligible_score=1.0,
     )
     best = best_trial(out.trials)
     assert best.score == 0.0
@@ -200,10 +199,9 @@ def test_jsonl_output_written_incrementally(tmp_path, monkeypatch):
         model=None,
         positive_prompts=["pos"],
         negative_prompts=["neg"],
-        eval_prompts=["eval"],
-        score=lambda p, r: 1.0,
+        eval_rows=[{"prompt": "eval"}],
+        score=lambda row, b, s: 1.0,
         max_new_tokens=10,
-        ineligible_score=1.0,
         output_path=output,
     )
     lines = output.read_text().strip().split("\n")
@@ -234,82 +232,167 @@ def test_empty_selection_gets_penalty(monkeypatch):
         model=None,
         positive_prompts=["pos"],
         negative_prompts=["neg"],
-        eval_prompts=["eval"],
-        score=lambda p, r: 0.0,
+        eval_rows=[{"prompt": "eval"}],
+        score=lambda row, b, s: 0.0,
         max_new_tokens=10,
-        ineligible_score=1.0,
     )
     penalty_trials = [r for r in out.trials if not r.selected_paths]
     for t in penalty_trials:
         assert t.score == PENALTY
 
 
-def test_ineligible_baseline_prompts_excluded(monkeypatch):
-    """Prompts that score None on baseline are excluded from evaluation."""
-    fake_vectors = {"p0": torch.randn(1, 8)}
+def test_score_receives_row_metadata(monkeypatch):
+    """Score function receives the full row dict, not just the prompt."""
+    fake_vectors = {"p0": torch.randn(1, 8), "p1": torch.randn(1, 8)}
     _patch_internals(monkeypatch, fake_vectors)
 
-    # 3 eval prompts; "bad" one returns None on baseline, others return 0.5
-    def score_fn(prompt, response):
-        if prompt == "bad_prompt":
-            return None
+    received_rows = []
+
+    def score_fn(row, baseline, steered):
+        received_rows.append(row)
         return 0.5
 
     config = SteeringSearchConfig(
-        candidate_paths=["p0"],
-        max_simultaneous_paths=1,
+        candidate_paths=["p0", "p1"],
+        max_simultaneous_paths=2,
         alpha_range=(0.1, 3.0),
         n_calls=5,
-        n_initial_points=3,
+        n_initial_points=5,
     )
-    out = steering_search(
+    steering_search(
         config=config,
         model=None,
         positive_prompts=["pos"],
         negative_prompts=["neg"],
-        eval_prompts=["good1", "bad_prompt", "good2"],
+        eval_rows=[{"prompt": "eval", "task_id": "T1", "extra": 42}],
         score=score_fn,
         max_new_tokens=10,
-        ineligible_score=1.0,
     )
-    # All non-penalty trials should have score 0.5 (only good prompts evaluated)
-    for t in out.trials:
-        if t.selected_paths:
-            assert t.score == 0.5
+    non_penalty = [r for r in received_rows if r is not None]
+    assert len(non_penalty) > 0
+    assert non_penalty[0]["task_id"] == "T1"
+    assert non_penalty[0]["extra"] == 42
 
 
-def test_steered_none_uses_ineligible_score(monkeypatch):
-    """When steering makes an eligible prompt return None, ineligible_score is used."""
-    fake_vectors = {"p0": torch.randn(1, 8)}
-    _patch_internals(monkeypatch, fake_vectors)
+# ---------------------------------------------------------------------------
+# _group_paths_by_layer
+# ---------------------------------------------------------------------------
+
 
-    call_count = [0]
+def test_group_paths_same_layer_grouped_together():
+    paths = [
+        "model.transformer.h[10].output[-1]",
+        "model.transformer.h[10].output[-2]",
+        "model.transformer.h[11].output[-1]",
+    ]
+    groups = _group_paths_by_layer(paths)
+    assert len(groups) == 2
+    assert groups[0] == [
+        "model.transformer.h[10].output[-1]",
+        "model.transformer.h[10].output[-2]",
+    ]
+    assert groups[1] == ["model.transformer.h[11].output[-1]"]
 
-    def score_fn(prompt, response):
-        call_count[0] += 1
-        # Baseline call returns 0.5 (eligible), steered calls return None
-        if response == "baseline response":
-            return 0.5
-        return None
 
-    config = SteeringSearchConfig(
-        candidate_paths=["p0"],
-        max_simultaneous_paths=1,
-        alpha_range=(0.1, 3.0),
-        n_calls=3,
-        n_initial_points=3,
+def test_group_paths_preserves_insertion_order():
+    paths = [
+        "model.transformer.h[5].output[-1]",
+        "model.transformer.h[3].output[-1]",
+        "model.transformer.h[5].output[-2]",
+    ]
+    groups = _group_paths_by_layer(paths)
+    assert len(groups) == 2
+    assert groups[0][0].startswith("model.transformer.h[5]")
+    assert groups[1][0].startswith("model.transformer.h[3]")
+
+
+def test_group_paths_no_token_index():
+    # Paths without a trailing [N] each form their own group.
+    paths = ["model.transformer.h[10].output", "model.transformer.h[11].output"]
+    groups = _group_paths_by_layer(paths)
+    assert len(groups) == 2
+
+
+# ---------------------------------------------------------------------------
+# layer_sweep
+# ---------------------------------------------------------------------------
+
+
+def test_layer_sweep_groups_by_layer(monkeypatch):
+    # Two paths in the same layer → one trial; one path in a different layer → second trial.
+    paths = [
+        "model.transformer.h[10].output[-1]",
+        "model.transformer.h[10].output[-2]",
+        "model.transformer.h[11].output[-1]",
+    ]
+    fake_vectors = {p: torch.randn(1, 8) for p in paths}
+    _patch_internals(monkeypatch, fake_vectors)
+
+    out = layer_sweep(
+        paths=paths,
+        alpha=2.0,
+        model=None,
+        positive_prompts=["pos"],
+        negative_prompts=["neg"],
+        eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}],
+        score=lambda row, b, s: 0.5,
+        max_new_tokens=10,
     )
-    out = steering_search(
-        config=config,
+
+    assert len(out.trials) == 2
+    assert out.trials[0].selected_paths == [
+        "model.transformer.h[10].output[-1]",
+        "model.transformer.h[10].output[-2]",
+    ]
+    assert out.trials[1].selected_paths == ["model.transformer.h[11].output[-1]"]
+    for i, trial in enumerate(out.trials):
+        assert trial.trial_index == i
+        assert trial.alpha == 2.0
+        assert trial.raw_point == [i, 2.0]
+
+
+def test_layer_sweep_empty_paths_raises():
+    with pytest.raises(ValueError, match="paths must not be empty"):
+        layer_sweep(
+            paths=[],
+            alpha=1.0,
+            model=None,
+            positive_prompts=[],
+            negative_prompts=[],
+            eval_rows=[],
+            score=lambda row, b, s: 0.0,
+            max_new_tokens=10,
+        )
+
+
+def test_layer_sweep_jsonl_output(tmp_path, monkeypatch):
+    import json
+
+    paths = [
+        "model.transformer.h[10].output[-1]",
+        "model.transformer.h[10].output[-2]",
+        "model.transformer.h[11].output[-1]",
+    ]
+    fake_vectors = {p: torch.randn(1, 8) for p in paths}
+    _patch_internals(monkeypatch, fake_vectors)
+
+    output = tmp_path / "sweep.jsonl"
+    layer_sweep(
+        paths=paths,
+        alpha=3.0,
         model=None,
         positive_prompts=["pos"],
         negative_prompts=["neg"],
-        eval_prompts=["eval"],
-        score=score_fn,
+        eval_rows=[{"prompt": "eval"}],
+        score=lambda row, b, s: 0.0,
         max_new_tokens=10,
-        ineligible_score=1.0,
+        output_path=output,
     )
-    # All non-penalty trials should use ineligible_score since steered returns None
-    for t in out.trials:
-        if t.selected_paths:
-            assert t.score == 1.0
+
+    lines = output.read_text().strip().split("\n")
+    assert len(lines) == 2  # two groups → two trials
+    for line in lines:
+        record = json.loads(line)
+        assert "trial_index" in record
+        assert "selected_paths" in record
+        assert record["alpha"] == 3.0
tests/test_steering_search_gpt2.py
index ecc3d34..d4c6ee4 100644
--- a/tests/test_steering_search_gpt2.py
+++ b/tests/test_steering_search_gpt2.py
@@ -58,8 +58,8 @@ def test_steering_changes_sentiment(gpt2, tmp_path):
     ]
     eval_prompt = "The weather today is"
 
-    def score(prompt: str, response: str) -> float:
-        return _sentiment_score(response)
+    def score(row: dict, baseline: str, steered: str) -> float:
+        return _sentiment_score(steered)
 
     config = SteeringSearchConfig(
         candidate_paths=candidate_paths,
@@ -75,10 +75,9 @@ def test_steering_changes_sentiment(gpt2, tmp_path):
         model=gpt2,
         positive_prompts=positive_prompts,
         negative_prompts=negative_prompts,
-        eval_prompts=[eval_prompt],
+        eval_rows=[{"prompt": eval_prompt}],
         score=score,
         max_new_tokens=4,
-        ineligible_score=1.0,
         output_path=output,
     )