Repositories / more_nnsight.git

more_nnsight.git

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch

Add grid_first mode to steering_search

Sweeps all single-path candidates at alpha=1.0 as warm-start x0/y0
before handing off to the GP, consuming that many calls from the budget.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Author
Arjun Guha <a.guha@northeastern.edu>
Date
2026-04-02 18:05:35 -0400
Commit
4dfce2b9bde4b95a0249f9e401013aaafef14aac
src/more_nnsight/steering_search.py
index 4ed3ace..34779b1 100644
--- a/src/more_nnsight/steering_search.py
+++ b/src/more_nnsight/steering_search.py
@@ -25,6 +25,7 @@ class SteeringSearchConfig:
     n_calls: int
     n_initial_points: int = 10
     seed: int = 0
+    grid_first: bool = False
 
 
 @dataclass(frozen=True, slots=True)
@@ -324,10 +325,24 @@ def steering_search(
             _append_jsonl(output_path, result)
         return trial_score
 
+    x0, y0, n_calls = None, None, config.n_calls
+    if config.grid_first:
+        # Sweep every single path at alpha=1.0 before handing off to the GP.
+        x0 = [[i] + [n_candidates] * (m - 1) + [1.0] for i in range(n_candidates)]
+        y0 = [objective(pt) for pt in x0]
+        n_calls = config.n_calls - n_candidates
+        if n_calls < 1:
+            raise ValueError(
+                f"n_calls ({config.n_calls}) must exceed the number of candidate_paths "
+                f"({n_candidates}) when grid_first=True to leave budget for the GP."
+            )
+
     gp_minimize(
         func=objective,
         dimensions=dimensions,
-        n_calls=config.n_calls,
+        x0=x0,
+        y0=y0,
+        n_calls=n_calls,
         n_initial_points=config.n_initial_points,
         random_state=config.seed,
     )