Repositories / more_nnsight.git

src/more_nnsight/steering_search.py

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch
20786 bytes · 11bedd825604
from __future__ import annotations import json import re from dataclasses import dataclass, asdict from pathlib import Path from typing import Any, Callable import torch from skopt import gp_minimize from skopt.space import Integer, Real from .saved_activation import SavedActivation, save_activations, updates PENALTY = 1e6 @dataclass(frozen=True, slots=True) class SteeringSearchConfig: """Configuration for a Bayesian-optimized steering vector search.""" candidate_paths: list[str] max_simultaneous_paths: int alpha_range: tuple[float, float] n_calls: int n_initial_points: int = 10 seed: int = 0 grid_first: bool = False @dataclass(frozen=True, slots=True) class SteeringTrialResult: """One trial from the search.""" trial_index: int selected_paths: list[str] alpha: float score: float raw_point: list[int | float] responses: list[str] @dataclass(frozen=True, slots=True) class SteeringSearchOutput: """Full output from a steering search run.""" baseline_responses: list[str] trials: list[SteeringTrialResult] def best_trial(results: list[SteeringTrialResult]) -> SteeringTrialResult: """Return the trial with the lowest score.""" return min(results, key=lambda r: r.score) def _decode_point( point: list[Any], n_candidates: int, candidate_paths: list[str], ) -> tuple[list[str], float]: """Decode a raw skopt point into (selected_paths, alpha). Path slot indices equal to n_candidates are sentinels meaning "unused". Indices are sorted and deduplicated. """ *path_indices, alpha = point unique_indices = sorted(set(idx for idx in path_indices if idx < n_candidates)) selected = [candidate_paths[i] for i in unique_indices] return selected, float(alpha) def _append_jsonl(path: Path, result: SteeringTrialResult) -> None: with open(path, "a") as f: f.write(json.dumps(asdict(result)) + "\n") @torch.no_grad() def _compute_steering_vectors( model: Any, candidate_paths: list[str], positive_prompts: list[str], negative_prompts: list[str], ) -> dict[str, torch.Tensor]: """Compute mean-difference steering vectors for each candidate path. For each path: vector = mean(positive_activations) - mean(negative_activations). Processes each prompt individually. """ sums: dict[str, dict[str, Any]] = {p: {} for p in candidate_paths} for prompts, label in [(positive_prompts, "pos"), (negative_prompts, "neg")]: for prompt in prompts: with model.trace() as tracer: with tracer.invoke(prompt): saved = save_activations(model, candidate_paths) for path in candidate_paths: tensor = ( torch.as_tensor(saved.get(path)) .detach() .float() .cpu() .reshape(1, -1) ) bucket = sums[path] if label not in bucket: bucket[label] = tensor.clone() bucket[f"{label}_n"] = 1 else: bucket[label] = bucket[label] + tensor bucket[f"{label}_n"] += 1 vectors: dict[str, torch.Tensor] = {} for path, bucket in sums.items(): pos_mean = bucket["pos"] / bucket["pos_n"] neg_mean = bucket["neg"] / bucket["neg_n"] vectors[path] = pos_mean - neg_mean return vectors @torch.no_grad() def _generate_batch( model: Any, prompts: list[str], max_new_tokens: int, ) -> list[str]: """Generate responses for a batch of prompts without steering.""" tokenizer = model.tokenizer orig_padding_side = getattr(tokenizer, "padding_side", "right") orig_pad_token = tokenizer.pad_token tokenizer.padding_side = "left" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token try: encoded = tokenizer( prompts, return_tensors="pt", padding=True, add_special_tokens=False ) total_len = int(encoded["input_ids"].shape[1]) with model.generate(max_new_tokens=max_new_tokens, do_sample=False) as tracer: with tracer.invoke(prompts): output_ids = model.generator.output.save() responses: list[str] = [] for i in range(len(prompts)): response = tokenizer.decode( output_ids[i][total_len:], skip_special_tokens=True ) responses.append(response) return responses finally: tokenizer.padding_side = orig_padding_side tokenizer.pad_token = orig_pad_token @torch.no_grad() def _generate_steered_batch( model: Any, prompts: list[str], direction: SavedActivation, alpha: float, max_new_tokens: int, ) -> list[str]: """Generate responses for a batch of prompts with steering applied.""" tokenizer = model.tokenizer orig_padding_side = getattr(tokenizer, "padding_side", "right") orig_pad_token = tokenizer.pad_token tokenizer.padding_side = "left" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token try: encoded = tokenizer( prompts, return_tensors="pt", padding=True, add_special_tokens=False ) total_len = int(encoded["input_ids"].shape[1]) with model.generate(max_new_tokens=max_new_tokens, do_sample=False) as tracer: with tracer.invoke(prompts): for key, current_value, update in updates(model, direction.keys()): steer = direction.get(key).to( device=current_value.device, dtype=current_value.dtype, ) update(current_value + alpha * steer) output_ids = model.generator.output.save() responses: list[str] = [] for i in range(len(prompts)): response = tokenizer.decode( output_ids[i][total_len:], skip_special_tokens=True ) responses.append(response) return responses finally: tokenizer.padding_side = orig_padding_side tokenizer.pad_token = orig_pad_token @torch.no_grad() def _generate_steered( model: Any, prompt: str, direction: SavedActivation, alpha: float, max_new_tokens: int, ) -> str: """Generate a response with steering applied, return the response text.""" return _generate_steered_batch(model, [prompt], direction, alpha, max_new_tokens)[0] def _evaluate_steered( selected_paths: list[str], alpha: float, vectors: dict[str, torch.Tensor], eval_rows: list[dict], baseline_responses: list[str], model: Any, score: Callable[[dict, str, str], float], max_new_tokens: int, batch_size: int, ) -> tuple[list[str], float]: """Steer with the given paths+alpha and score the results. Returns ``(responses, mean_score)`` over ``eval_rows``. """ direction = SavedActivation.from_pairs(*[(p, vectors[p]) for p in selected_paths]) trial_scores: list[float] = [] trial_responses: list[str] = [] for i in range(0, len(eval_rows), batch_size): batch_rows = eval_rows[i : i + batch_size] batch_baselines = baseline_responses[i : i + batch_size] batch_prompts = [row["prompt"] for row in batch_rows] batch_responses = _generate_steered_batch( model, batch_prompts, direction, alpha, max_new_tokens ) trial_responses.extend(batch_responses) for row, baseline, steered in zip(batch_rows, batch_baselines, batch_responses): trial_scores.append(score(row, baseline, steered)) return trial_responses, sum(trial_scores) / len(trial_scores) @torch.no_grad() def steering_search( config: SteeringSearchConfig, model: Any, positive_prompts: list[str], negative_prompts: list[str], eval_rows: list[dict], score: Callable[[dict, str, str], float], max_new_tokens: int, output_path: Path | None = None, batch_size: int = 1, ) -> SteeringSearchOutput: """Run Bayesian-optimized search over steering path subsets and alpha values. Args: config: Search configuration (candidate paths, budget, etc.). model: An nnsight LanguageModel (already loaded). positive_prompts: Prompts for the "positive" class (used to compute steering vectors). negative_prompts: Prompts for the "negative" class (used to compute steering vectors). eval_rows: Rows to evaluate on each trial. Each row is a dict that **must** contain a ``"prompt"`` key (the formatted prompt string). All other keys are opaque metadata passed through to ``score``. The caller should pre-filter to rows whose baselines are scorable. score: Scoring function. Takes ``(row, baseline_response, steered_response)`` and returns a float where **lower is better**. max_new_tokens: Maximum tokens to generate per prompt. output_path: If provided, write trial results as JSONL (appended incrementally so partial results survive crashes). batch_size: Number of prompts to generate in parallel per batch. Returns: SteeringSearchOutput with baseline_responses and all trial results. """ if not config.candidate_paths: raise ValueError("candidate_paths must not be empty") n_candidates = len(config.candidate_paths) m = config.max_simultaneous_paths # Compute steering vectors once up front. vectors = _compute_steering_vectors( model, config.candidate_paths, positive_prompts, negative_prompts ) # Generate baseline responses for all eval rows. eval_prompts = [row["prompt"] for row in eval_rows] baseline_responses: list[str] = [] print(f"Generating baselines for {len(eval_prompts)} eval prompts...") for i in range(0, len(eval_prompts), batch_size): batch = eval_prompts[i : i + batch_size] baseline_responses.extend(_generate_batch(model, batch, max_new_tokens)) print(f"Baselines generated: {len(baseline_responses)}") if not eval_rows: raise ValueError("eval_rows must not be empty") # Clear output file if it exists. if output_path is not None: output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text("") # Build search space. dimensions = [ Integer(0, n_candidates, name=f"path_{i}") for i in range(m) ] + [ Real(config.alpha_range[0], config.alpha_range[1], name="alpha"), ] trials: list[SteeringTrialResult] = [] def objective(point: list[Any]) -> float: selected_paths, alpha = _decode_point(point, n_candidates, config.candidate_paths) if not selected_paths: trial_score = PENALTY trial_responses: list[str] = [] else: trial_responses, trial_score = _evaluate_steered( selected_paths, alpha, vectors, eval_rows, baseline_responses, model, score, max_new_tokens, batch_size, ) result = SteeringTrialResult( trial_index=len(trials), selected_paths=selected_paths, alpha=alpha, score=trial_score, raw_point=[int(x) if isinstance(x, (int,)) else float(x) for x in point], responses=trial_responses, ) trials.append(result) if output_path is not None: _append_jsonl(output_path, result) return trial_score x0, y0, n_calls = None, None, config.n_calls if config.grid_first: # Sweep every single path at alpha=1.0 before handing off to the GP. x0 = [[i] + [n_candidates] * (m - 1) + [1.0] for i in range(n_candidates)] y0 = [objective(pt) for pt in x0] n_calls = config.n_calls - n_candidates if n_calls < 1: raise ValueError( f"n_calls ({config.n_calls}) must exceed the number of candidate_paths " f"({n_candidates}) when grid_first=True to leave budget for the GP." ) gp_minimize( func=objective, dimensions=dimensions, x0=x0, y0=y0, n_calls=n_calls, n_initial_points=config.n_initial_points, random_state=config.seed, ) return SteeringSearchOutput(baseline_responses=baseline_responses, trials=trials) @torch.no_grad() def steer_and_eval( paths: list[str], alpha: float, model: Any, eval_rows: list[dict], score: Callable[[dict, str, str], float], max_new_tokens: int, *, positive_prompts: list[str] | None = None, negative_prompts: list[str] | None = None, steering_vectors: dict[str, torch.Tensor] | None = None, output_path: Path | None = None, batch_size: int = 1, ) -> SteeringSearchOutput: """Steer on the given paths at a fixed alpha and evaluate eval_rows. Provide either ``positive_prompts`` and ``negative_prompts`` together to compute steering vectors, or pass pre-computed ``steering_vectors`` directly. Exactly one of the two options must be used. Args: paths: Activation paths to steer on. alpha: Steering coefficient applied at every path. model: An nnsight LanguageModel (already loaded). eval_rows: Rows to evaluate. Each row must contain a ``"prompt"`` key; other keys are opaque metadata passed through to ``score``. score: Scoring function. Takes ``(row, baseline_response, steered_response)`` and returns a float (lower is better). max_new_tokens: Maximum tokens to generate per prompt. positive_prompts: Prompts for the "positive" class used to compute steering vectors. Must be paired with ``negative_prompts``. Mutually exclusive with ``steering_vectors``. negative_prompts: Prompts for the "negative" class used to compute steering vectors. Must be paired with ``positive_prompts``. Mutually exclusive with ``steering_vectors``. steering_vectors: Pre-computed steering vectors keyed by path. Mutually exclusive with ``positive_prompts`` / ``negative_prompts``. output_path: If provided, write the single trial result as one JSONL line. batch_size: Number of prompts to generate in parallel per batch. Returns: SteeringSearchOutput with baseline_responses and a single trial. """ prompts_provided = positive_prompts is not None or negative_prompts is not None vectors_provided = steering_vectors is not None if prompts_provided and vectors_provided: raise ValueError( "Provide either (positive_prompts + negative_prompts) or steering_vectors, not both." ) if not prompts_provided and not vectors_provided: raise ValueError( "Provide either (positive_prompts + negative_prompts) or steering_vectors." ) if prompts_provided and (positive_prompts is None or negative_prompts is None): raise ValueError( "positive_prompts and negative_prompts must both be provided together." ) if not paths: raise ValueError("paths must not be empty") if not eval_rows: raise ValueError("eval_rows must not be empty") if vectors_provided: vectors = steering_vectors else: vectors = _compute_steering_vectors(model, paths, positive_prompts, negative_prompts) eval_prompts = [row["prompt"] for row in eval_rows] baseline_responses: list[str] = [] for i in range(0, len(eval_prompts), batch_size): batch = eval_prompts[i : i + batch_size] baseline_responses.extend(_generate_batch(model, batch, max_new_tokens)) if output_path is not None: output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text("") responses, trial_score = _evaluate_steered( paths, alpha, vectors, eval_rows, baseline_responses, model, score, max_new_tokens, batch_size, ) result = SteeringTrialResult( trial_index=0, selected_paths=paths, alpha=alpha, score=trial_score, raw_point=[alpha], responses=responses, ) if output_path is not None: _append_jsonl(output_path, result) return SteeringSearchOutput(baseline_responses=baseline_responses, trials=[result]) def _group_paths_by_layer(paths: list[str]) -> list[list[str]]: """Group paths that share the same layer by stripping the trailing token index. ``"model.transformer.h[10].output[-1]"`` and ``"model.transformer.h[10].output[-2]"`` both strip to ``"model.transformer.h[10].output"`` and are placed in the same group, producing a single trial that intervenes on both positions simultaneously. Insertion order of groups is preserved. """ groups: dict[str, list[str]] = {} for path in paths: key = re.sub(r"\[-?\d+\]$", "", path) if key not in groups: groups[key] = [] groups[key].append(path) return list(groups.values()) @torch.no_grad() def layer_sweep( paths: list[str], alpha: float, model: Any, positive_prompts: list[str], negative_prompts: list[str], eval_rows: list[dict], score: Callable[[dict, str, str], float], max_new_tokens: int, output_path: Path | None = None, batch_size: int = 1, ) -> SteeringSearchOutput: """Run one steering trial per layer at a fixed alpha value. Paths that share the same layer (differing only in their trailing token position index) are grouped automatically and intervened on simultaneously within a single trial. For example, passing ``"model.transformer.h[10].output[-1]"`` and ``"model.transformer.h[10].output[-2]"`` produces one trial that steers both positions at once. Args: paths: Activation paths to sweep. Grouped by layer internally. alpha: Fixed steering coefficient applied in every trial. model: An nnsight LanguageModel (already loaded). positive_prompts: Prompts for the "positive" class. negative_prompts: Prompts for the "negative" class. eval_rows: Rows to evaluate on each trial. Each row is a dict that **must** contain a ``"prompt"`` key. Other keys are opaque metadata passed through to ``score``. score: Scoring function. Takes ``(row, baseline_response, steered_response)`` and returns a float (lower is better). max_new_tokens: Maximum tokens to generate per prompt. output_path: If provided, append trial results as JSONL incrementally. batch_size: Number of prompts to generate in parallel per batch. Returns: SteeringSearchOutput with baseline_responses and one trial per layer. """ if not paths: raise ValueError("paths must not be empty") groups = _group_paths_by_layer(paths) vectors = _compute_steering_vectors(model, paths, positive_prompts, negative_prompts) # Generate baseline responses for all eval rows. eval_prompts = [row["prompt"] for row in eval_rows] baseline_responses: list[str] = [] print(f"Generating baselines for {len(eval_prompts)} eval prompts...") for i in range(0, len(eval_prompts), batch_size): batch = eval_prompts[i : i + batch_size] baseline_responses.extend(_generate_batch(model, batch, max_new_tokens)) print(f"Baselines generated: {len(baseline_responses)}") if not eval_rows: raise ValueError("eval_rows must not be empty") if output_path is not None: output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text("") trials: list[SteeringTrialResult] = [] for trial_index, group in enumerate(groups): trial_responses, trial_score = _evaluate_steered( group, alpha, vectors, eval_rows, baseline_responses, model, score, max_new_tokens, batch_size, ) result = SteeringTrialResult( trial_index=trial_index, selected_paths=group, alpha=alpha, score=trial_score, raw_point=[trial_index, alpha], responses=trial_responses, ) trials.append(result) if output_path is not None: _append_jsonl(output_path, result) label = group[0] if len(group) == 1 else f"{group[0]} (+{len(group) - 1} more)" print(f" [{trial_index + 1}/{len(groups)}] {label}: score={trial_score:.4f}") return SteeringSearchOutput(baseline_responses=baseline_responses, trials=trials)