Repositories / more_nnsight.git
more_nnsight.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
@@ -1,127 +1 @@ # more-nnsight - -This library adds one abstraction on top of NNSight: `SavedActivation`, a -container that holds activation slices keyed by their model path and token -position. - -NNSight already lets you save and patch individual activations. The problem -starts when you need several at once — across layers, token positions, or -prompts. You end up juggling loose tensors and remembering which one came from -where. NNSight also requires that you access modules in layer-major order -within a trace: you cannot touch layer 2, then layer 5, then go back to layer -2. `SavedActivation` keeps activations together, enforces canonical ordering -internally, and gives you batch, arithmetic, and patching operations over the -whole set. - -## API - -```python -from more_nnsight import SavedActivation, save_activations, updates -``` - -Call `save_activations` inside an NNSight trace to capture activations. Paths -follow the model's attribute structure; the final bracket is the token position. - -```python -with model.trace() as tracer: - with tracer.invoke(prompts): - saved = save_activations(model, [ - "model.transformer.h[2].output[10]", - "model.transformer.h[3].output[-1]", - ]) -``` - -Use `[:]` to expand over a `ModuleList` — `"model.transformer.h[:].output[10]"` -becomes one key per layer. Each saved tensor has shape `(batch, hidden)`. - -Values come back by string or by attribute traversal: - -```python -saved.get("model.transformer.h[2].output[10]") -saved.values.transformer.h[2].output[10] -saved.keys() # list of saved path strings -``` - -You can slice the batch dimension, take the mean, or do arithmetic across all -saved paths at once: - -```python -saved.slice[0:3] # first three batch rows -saved.mean() # (1, hidden) per path -direction = positive.mean() - negative.mean() # +, -, scalar * all work -``` - -`subset` narrows to specific paths; `union` merges two disjoint sets: - -```python -focused = saved.subset(["model.transformer.h[2].output[10]"]) -combined = patch_a.union(patch_b) -``` - -To write stored activations into a later forward pass, use `apply`: - -```python -with model.trace() as tracer: - with tracer.invoke(corrupted_prompts): - saved.apply(model) -``` - -When the replacement depends on the live activation, use `updates` instead. -It is a generator that walks paths in layer-major order and yields three values -per path: `key` (the path string), `current` (the activation from the current -forward pass at that path and token position), and `update` (a callback that -writes a new tensor back to the same location). At each step, you call -`update(new_value)` with whatever you want to write. This lets you express -per-layer logic — for example, adding a scaled steering vector at one layer -while replacing outright at another: - -```python -with model.trace() as tracer: - with tracer.invoke(prompts): - for key, current, update in updates(model, saved.keys()): - if key == "model.transformer.h[2].output[-1]": - update(current + 2.0 * saved.get(key)) - else: - update(saved.get(key)) -``` - -You can also build a `SavedActivation` directly from tensors you already have: - -```python -patch = SavedActivation.from_pairs( - ("model.transformer.h[2].output[10]", tensor_a), - ("model.transformer.h[3].output[9]", tensor_b), -) -``` - -## Typical workflows - -**Activation patching** — save from a clean run, replay into a corrupted run: - -```python -with model.trace() as tracer: - with tracer.invoke([clean_prompt]): - patch = save_activations(model, ["model.transformer.h[2].output[9]"]) - -with model.trace() as tracer: - with tracer.invoke([corrupted_prompt]): - patch.apply(model) - logits = model.lm_head.output.save() -``` - -**Steering** — compute a direction from contrastive prompts, add it to a -neutral run: - -```python -prompts = positive_prompts + negative_prompts + [neutral_prompt] -path = "model.transformer.h[5].output[-1]" - -with model.trace() as tracer: - with tracer.invoke(prompts): - saved = save_activations(model, [path]) - pos = saved.slice[0:3].mean() - neg = saved.slice[3:6].mean() - neutral = saved.slice[6] - -steered = neutral + 1.5 * (pos - neg) -```