Repositories / more_nnsight.git
src/more_nnsight/steering_search.py
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
from __future__ import annotations
import json
import re
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Callable
import torch
from skopt import gp_minimize
from skopt.space import Integer, Real
from .saved_activation import SavedActivation, save_activations, updates
PENALTY = 1e6
@dataclass(frozen=True, slots=True)
class SteeringSearchConfig:
"""Configuration for a Bayesian-optimized steering vector search."""
candidate_paths: list[str]
max_simultaneous_paths: int
alpha_range: tuple[float, float]
n_calls: int
n_initial_points: int = 10
seed: int = 0
grid_first: bool = False
@dataclass(frozen=True, slots=True)
class SteeringTrialResult:
"""One trial from the search."""
trial_index: int
selected_paths: list[str]
alpha: float
score: float
raw_point: list[int | float]
responses: list[str]
@dataclass(frozen=True, slots=True)
class SteeringSearchOutput:
"""Full output from a steering search run."""
baseline_responses: list[str]
trials: list[SteeringTrialResult]
def best_trial(results: list[SteeringTrialResult]) -> SteeringTrialResult:
"""Return the trial with the lowest score."""
return min(results, key=lambda r: r.score)
def _decode_point(
point: list[Any],
n_candidates: int,
candidate_paths: list[str],
) -> tuple[list[str], float]:
"""Decode a raw skopt point into (selected_paths, alpha).
Path slot indices equal to n_candidates are sentinels meaning "unused".
Indices are sorted and deduplicated.
"""
*path_indices, alpha = point
unique_indices = sorted(set(idx for idx in path_indices if idx < n_candidates))
selected = [candidate_paths[i] for i in unique_indices]
return selected, float(alpha)
def _append_jsonl(path: Path, result: SteeringTrialResult) -> None:
with open(path, "a") as f:
f.write(json.dumps(asdict(result)) + "\n")
@torch.no_grad()
def _compute_steering_vectors(
model: Any,
candidate_paths: list[str],
positive_prompts: list[str],
negative_prompts: list[str],
) -> dict[str, torch.Tensor]:
"""Compute mean-difference steering vectors for each candidate path.
For each path: vector = mean(positive_activations) - mean(negative_activations).
Processes each prompt individually.
"""
sums: dict[str, dict[str, Any]] = {p: {} for p in candidate_paths}
for prompts, label in [(positive_prompts, "pos"), (negative_prompts, "neg")]:
for prompt in prompts:
with model.trace() as tracer:
with tracer.invoke(prompt):
saved = save_activations(model, candidate_paths)
for path in candidate_paths:
tensor = (
torch.as_tensor(saved.get(path))
.detach()
.float()
.cpu()
.reshape(1, -1)
)
bucket = sums[path]
if label not in bucket:
bucket[label] = tensor.clone()
bucket[f"{label}_n"] = 1
else:
bucket[label] = bucket[label] + tensor
bucket[f"{label}_n"] += 1
vectors: dict[str, torch.Tensor] = {}
for path, bucket in sums.items():
pos_mean = bucket["pos"] / bucket["pos_n"]
neg_mean = bucket["neg"] / bucket["neg_n"]
vectors[path] = pos_mean - neg_mean
return vectors
@torch.no_grad()
def _generate_batch(
model: Any,
prompts: list[str],
max_new_tokens: int,
) -> list[str]:
"""Generate responses for a batch of prompts without steering."""
tokenizer = model.tokenizer
orig_padding_side = getattr(tokenizer, "padding_side", "right")
orig_pad_token = tokenizer.pad_token
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
try:
encoded = tokenizer(
prompts, return_tensors="pt", padding=True, add_special_tokens=False
)
total_len = int(encoded["input_ids"].shape[1])
with model.generate(max_new_tokens=max_new_tokens, do_sample=False) as tracer:
with tracer.invoke(prompts):
output_ids = model.generator.output.save()
responses: list[str] = []
for i in range(len(prompts)):
response = tokenizer.decode(
output_ids[i][total_len:], skip_special_tokens=True
)
responses.append(response)
return responses
finally:
tokenizer.padding_side = orig_padding_side
tokenizer.pad_token = orig_pad_token
@torch.no_grad()
def _generate_steered_batch(
model: Any,
prompts: list[str],
direction: SavedActivation,
alpha: float,
max_new_tokens: int,
) -> list[str]:
"""Generate responses for a batch of prompts with steering applied."""
tokenizer = model.tokenizer
orig_padding_side = getattr(tokenizer, "padding_side", "right")
orig_pad_token = tokenizer.pad_token
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
try:
encoded = tokenizer(
prompts, return_tensors="pt", padding=True, add_special_tokens=False
)
total_len = int(encoded["input_ids"].shape[1])
with model.generate(max_new_tokens=max_new_tokens, do_sample=False) as tracer:
with tracer.invoke(prompts):
for key, current_value, update in updates(model, direction.keys()):
steer = direction.get(key).to(
device=current_value.device,
dtype=current_value.dtype,
)
update(current_value + alpha * steer)
output_ids = model.generator.output.save()
responses: list[str] = []
for i in range(len(prompts)):
response = tokenizer.decode(
output_ids[i][total_len:], skip_special_tokens=True
)
responses.append(response)
return responses
finally:
tokenizer.padding_side = orig_padding_side
tokenizer.pad_token = orig_pad_token
@torch.no_grad()
def _generate_steered(
model: Any,
prompt: str,
direction: SavedActivation,
alpha: float,
max_new_tokens: int,
) -> str:
"""Generate a response with steering applied, return the response text."""
return _generate_steered_batch(model, [prompt], direction, alpha, max_new_tokens)[0]
def _evaluate_steered(
selected_paths: list[str],
alpha: float,
vectors: dict[str, torch.Tensor],
eval_rows: list[dict],
baseline_responses: list[str],
model: Any,
score: Callable[[dict, str, str], float],
max_new_tokens: int,
batch_size: int,
) -> tuple[list[str], float]:
"""Steer with the given paths+alpha and score the results.
Returns ``(responses, mean_score)`` over ``eval_rows``.
"""
direction = SavedActivation.from_pairs(*[(p, vectors[p]) for p in selected_paths])
trial_scores: list[float] = []
trial_responses: list[str] = []
for i in range(0, len(eval_rows), batch_size):
batch_rows = eval_rows[i : i + batch_size]
batch_baselines = baseline_responses[i : i + batch_size]
batch_prompts = [row["prompt"] for row in batch_rows]
batch_responses = _generate_steered_batch(
model, batch_prompts, direction, alpha, max_new_tokens
)
trial_responses.extend(batch_responses)
for row, baseline, steered in zip(batch_rows, batch_baselines, batch_responses):
trial_scores.append(score(row, baseline, steered))
return trial_responses, sum(trial_scores) / len(trial_scores)
@torch.no_grad()
def steering_search(
config: SteeringSearchConfig,
model: Any,
positive_prompts: list[str],
negative_prompts: list[str],
eval_rows: list[dict],
score: Callable[[dict, str, str], float],
max_new_tokens: int,
output_path: Path | None = None,
batch_size: int = 1,
) -> SteeringSearchOutput:
"""Run Bayesian-optimized search over steering path subsets and alpha values.
Args:
config: Search configuration (candidate paths, budget, etc.).
model: An nnsight LanguageModel (already loaded).
positive_prompts: Prompts for the "positive" class (used to compute
steering vectors).
negative_prompts: Prompts for the "negative" class (used to compute
steering vectors).
eval_rows: Rows to evaluate on each trial. Each row is a dict that
**must** contain a ``"prompt"`` key (the formatted prompt string).
All other keys are opaque metadata passed through to ``score``.
The caller should pre-filter to rows whose baselines are scorable.
score: Scoring function. Takes ``(row, baseline_response,
steered_response)`` and returns a float where **lower is better**.
max_new_tokens: Maximum tokens to generate per prompt.
output_path: If provided, write trial results as JSONL (appended
incrementally so partial results survive crashes).
batch_size: Number of prompts to generate in parallel per batch.
Returns:
SteeringSearchOutput with baseline_responses and all trial results.
"""
if not config.candidate_paths:
raise ValueError("candidate_paths must not be empty")
n_candidates = len(config.candidate_paths)
m = config.max_simultaneous_paths
# Compute steering vectors once up front.
vectors = _compute_steering_vectors(
model, config.candidate_paths, positive_prompts, negative_prompts
)
# Generate baseline responses for all eval rows.
eval_prompts = [row["prompt"] for row in eval_rows]
baseline_responses: list[str] = []
print(f"Generating baselines for {len(eval_prompts)} eval prompts...")
for i in range(0, len(eval_prompts), batch_size):
batch = eval_prompts[i : i + batch_size]
baseline_responses.extend(_generate_batch(model, batch, max_new_tokens))
print(f"Baselines generated: {len(baseline_responses)}")
if not eval_rows:
raise ValueError("eval_rows must not be empty")
# Clear output file if it exists.
if output_path is not None:
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text("")
# Build search space.
dimensions = [
Integer(0, n_candidates, name=f"path_{i}") for i in range(m)
] + [
Real(config.alpha_range[0], config.alpha_range[1], name="alpha"),
]
trials: list[SteeringTrialResult] = []
def objective(point: list[Any]) -> float:
selected_paths, alpha = _decode_point(point, n_candidates, config.candidate_paths)
if not selected_paths:
trial_score = PENALTY
trial_responses: list[str] = []
else:
trial_responses, trial_score = _evaluate_steered(
selected_paths, alpha, vectors, eval_rows,
baseline_responses, model, score, max_new_tokens, batch_size,
)
result = SteeringTrialResult(
trial_index=len(trials),
selected_paths=selected_paths,
alpha=alpha,
score=trial_score,
raw_point=[int(x) if isinstance(x, (int,)) else float(x) for x in point],
responses=trial_responses,
)
trials.append(result)
if output_path is not None:
_append_jsonl(output_path, result)
return trial_score
x0, y0, n_calls = None, None, config.n_calls
if config.grid_first:
# Sweep every single path at alpha=1.0 before handing off to the GP.
x0 = [[i] + [n_candidates] * (m - 1) + [1.0] for i in range(n_candidates)]
y0 = [objective(pt) for pt in x0]
n_calls = config.n_calls - n_candidates
if n_calls < 1:
raise ValueError(
f"n_calls ({config.n_calls}) must exceed the number of candidate_paths "
f"({n_candidates}) when grid_first=True to leave budget for the GP."
)
gp_minimize(
func=objective,
dimensions=dimensions,
x0=x0,
y0=y0,
n_calls=n_calls,
n_initial_points=config.n_initial_points,
random_state=config.seed,
)
return SteeringSearchOutput(baseline_responses=baseline_responses, trials=trials)
@torch.no_grad()
def steer_and_eval(
paths: list[str],
alpha: float,
model: Any,
eval_rows: list[dict],
score: Callable[[dict, str, str], float],
max_new_tokens: int,
*,
positive_prompts: list[str] | None = None,
negative_prompts: list[str] | None = None,
steering_vectors: dict[str, torch.Tensor] | None = None,
output_path: Path | None = None,
batch_size: int = 1,
) -> SteeringSearchOutput:
"""Steer on the given paths at a fixed alpha and evaluate eval_rows.
Provide either ``positive_prompts`` and ``negative_prompts`` together to
compute steering vectors, or pass pre-computed ``steering_vectors`` directly.
Exactly one of the two options must be used.
Args:
paths: Activation paths to steer on.
alpha: Steering coefficient applied at every path.
model: An nnsight LanguageModel (already loaded).
eval_rows: Rows to evaluate. Each row must contain a ``"prompt"`` key;
other keys are opaque metadata passed through to ``score``.
score: Scoring function. Takes ``(row, baseline_response,
steered_response)`` and returns a float (lower is better).
max_new_tokens: Maximum tokens to generate per prompt.
positive_prompts: Prompts for the "positive" class used to compute
steering vectors. Must be paired with ``negative_prompts``.
Mutually exclusive with ``steering_vectors``.
negative_prompts: Prompts for the "negative" class used to compute
steering vectors. Must be paired with ``positive_prompts``.
Mutually exclusive with ``steering_vectors``.
steering_vectors: Pre-computed steering vectors keyed by path.
Mutually exclusive with ``positive_prompts`` / ``negative_prompts``.
output_path: If provided, write the single trial result as one JSONL line.
batch_size: Number of prompts to generate in parallel per batch.
Returns:
SteeringSearchOutput with baseline_responses and a single trial.
"""
prompts_provided = positive_prompts is not None or negative_prompts is not None
vectors_provided = steering_vectors is not None
if prompts_provided and vectors_provided:
raise ValueError(
"Provide either (positive_prompts + negative_prompts) or steering_vectors, not both."
)
if not prompts_provided and not vectors_provided:
raise ValueError(
"Provide either (positive_prompts + negative_prompts) or steering_vectors."
)
if prompts_provided and (positive_prompts is None or negative_prompts is None):
raise ValueError(
"positive_prompts and negative_prompts must both be provided together."
)
if not paths:
raise ValueError("paths must not be empty")
if not eval_rows:
raise ValueError("eval_rows must not be empty")
if vectors_provided:
vectors = steering_vectors
else:
vectors = _compute_steering_vectors(model, paths, positive_prompts, negative_prompts)
eval_prompts = [row["prompt"] for row in eval_rows]
baseline_responses: list[str] = []
for i in range(0, len(eval_prompts), batch_size):
batch = eval_prompts[i : i + batch_size]
baseline_responses.extend(_generate_batch(model, batch, max_new_tokens))
if output_path is not None:
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text("")
responses, trial_score = _evaluate_steered(
paths, alpha, vectors, eval_rows, baseline_responses, model, score, max_new_tokens, batch_size,
)
result = SteeringTrialResult(
trial_index=0,
selected_paths=paths,
alpha=alpha,
score=trial_score,
raw_point=[alpha],
responses=responses,
)
if output_path is not None:
_append_jsonl(output_path, result)
return SteeringSearchOutput(baseline_responses=baseline_responses, trials=[result])
def _group_paths_by_layer(paths: list[str]) -> list[list[str]]:
"""Group paths that share the same layer by stripping the trailing token index.
``"model.transformer.h[10].output[-1]"`` and
``"model.transformer.h[10].output[-2]"`` both strip to
``"model.transformer.h[10].output"`` and are placed in the same group,
producing a single trial that intervenes on both positions simultaneously.
Insertion order of groups is preserved.
"""
groups: dict[str, list[str]] = {}
for path in paths:
key = re.sub(r"\[-?\d+\]$", "", path)
if key not in groups:
groups[key] = []
groups[key].append(path)
return list(groups.values())
@torch.no_grad()
def layer_sweep(
paths: list[str],
alpha: float,
model: Any,
positive_prompts: list[str],
negative_prompts: list[str],
eval_rows: list[dict],
score: Callable[[dict, str, str], float],
max_new_tokens: int,
output_path: Path | None = None,
batch_size: int = 1,
) -> SteeringSearchOutput:
"""Run one steering trial per layer at a fixed alpha value.
Paths that share the same layer (differing only in their trailing token
position index) are grouped automatically and intervened on simultaneously
within a single trial. For example, passing
``"model.transformer.h[10].output[-1]"`` and
``"model.transformer.h[10].output[-2]"`` produces one trial that steers
both positions at once.
Args:
paths: Activation paths to sweep. Grouped by layer internally.
alpha: Fixed steering coefficient applied in every trial.
model: An nnsight LanguageModel (already loaded).
positive_prompts: Prompts for the "positive" class.
negative_prompts: Prompts for the "negative" class.
eval_rows: Rows to evaluate on each trial. Each row is a dict that
**must** contain a ``"prompt"`` key. Other keys are opaque metadata
passed through to ``score``.
score: Scoring function. Takes ``(row, baseline_response,
steered_response)`` and returns a float (lower is better).
max_new_tokens: Maximum tokens to generate per prompt.
output_path: If provided, append trial results as JSONL incrementally.
batch_size: Number of prompts to generate in parallel per batch.
Returns:
SteeringSearchOutput with baseline_responses and one trial per layer.
"""
if not paths:
raise ValueError("paths must not be empty")
groups = _group_paths_by_layer(paths)
vectors = _compute_steering_vectors(model, paths, positive_prompts, negative_prompts)
# Generate baseline responses for all eval rows.
eval_prompts = [row["prompt"] for row in eval_rows]
baseline_responses: list[str] = []
print(f"Generating baselines for {len(eval_prompts)} eval prompts...")
for i in range(0, len(eval_prompts), batch_size):
batch = eval_prompts[i : i + batch_size]
baseline_responses.extend(_generate_batch(model, batch, max_new_tokens))
print(f"Baselines generated: {len(baseline_responses)}")
if not eval_rows:
raise ValueError("eval_rows must not be empty")
if output_path is not None:
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text("")
trials: list[SteeringTrialResult] = []
for trial_index, group in enumerate(groups):
trial_responses, trial_score = _evaluate_steered(
group, alpha, vectors, eval_rows,
baseline_responses, model, score, max_new_tokens, batch_size,
)
result = SteeringTrialResult(
trial_index=trial_index,
selected_paths=group,
alpha=alpha,
score=trial_score,
raw_point=[trial_index, alpha],
responses=trial_responses,
)
trials.append(result)
if output_path is not None:
_append_jsonl(output_path, result)
label = group[0] if len(group) == 1 else f"{group[0]} (+{len(group) - 1} more)"
print(f" [{trial_index + 1}/{len(groups)}] {label}: score={trial_score:.4f}")
return SteeringSearchOutput(baseline_responses=baseline_responses, trials=trials)