Repositories / more_nnsight.git
more_nnsight.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
@@ -25,6 +25,7 @@ class SteeringSearchConfig: n_calls: int n_initial_points: int = 10 seed: int = 0 + grid_first: bool = False @dataclass(frozen=True, slots=True) @@ -324,10 +325,24 @@ def steering_search( _append_jsonl(output_path, result) return trial_score + x0, y0, n_calls = None, None, config.n_calls + if config.grid_first: + # Sweep every single path at alpha=1.0 before handing off to the GP. + x0 = [[i] + [n_candidates] * (m - 1) + [1.0] for i in range(n_candidates)] + y0 = [objective(pt) for pt in x0] + n_calls = config.n_calls - n_candidates + if n_calls < 1: + raise ValueError( + f"n_calls ({config.n_calls}) must exceed the number of candidate_paths " + f"({n_candidates}) when grid_first=True to leave budget for the GP." + ) + gp_minimize( func=objective, dimensions=dimensions, - n_calls=config.n_calls, + x0=x0, + y0=y0, + n_calls=n_calls, n_initial_points=config.n_initial_points, random_state=config.seed, )