Repositories / more_nnsight.git
more_nnsight.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
@@ -0,0 +1,351 @@ +# SavedActivation + +`SavedActivation` is for the point where plain NNSight starts to get awkward: +you are no longer saving one activation, but several, and now you have to keep +track of which tensor came from which layer and position. It keeps those saves +together in one object, addressed by the same paths you used to create them, so +operations like subsetting, averaging, and patching stay attached to the +activations themselves instead of turning into bookkeeping. This file documents +the API and shows the equivalent direct NNSight code. + +## Imports + +```python +from more_nnsight import SavedActivation, save_activations +``` + +## Core Rule + +`save_activations(...)` must be called inside an already-active NNSight +`trace`/`invoke` context. + +```python +with model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(model, ["model.transformer.h[2].output[10]"]) +``` + +This matches ordinary NNSight usage: activation saves only work inside a trace. + +## Path Syntax + +Paths follow the model's real attribute/index structure. + +Examples: + +- `model.transformer.h[2].output[10]` +- `model.transformer.h[3].output[-1]` +- `model.transformer.h[:].output[10]` +- `model.model.layers[:].output[2]` + +Everything before the final bracket names the activation tensor. The final +bracket gives the token position to save. Intermediate `[:]` syntax expands +over repeated blocks such as GPT-2 `transformer.h[:]` or Qwen `model.layers[:]`. + +For example, `model.transformer.h[2].output[10]` means "take layer 2, take its +output, and save token position 10", producing a tensor of shape +`(batch_size, hidden_size)`. + +## Saving Activations + +### Single path + +```python +with model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(model, ["model.transformer.h[2].output[10]"]) +``` + +Direct NNSight equivalent: + +```python +with model.trace() as tracer: + with tracer.invoke(prompts): + direct = model.transformer.h[2].output[:, 10, :].save() +``` + +Equivalent access: + +```python +saved.get("model.transformer.h[2].output[10]") == direct +``` + +### Multiple paths + +```python +with model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations( + model, + [ + "model.transformer.h[2].output[10]", + "model.transformer.h[3].output[-1]", + ], + ) +``` + +Direct NNSight equivalent: + +```python +with model.trace() as tracer: + with tracer.invoke(prompts): + layer_2 = model.transformer.h[2].output[:, 10, :].save() + layer_3 = model.transformer.h[3].output[:, -1, :].save() +``` + +### All layers with `[:]` + +```python +with model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(model, ["model.transformer.h[:].output[10]"]) +``` + +Direct NNSight equivalent: + +```python +with model.trace() as tracer: + with tracer.invoke(prompts): + direct = [ + model.transformer.h[layer].output[:, 10, :].save() + for layer in range(len(model.transformer.h)) + ] +``` + +The saved keys become concrete: + +```python +saved.keys() +# [ +# "model.transformer.h[0].output[10]", +# "model.transformer.h[1].output[10]", +# ... +# ] +``` + +## Accessing Saved Values + +Saved values are exposed under `.values`. + +```python +tensor = saved.values.transformer.h[2].output[10] +``` + +Equivalent string lookup: + +```python +tensor = saved.get("model.transformer.h[2].output[10]") +``` + +`saved.keys()` returns the saved path strings, and `saved.get(path)` returns one +saved tensor. Missing paths raise immediately. + +## Subsetting by Path + +```python +focused = saved.subset(["model.transformer.h[2].output[10]"]) +``` + +This keeps only the listed saved activations and drops the rest. In direct +NNSight, you would usually do this by manually building a smaller Python +structure. + +## Batch Slicing + +Use bracket syntax on `.slice`: + +```python +first = saved.slice[0] +first_two = saved.slice[0:2] +mixed = saved.slice[0:2, 5, 8:10] +``` + +This slices the batch dimension of every saved tensor. + +Direct NNSight equivalent for one path: + +```python +saved.get("model.transformer.h[2].output[10]")[0:2] +``` + +The point is that the same batch selection is applied across every saved +activation at once. + +## Mean Over Batch + +```python +mean_saved = saved.mean() +``` + +This reduces each saved tensor from shape `(batch_size, hidden_size)` to +`(1, hidden_size)`. + +Direct NNSight equivalent for one path: + +```python +with model.trace() as tracer: + with tracer.invoke(prompts): + direct_mean = model.transformer.h[2].output[:, 10, :].mean(dim=0, keepdim=True).save() +``` + +If `saved.mean()` is called inside the active trace, the reduction happens +there before the reduced value is saved. That is more memory-efficient than +saving the full batch and averaging later. + +## Arithmetic + +If two `SavedActivation` objects have the same keys, you can combine them: + +```python +direction = positive.mean() - negative.mean() +steered = neutral + 1.5 * direction +``` + +Supported operations are `a + b`, `a - b`, `scalar * a`, and `a * scalar`. +They are applied elementwise across matching saved tensors. + +Direct NNSight equivalent for one path: + +```python +direction = positive_tensor.mean(dim=0, keepdim=True) - negative_tensor.mean(dim=0, keepdim=True) +steered = neutral_tensor + 1.5 * direction +``` + +## Applying Saved Activations + +You can patch a later run with: + +```python +with model.trace() as tracer: + with tracer.invoke(corrupted_prompts): + saved.apply(model) +``` + +This writes each stored tensor back into the live traced activation at its +saved path and token position. + +Direct NNSight equivalent for one path: + +```python +with model.trace() as tracer: + with tracer.invoke(corrupted_prompts): + model.transformer.h[2].output[:, 10, :] = saved_tensor +``` + +`SavedActivation.apply(model)` performs that assignment for every saved key. + +## `saved.save()` + +`SavedActivation.save()` is different from `save_activations(...)`. + +- `save_activations(...)` captures activation values +- `saved.save()` registers the `SavedActivation` object itself with NNSight + +New `SavedActivation` objects created inside the trace, such as the result of +`save_activations(...)`, `saved.mean()`, `saved.subset(...)`, or arithmetic, +are registered automatically by the library, so the usual pattern works: + +```python +with model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(model, ["model.transformer.h[2].output[10]"]) + mean_saved = saved.mean() +``` + +Both `saved` and `mean_saved` remain usable after trace exit. You only need +`.save()` if you want to register an existing `SavedActivation` object +yourself inside the trace. + +## Activation Steering Example + +Single forward pass: + +```python +positive_prompts = [ + "The movie was absolutely wonderful and I felt", + "The dinner was excellent and I left feeling", + "The vacation was amazing and it made me feel", +] +negative_prompts = [ + "The movie was absolutely terrible and I felt", + "The dinner was awful and I left feeling", + "The vacation was horrible and it made me feel", +] +neutral_prompt = "The day was long and by the end I felt" +prompts = positive_prompts + negative_prompts + [neutral_prompt] + +path = "model.transformer.h[5].output[-1]" + +with model.trace() as tracer: + with tracer.invoke(prompts): + saved = save_activations(model, [path]) + positive = saved.slice[0:3].mean() + negative = saved.slice[3:6].mean() + neutral = saved.slice[6] + +direction = positive - negative +steered = neutral + 1.5 * direction +``` + +Direct NNSight equivalent for one path: + +```python +with model.trace() as tracer: + with tracer.invoke(prompts): + full = model.transformer.h[5].output[:, -1, :].save() + +positive = full[0:3].mean(dim=0, keepdim=True) +negative = full[3:6].mean(dim=0, keepdim=True) +neutral = full[6:7] +direction = positive - negative +steered = neutral + 1.5 * direction +``` + +The difference is that `SavedActivation` keeps the same pattern workable when +you are carrying several saved paths at once instead of a single tensor. + +## Patching Example + +```python +clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to" +corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" +patch_paths = [ + "model.transformer.h[0].output[9]", + "model.transformer.h[1].output[9]", + "model.transformer.h[2].output[9]", + "model.transformer.h[3].output[9]", +] + +with model.trace() as tracer: + with tracer.invoke([clean_prompt]): + clean_saved = save_activations(model, patch_paths) + +focused_patch = clean_saved.subset(["model.transformer.h[2].output[9]"]) + +with model.trace() as tracer: + with tracer.invoke([corrupted_prompt]): + focused_patch.apply(model) + logits = model.lm_head.output.save() +``` + +Direct NNSight equivalent for one path: + +```python +with model.trace() as tracer: + with tracer.invoke([clean_prompt]): + clean_layer_2 = model.transformer.h[2].output[:, 9, :].save() + +with model.trace() as tracer: + with tracer.invoke([corrupted_prompt]): + model.transformer.h[2].output[:, 9, :] = clean_layer_2 + logits = model.lm_head.output.save() +``` + +## Other Models + +Use the model's real path names. For example, a Qwen-style decoder stack can +use: + +```python +save_activations(model, ["model.model.layers[:].output[2]"]) +```