Repositories / more_nnsight.git
more_nnsight.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
@@ -0,0 +1,270 @@ +""" +Establishing the equivalence between activation steering approaches in nnsight. + +This script investigates how nnsight's .generate() applies interventions and confirms +that using position -1 in a plain invoke context is equivalent to using the concrete +last-prompt-token index with tracer.iter[0]. + +Key findings established by these experiments: + +1. INTERVENTION FIRES ONCE (on prompt pass only) + nnsight's plain `with tracer.invoke(prompt):` fires the intervention graph exactly + once, during the first forward pass over the full prompt. Subsequent generation + steps use the KV cache and process only 1 new token per step — the intervention + does NOT fire again. + +2. POSITION -1 CORRECTLY TARGETS THE LAST PROMPT TOKEN + Since the intervention fires once when the full prompt (seq_len=4) is being + processed, `activation[-1]` resolves to position 3 (the last prompt token). + This is exactly what we want for steering. + +3. THE EFFECT PERSISTS VIA THE KV CACHE + Corrupting a hidden state at position 3 during the prompt pass alters the + key/value pairs stored in the KV cache for that position. All subsequent + generation steps attend to these corrupted values, so the effect propagates + through the entire generation without re-applying the intervention. + +4. iter[0] WITH CONCRETE POSITION GIVES IDENTICAL RESULTS + Applying the intervention at step 0 only (via tracer.iter[0]) with position n-1 + is bit-for-bit identical to using -1 in the plain invoke context. + +5. iter[:] FAILS FOR POSITION > 0 AT STEPS 1+ + At generation steps 1 and beyond (KV cache steps), the activation tensor has + shape [1, hidden_size] — only the new token. Position 3 is out of bounds. + This is proof that nnsight uses the KV cache during generation. + +Usage: + uv run --no-sync python scripts/equivalence.py + +Expected output (GPT-2, prompt "The weather today is"): + + === Experiment 1: Which prompt position matters? === + Prompt: 4 tokens ['The', 'Ġweather', 'Ġtoday', 'Ġis'] + Normal (no zero) : very good, and we + Zero [0] : very good, and we + Zero [1] : very good, and we + Zero [2] : very good, and we + Zero [3] = [-1] : , of course, very + + Only zeroing the last prompt token (position 3 = -1) changes the output + when using a late layer (e.g. layer 11). Positions 0-2 have no effect + because the last prompt token is what directly determines the next-token + logits in an autoregressive model. (At early layers like layer 0, all + positions matter because the corruption propagates forward through all + subsequent layers.) + + === Experiment 2: KV cache persistence === + Steered 6 tokens: ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ', 'Ċ'] + Context for trace: ['The', 'Ġweather', 'Ġtoday', 'Ġis', 'Ċ', 'Ċ', 'Ċ', '"'] + Actual 5th token from generate: '\n' (id=198) + Predicted by full trace at pos 3: '\n' (id=198) + KV cache persistent == full recompute: True + + === Experiment 3: iter[0] equivalence === + iter[0] step shape: torch.Size([4, 768]) + Approach 1 (-1, plain invoke): ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ'] + Approach 2 (iter[0], pos 3): ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ'] + Identical: True + + === Experiment 4: iter[:] out-of-bounds at steps 1+ === + Shapes seen before error: [(0, (4, 768)), (1, (1, 768))] + IndexError at step 1+: index 3 is out of bounds for dimension 0 with size 1 + -> Step 0 has seq_len=4, position 3 valid. + -> Steps 1+ have seq_len=1 (KV cache), position 3 out of bounds. + + === Experiment 5: Unsteered differs from steered === + Unsteered: ['Ġvery', 'Ġgood', ',', 'Ġand', 'Ġwe'] + Steered: ['Ċ', 'Ċ', 'Ċ', '"', 'Ċ'] + Different: True +""" + +import torch +torch.cuda.memory.set_per_process_memory_fraction(0.8) + +from nnsight import LanguageModel + +model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True) +tokenizer = model.tokenizer +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +prompt = "The weather today is" +prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"][0] +n = len(prompt_ids) + + +# --------------------------------------------------------------------------- +# Experiment 1: Which token position matters? +# +# We zero layer 11's output at each individual token position to determine +# which positions influence the generated output. At a late layer like 11, +# only the last prompt token (position n-1 = -1) changes the output; earlier +# positions do not. This is because causal attention means the last position +# is the only one whose representation directly feeds into the next-token +# logits. (At early layers like layer 0, corrupting any position propagates +# through all subsequent layers and changes the output.) +# --------------------------------------------------------------------------- + +print("=" * 60) +print("Experiment 1: Which prompt position matters? (layer 11)") +print("=" * 60) +print(f"Prompt: {n} tokens {tokenizer.convert_ids_to_tokens(prompt_ids.tolist())}") + +for label, idx in [ + ("Normal (no zero)", None), + ("Zero [0]", 0), + ("Zero [1]", 1), + ("Zero [2]", 2), + ("Zero [3] = [-1]", 3), +]: + with model.generate(max_new_tokens=5, do_sample=False) as tracer: + with tracer.invoke(prompt): + if idx is not None: + model.transformer.h[11].output[0][idx] = 0.0 + out_ids = model.generator.output.save() + response = tokenizer.decode(out_ids[0][n:], skip_special_tokens=True) + print(f"{label:22s}: {repr(response)}") + + +# --------------------------------------------------------------------------- +# Experiment 2: KV cache persistence. +# +# We confirm that the intervention on the prompt pass persists via the KV +# cache through all subsequent generation steps. Specifically: +# - Run generate from P with zero at position n-1, collect 6 steered tokens. +# - Do a fresh full forward pass over P + first 4 steered tokens, with the +# same zero at position n-1. +# - The full-pass logits at the last position should predict the same 5th +# token that generate produced. +# +# This works because transformers with causal attention are equivalent whether +# computed incrementally (KV cache) or all at once, as long as the intervention +# is applied at the same absolute position. +# --------------------------------------------------------------------------- + +print() +print("=" * 60) +print("Experiment 2: KV cache persistence") +print("=" * 60) + +with model.generate(max_new_tokens=6, do_sample=False) as tracer: + with tracer.invoke(prompt): + h = model.transformer.h[0].output[0] + model.transformer.h[0].output[0][-1] = torch.zeros_like(h[-1]) + out_ids = model.generator.output.save() + +steered = out_ids[0][n:].tolist() +t5_actual = steered[4] +print(f"Steered 6 tokens: {tokenizer.convert_ids_to_tokens(steered)}") + +# Full trace over P + first 4 steered tokens, zero at the same absolute position. +ctx = torch.cat([prompt_ids, torch.tensor(steered[:4])]).unsqueeze(0) +print(f"Context for trace: {tokenizer.convert_ids_to_tokens(ctx[0].tolist())}") + +with torch.no_grad(): + with model.trace() as tracer: + with tracer.invoke(ctx): + h = model.transformer.h[0].output[0] + model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1]) + logits_proxy = model.lm_head.output.save() + +logits = torch.as_tensor(logits_proxy).float().cpu() # shape [1, seq, vocab] +last_logits = logits[0, -1] # logits predicting token after ctx +t5_predicted = torch.argmax(last_logits).item() + +print(f"Actual 5th token from generate: {repr(tokenizer.decode([t5_actual]))} (id={t5_actual})") +print(f"Predicted by full trace at pos {n-1}: {repr(tokenizer.decode([t5_predicted]))} (id={t5_predicted})") +print(f"KV cache persistent == full recompute: {t5_actual == t5_predicted}") + + +# --------------------------------------------------------------------------- +# Experiment 3: iter[0] with concrete position == -1 in plain invoke. +# +# tracer.iter[0] applies the intervention graph only at generation step 0, +# which is the prompt pass. At this step the activation has shape +# [seq_len, hidden_size] so position n-1 is valid. This is bit-for-bit +# identical to using -1 in the plain invoke context. +# --------------------------------------------------------------------------- + +print() +print("=" * 60) +print("Experiment 3: iter[0] equivalence") +print("=" * 60) + +# Approach 1: -1 in plain invoke (fires once on prompt pass) +with model.generate(max_new_tokens=5, do_sample=False) as tracer: + with tracer.invoke(prompt): + h = model.transformer.h[0].output[0] + model.transformer.h[0].output[0][-1] = torch.zeros_like(h[-1]) + out1 = model.generator.output.save() +tokens1 = out1[0][n:].tolist() + +# Approach 2: iter[0] with concrete position n-1 (fires once at step 0) +with model.generate(max_new_tokens=5, do_sample=False) as tracer: + with tracer.invoke(prompt): + for step in tracer.iter[0]: + h = model.transformer.h[0].output[0] + print(f" iter[0] step shape: {h.shape}") + model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1]) + out2 = model.generator.output.save() +tokens2 = out2[0][n:].tolist() + +print(f"Approach 1 (-1, plain invoke): {tokenizer.convert_ids_to_tokens(tokens1)}") +print(f"Approach 2 (iter[0], pos {n-1}): {tokenizer.convert_ids_to_tokens(tokens2)}") +print(f"Identical: {tokens1 == tokens2}") + + +# --------------------------------------------------------------------------- +# Experiment 4: iter[:] fails at steps 1+ for position > 0. +# +# tracer.iter[:] (= tracer.all()) fires the intervention at every generation +# step. At step 0 the full prompt is processed (shape [seq_len, hidden]). +# At steps 1+ nnsight uses the KV cache, processing only the new token +# (shape [1, hidden]). Trying to access position n-1 at those steps raises +# an IndexError, proving the KV cache is active. +# --------------------------------------------------------------------------- + +print() +print("=" * 60) +print("Experiment 4: iter[:] shape at each step") +print("=" * 60) + +shapes = [] +try: + with model.generate(max_new_tokens=5, do_sample=False) as tracer: + with tracer.invoke(prompt): + for step in tracer.iter[:]: + h = model.transformer.h[0].output[0] + shapes.append((step, tuple(h.shape))) + # This will raise IndexError at step 1+ because size is 1 + model.transformer.h[0].output[0][n - 1] = torch.zeros_like(h[n - 1]) + model.generator.output.save() +except IndexError as e: + print(f"Shapes seen before error: {shapes}") + print(f"IndexError at step 1+: {e}") + print(f" -> Step 0 has seq_len={n}, position {n-1} valid.") + print(f" -> Steps 1+ have seq_len=1 (KV cache), position {n-1} out of bounds.") + +# --------------------------------------------------------------------------- +# Experiment 5: Sanity check — unsteered generation differs from steered. +# +# Confirms the intervention actually has an effect: generation without any +# intervention produces different tokens than generation with position -1 +# zeroed. Without this check, the equivalences above could be vacuous. +# --------------------------------------------------------------------------- + +print() +print("=" * 60) +print("Experiment 5: Unsteered differs from steered") +print("=" * 60) + +with model.generate(max_new_tokens=5, do_sample=False) as tracer: + with tracer.invoke(prompt): + out_unsteered = model.generator.output.save() + +tokens_unsteered = out_unsteered[0][n:].tolist() +tokens_steered = out1[0][n:].tolist() # from experiment 3 + +print(f"Unsteered: {tokenizer.convert_ids_to_tokens(tokens_unsteered)}") +print(f"Steered: {tokenizer.convert_ids_to_tokens(tokens_steered)}") +print(f"Different: {tokens_unsteered != tokens_steered}")
@@ -4,6 +4,7 @@ from .steering_search import ( SteeringSearchOutput, SteeringTrialResult, best_trial, + layer_sweep, steering_search, ) @@ -15,6 +16,7 @@ __all__ = [ "SteeringSearchOutput", "SteeringTrialResult", "best_trial", + "layer_sweep", "steering_search", ]
@@ -1,6 +1,7 @@ from __future__ import annotations import json +import re from dataclasses import dataclass, asdict from pathlib import Path from typing import Any, Callable @@ -211,16 +212,46 @@ def _generate_steered( return _generate_steered_batch(model, [prompt], direction, alpha, max_new_tokens)[0] +def _evaluate_steered( + selected_paths: list[str], + alpha: float, + vectors: dict[str, torch.Tensor], + eval_rows: list[dict], + baseline_responses: list[str], + model: Any, + score: Callable[[dict, str, str], float], + max_new_tokens: int, + batch_size: int, +) -> tuple[list[str], float]: + """Steer with the given paths+alpha and score the results. + + Returns ``(responses, mean_score)`` over ``eval_rows``. + """ + direction = SavedActivation.from_pairs(*[(p, vectors[p]) for p in selected_paths]) + trial_scores: list[float] = [] + trial_responses: list[str] = [] + for i in range(0, len(eval_rows), batch_size): + batch_rows = eval_rows[i : i + batch_size] + batch_baselines = baseline_responses[i : i + batch_size] + batch_prompts = [row["prompt"] for row in batch_rows] + batch_responses = _generate_steered_batch( + model, batch_prompts, direction, alpha, max_new_tokens + ) + trial_responses.extend(batch_responses) + for row, baseline, steered in zip(batch_rows, batch_baselines, batch_responses): + trial_scores.append(score(row, baseline, steered)) + return trial_responses, sum(trial_scores) / len(trial_scores) + + @torch.no_grad() def steering_search( config: SteeringSearchConfig, model: Any, positive_prompts: list[str], negative_prompts: list[str], - eval_prompts: list[str], - score: Callable[[str, str], float | None], + eval_rows: list[dict], + score: Callable[[dict, str, str], float], max_new_tokens: int, - ineligible_score: float, output_path: Path | None = None, batch_size: int = 1, ) -> SteeringSearchOutput: @@ -233,12 +264,13 @@ def steering_search( steering vectors). negative_prompts: Prompts for the "negative" class (used to compute steering vectors). - eval_prompts: Prompts to generate steered responses on each trial. - score: Scoring function. Takes (prompt, response) and returns a float - where **lower is better**, or ``None`` if the result is ineligible. + eval_rows: Rows to evaluate on each trial. Each row is a dict that + **must** contain a ``"prompt"`` key (the formatted prompt string). + All other keys are opaque metadata passed through to ``score``. + The caller should pre-filter to rows whose baselines are scorable. + score: Scoring function. Takes ``(row, baseline_response, + steered_response)`` and returns a float where **lower is better**. max_new_tokens: Maximum tokens to generate per prompt. - ineligible_score: Score to assign when steering makes a previously - eligible prompt ineligible (e.g. program stops passing tests). output_path: If provided, write trial results as JSONL (appended incrementally so partial results survive crashes). batch_size: Number of prompts to generate in parallel per batch. @@ -257,23 +289,17 @@ def steering_search( model, config.candidate_paths, positive_prompts, negative_prompts ) - # Pre-filter: generate baselines (no steering) and score. - # Prompts where the baseline score is None are permanently excluded. - eligible_prompts: list[str] = [] + # Generate baseline responses for all eval rows. + eval_prompts = [row["prompt"] for row in eval_rows] baseline_responses: list[str] = [] - print(f"Pre-filtering {len(eval_prompts)} eval prompts...") + print(f"Generating baselines for {len(eval_prompts)} eval prompts...") for i in range(0, len(eval_prompts), batch_size): batch = eval_prompts[i : i + batch_size] - responses = _generate_batch(model, batch, max_new_tokens) - for prompt, response in zip(batch, responses): - baseline_score = score(prompt, response) - if baseline_score is not None: - eligible_prompts.append(prompt) - baseline_responses.append(response) - print(f"Eligible prompts: {len(eligible_prompts)} / {len(eval_prompts)}") + baseline_responses.extend(_generate_batch(model, batch, max_new_tokens)) + print(f"Baselines generated: {len(baseline_responses)}") - if not eligible_prompts: - raise ValueError("No eligible eval prompts after baseline filtering") + if not eval_rows: + raise ValueError("eval_rows must not be empty") # Clear output file if it exists. if output_path is not None: @@ -296,21 +322,10 @@ def steering_search( trial_score = PENALTY trial_responses: list[str] = [] else: - direction = SavedActivation.from_pairs( - *[(p, vectors[p]) for p in selected_paths] + trial_responses, trial_score = _evaluate_steered( + selected_paths, alpha, vectors, eval_rows, + baseline_responses, model, score, max_new_tokens, batch_size, ) - trial_scores: list[float] = [] - trial_responses = [] - for i in range(0, len(eligible_prompts), batch_size): - batch_prompts = eligible_prompts[i : i + batch_size] - batch_responses = _generate_steered_batch( - model, batch_prompts, direction, alpha, max_new_tokens - ) - trial_responses.extend(batch_responses) - for prompt, response in zip(batch_prompts, batch_responses): - s = score(prompt, response) - trial_scores.append(s if s is not None else ineligible_score) - trial_score = sum(trial_scores) / len(trial_scores) result = SteeringTrialResult( trial_index=len(trials), @@ -348,3 +363,107 @@ def steering_search( ) return SteeringSearchOutput(baseline_responses=baseline_responses, trials=trials) + + +def _group_paths_by_layer(paths: list[str]) -> list[list[str]]: + """Group paths that share the same layer by stripping the trailing token index. + + ``"model.transformer.h[10].output[-1]"`` and + ``"model.transformer.h[10].output[-2]"`` both strip to + ``"model.transformer.h[10].output"`` and are placed in the same group, + producing a single trial that intervenes on both positions simultaneously. + Insertion order of groups is preserved. + """ + groups: dict[str, list[str]] = {} + for path in paths: + key = re.sub(r"\[-?\d+\]$", "", path) + if key not in groups: + groups[key] = [] + groups[key].append(path) + return list(groups.values()) + + +@torch.no_grad() +def layer_sweep( + paths: list[str], + alpha: float, + model: Any, + positive_prompts: list[str], + negative_prompts: list[str], + eval_rows: list[dict], + score: Callable[[dict, str, str], float], + max_new_tokens: int, + output_path: Path | None = None, + batch_size: int = 1, +) -> SteeringSearchOutput: + """Run one steering trial per layer at a fixed alpha value. + + Paths that share the same layer (differing only in their trailing token + position index) are grouped automatically and intervened on simultaneously + within a single trial. For example, passing + ``"model.transformer.h[10].output[-1]"`` and + ``"model.transformer.h[10].output[-2]"`` produces one trial that steers + both positions at once. + + Args: + paths: Activation paths to sweep. Grouped by layer internally. + alpha: Fixed steering coefficient applied in every trial. + model: An nnsight LanguageModel (already loaded). + positive_prompts: Prompts for the "positive" class. + negative_prompts: Prompts for the "negative" class. + eval_rows: Rows to evaluate on each trial. Each row is a dict that + **must** contain a ``"prompt"`` key. Other keys are opaque metadata + passed through to ``score``. + score: Scoring function. Takes ``(row, baseline_response, + steered_response)`` and returns a float (lower is better). + max_new_tokens: Maximum tokens to generate per prompt. + output_path: If provided, append trial results as JSONL incrementally. + batch_size: Number of prompts to generate in parallel per batch. + + Returns: + SteeringSearchOutput with baseline_responses and one trial per layer. + """ + if not paths: + raise ValueError("paths must not be empty") + + groups = _group_paths_by_layer(paths) + vectors = _compute_steering_vectors(model, paths, positive_prompts, negative_prompts) + + # Generate baseline responses for all eval rows. + eval_prompts = [row["prompt"] for row in eval_rows] + baseline_responses: list[str] = [] + print(f"Generating baselines for {len(eval_prompts)} eval prompts...") + for i in range(0, len(eval_prompts), batch_size): + batch = eval_prompts[i : i + batch_size] + baseline_responses.extend(_generate_batch(model, batch, max_new_tokens)) + print(f"Baselines generated: {len(baseline_responses)}") + + if not eval_rows: + raise ValueError("eval_rows must not be empty") + + if output_path is not None: + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text("") + + trials: list[SteeringTrialResult] = [] + + for trial_index, group in enumerate(groups): + trial_responses, trial_score = _evaluate_steered( + group, alpha, vectors, eval_rows, + baseline_responses, model, score, max_new_tokens, batch_size, + ) + result = SteeringTrialResult( + trial_index=trial_index, + selected_paths=group, + alpha=alpha, + score=trial_score, + raw_point=[trial_index, alpha], + responses=trial_responses, + ) + trials.append(result) + if output_path is not None: + _append_jsonl(output_path, result) + label = group[0] if len(group) == 1 else f"{group[0]} (+{len(group) - 1} more)" + print(f" [{trial_index + 1}/{len(groups)}] {label}: score={trial_score:.4f}") + + return SteeringSearchOutput(baseline_responses=baseline_responses, trials=trials)
@@ -15,7 +15,9 @@ from more_nnsight.steering_search import ( SteeringSearchOutput, SteeringTrialResult, _decode_point, + _group_paths_by_layer, best_trial, + layer_sweep, steering_search, ) @@ -83,7 +85,7 @@ def _patch_internals(monkeypatch, fake_vectors, generate_fn=None): monkeypatch.setattr(SavedActivation, "from_pairs", staticmethod(fake_from_pairs)) - # Mock baseline generation (for pre-filter). + # Mock baseline generation. def baseline_batch_fn(model, prompts, max_new_tokens): return ["baseline response"] * len(prompts) @@ -113,10 +115,9 @@ def test_empty_candidates_raises(): model=None, positive_prompts=[], negative_prompts=[], - eval_prompts=[], - score=lambda p, r: 0.0, + eval_rows=[], + score=lambda row, b, s: 0.0, max_new_tokens=10, - ineligible_score=1.0, ) @@ -137,10 +138,9 @@ def test_search_returns_correct_number_of_results(tmp_path, monkeypatch): model=None, positive_prompts=["pos1"], negative_prompts=["neg1"], - eval_prompts=["eval1", "eval2"], - score=lambda p, r: 1.0, + eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}], + score=lambda row, b, s: 1.0, max_new_tokens=10, - ineligible_score=1.0, output_path=output, ) assert isinstance(out, SteeringSearchOutput) @@ -157,8 +157,8 @@ def test_search_finds_known_optimum(monkeypatch): _patch_internals(monkeypatch, fake_vectors, generate_fn=fake_generate) - def score_fn(prompt, response): - paths = response.split(",") if response else [] + def score_fn(row, baseline, steered): + paths = steered.split(",") if steered else [] return 0.0 if "path_2" in paths else 1.0 config = SteeringSearchConfig( @@ -173,10 +173,9 @@ def test_search_finds_known_optimum(monkeypatch): model=None, positive_prompts=["pos"], negative_prompts=["neg"], - eval_prompts=["eval"], + eval_rows=[{"prompt": "eval"}], score=score_fn, max_new_tokens=10, - ineligible_score=1.0, ) best = best_trial(out.trials) assert best.score == 0.0 @@ -200,10 +199,9 @@ def test_jsonl_output_written_incrementally(tmp_path, monkeypatch): model=None, positive_prompts=["pos"], negative_prompts=["neg"], - eval_prompts=["eval"], - score=lambda p, r: 1.0, + eval_rows=[{"prompt": "eval"}], + score=lambda row, b, s: 1.0, max_new_tokens=10, - ineligible_score=1.0, output_path=output, ) lines = output.read_text().strip().split("\n") @@ -234,82 +232,167 @@ def test_empty_selection_gets_penalty(monkeypatch): model=None, positive_prompts=["pos"], negative_prompts=["neg"], - eval_prompts=["eval"], - score=lambda p, r: 0.0, + eval_rows=[{"prompt": "eval"}], + score=lambda row, b, s: 0.0, max_new_tokens=10, - ineligible_score=1.0, ) penalty_trials = [r for r in out.trials if not r.selected_paths] for t in penalty_trials: assert t.score == PENALTY -def test_ineligible_baseline_prompts_excluded(monkeypatch): - """Prompts that score None on baseline are excluded from evaluation.""" - fake_vectors = {"p0": torch.randn(1, 8)} +def test_score_receives_row_metadata(monkeypatch): + """Score function receives the full row dict, not just the prompt.""" + fake_vectors = {"p0": torch.randn(1, 8), "p1": torch.randn(1, 8)} _patch_internals(monkeypatch, fake_vectors) - # 3 eval prompts; "bad" one returns None on baseline, others return 0.5 - def score_fn(prompt, response): - if prompt == "bad_prompt": - return None + received_rows = [] + + def score_fn(row, baseline, steered): + received_rows.append(row) return 0.5 config = SteeringSearchConfig( - candidate_paths=["p0"], - max_simultaneous_paths=1, + candidate_paths=["p0", "p1"], + max_simultaneous_paths=2, alpha_range=(0.1, 3.0), n_calls=5, - n_initial_points=3, + n_initial_points=5, ) - out = steering_search( + steering_search( config=config, model=None, positive_prompts=["pos"], negative_prompts=["neg"], - eval_prompts=["good1", "bad_prompt", "good2"], + eval_rows=[{"prompt": "eval", "task_id": "T1", "extra": 42}], score=score_fn, max_new_tokens=10, - ineligible_score=1.0, ) - # All non-penalty trials should have score 0.5 (only good prompts evaluated) - for t in out.trials: - if t.selected_paths: - assert t.score == 0.5 + non_penalty = [r for r in received_rows if r is not None] + assert len(non_penalty) > 0 + assert non_penalty[0]["task_id"] == "T1" + assert non_penalty[0]["extra"] == 42 -def test_steered_none_uses_ineligible_score(monkeypatch): - """When steering makes an eligible prompt return None, ineligible_score is used.""" - fake_vectors = {"p0": torch.randn(1, 8)} - _patch_internals(monkeypatch, fake_vectors) +# --------------------------------------------------------------------------- +# _group_paths_by_layer +# --------------------------------------------------------------------------- + - call_count = [0] +def test_group_paths_same_layer_grouped_together(): + paths = [ + "model.transformer.h[10].output[-1]", + "model.transformer.h[10].output[-2]", + "model.transformer.h[11].output[-1]", + ] + groups = _group_paths_by_layer(paths) + assert len(groups) == 2 + assert groups[0] == [ + "model.transformer.h[10].output[-1]", + "model.transformer.h[10].output[-2]", + ] + assert groups[1] == ["model.transformer.h[11].output[-1]"] - def score_fn(prompt, response): - call_count[0] += 1 - # Baseline call returns 0.5 (eligible), steered calls return None - if response == "baseline response": - return 0.5 - return None - config = SteeringSearchConfig( - candidate_paths=["p0"], - max_simultaneous_paths=1, - alpha_range=(0.1, 3.0), - n_calls=3, - n_initial_points=3, +def test_group_paths_preserves_insertion_order(): + paths = [ + "model.transformer.h[5].output[-1]", + "model.transformer.h[3].output[-1]", + "model.transformer.h[5].output[-2]", + ] + groups = _group_paths_by_layer(paths) + assert len(groups) == 2 + assert groups[0][0].startswith("model.transformer.h[5]") + assert groups[1][0].startswith("model.transformer.h[3]") + + +def test_group_paths_no_token_index(): + # Paths without a trailing [N] each form their own group. + paths = ["model.transformer.h[10].output", "model.transformer.h[11].output"] + groups = _group_paths_by_layer(paths) + assert len(groups) == 2 + + +# --------------------------------------------------------------------------- +# layer_sweep +# --------------------------------------------------------------------------- + + +def test_layer_sweep_groups_by_layer(monkeypatch): + # Two paths in the same layer → one trial; one path in a different layer → second trial. + paths = [ + "model.transformer.h[10].output[-1]", + "model.transformer.h[10].output[-2]", + "model.transformer.h[11].output[-1]", + ] + fake_vectors = {p: torch.randn(1, 8) for p in paths} + _patch_internals(monkeypatch, fake_vectors) + + out = layer_sweep( + paths=paths, + alpha=2.0, + model=None, + positive_prompts=["pos"], + negative_prompts=["neg"], + eval_rows=[{"prompt": "eval1"}, {"prompt": "eval2"}], + score=lambda row, b, s: 0.5, + max_new_tokens=10, ) - out = steering_search( - config=config, + + assert len(out.trials) == 2 + assert out.trials[0].selected_paths == [ + "model.transformer.h[10].output[-1]", + "model.transformer.h[10].output[-2]", + ] + assert out.trials[1].selected_paths == ["model.transformer.h[11].output[-1]"] + for i, trial in enumerate(out.trials): + assert trial.trial_index == i + assert trial.alpha == 2.0 + assert trial.raw_point == [i, 2.0] + + +def test_layer_sweep_empty_paths_raises(): + with pytest.raises(ValueError, match="paths must not be empty"): + layer_sweep( + paths=[], + alpha=1.0, + model=None, + positive_prompts=[], + negative_prompts=[], + eval_rows=[], + score=lambda row, b, s: 0.0, + max_new_tokens=10, + ) + + +def test_layer_sweep_jsonl_output(tmp_path, monkeypatch): + import json + + paths = [ + "model.transformer.h[10].output[-1]", + "model.transformer.h[10].output[-2]", + "model.transformer.h[11].output[-1]", + ] + fake_vectors = {p: torch.randn(1, 8) for p in paths} + _patch_internals(monkeypatch, fake_vectors) + + output = tmp_path / "sweep.jsonl" + layer_sweep( + paths=paths, + alpha=3.0, model=None, positive_prompts=["pos"], negative_prompts=["neg"], - eval_prompts=["eval"], - score=score_fn, + eval_rows=[{"prompt": "eval"}], + score=lambda row, b, s: 0.0, max_new_tokens=10, - ineligible_score=1.0, + output_path=output, ) - # All non-penalty trials should use ineligible_score since steered returns None - for t in out.trials: - if t.selected_paths: - assert t.score == 1.0 + + lines = output.read_text().strip().split("\n") + assert len(lines) == 2 # two groups → two trials + for line in lines: + record = json.loads(line) + assert "trial_index" in record + assert "selected_paths" in record + assert record["alpha"] == 3.0
@@ -58,8 +58,8 @@ def test_steering_changes_sentiment(gpt2, tmp_path): ] eval_prompt = "The weather today is" - def score(prompt: str, response: str) -> float: - return _sentiment_score(response) + def score(row: dict, baseline: str, steered: str) -> float: + return _sentiment_score(steered) config = SteeringSearchConfig( candidate_paths=candidate_paths, @@ -75,10 +75,9 @@ def test_steering_changes_sentiment(gpt2, tmp_path): model=gpt2, positive_prompts=positive_prompts, negative_prompts=negative_prompts, - eval_prompts=[eval_prompt], + eval_rows=[{"prompt": eval_prompt}], score=score, max_new_tokens=4, - ineligible_score=1.0, output_path=output, )