Repositories / more_nnsight.git

more_nnsight.git

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch

Add SavedActivation.from_pairs constructor

Author
Arjun Guha <a.guha@northeastern.edu>
Date
2026-03-30 20:00:51 -0400
Commit
bc7e18e9104436a4b6c9cf574c6b9179734ea90b
SKILL.md
index c3b4720..e958469 100644
--- a/SKILL.md
+++ b/SKILL.md
@@ -140,6 +140,22 @@ tensor = saved.get("model.transformer.h[2].output[10]")
 `saved.keys()` returns the saved path strings, and `saved.get(path)` returns one
 saved tensor. Missing paths raise immediately.
 
+## Building From Explicit Pairs
+
+If you already have tensors and want to build a `SavedActivation` directly, use
+`SavedActivation.from_pairs(...)`:
+
+```python
+patch = SavedActivation.from_pairs(
+    ("model.model.layers[34].output[10]", tensor_a),
+    ("model.model.layers[35].output[9]", tensor_b),
+)
+```
+
+Each entry is `(path, value)`. The paths are parsed the same way as
+`save_activations(...)`, the internal order is canonicalized, and duplicate
+paths raise an error.
+
 ## Subsetting by Path
 
 ```python
@@ -150,6 +166,23 @@ This keeps only the listed saved activations and drops the rest. In direct
 NNSight, you would usually do this by manually building a smaller Python
 structure.
 
+## Union
+
+Use `union` when you want to combine different saved paths into one
+`SavedActivation`:
+
+```python
+combined = layer_2_patch.union(layer_3_patch)
+```
+
+This is different from `+`:
+
+- `a + b` means "add values on the same paths"
+- `a.union(b)` means "combine different paths into one object"
+
+`union` requires the two objects to have disjoint keys. If the same saved path
+appears in both, it raises an error.
+
 ## Batch Slicing
 
 Use bracket syntax on `.slice`:
src/more_nnsight/saved_activation.py
index 5577a2e..e83f096 100644
--- a/src/more_nnsight/saved_activation.py
+++ b/src/more_nnsight/saved_activation.py
@@ -171,6 +171,17 @@ class SavedActivation:
         self._register_object_save(self)
         return self
 
+    @classmethod
+    def from_pairs(cls, *pairs: tuple[str, Any]) -> "SavedActivation":
+        """Builds a SavedActivation directly from explicit path-value pairs."""
+        targets: dict[ActivationTarget, Any] = {}
+        for path, value in pairs:
+            target = cls.parse_path(path)
+            if target in targets:
+                raise ValueError(f"SavedActivation.from_pairs got duplicate path: {path}")
+            targets[target] = value
+        return cls._from_targets(targets)
+
     def subset(self, activation_paths: list[str]) -> "SavedActivation":
         """Keeps only selected saved paths so patches can be focused or composed."""
         requested_targets = [self.parse_path(path) for path in activation_paths]
tests/test_saved_activation.py
index 524b2d4..a9313a9 100644
--- a/tests/test_saved_activation.py
+++ b/tests/test_saved_activation.py
@@ -87,6 +87,30 @@ def test_keys_and_get_round_trip_saved_paths() -> None:
     assert saved.get("model.transformer.h[2].output[10]") is tensor
 
 
+# Explicit path-value construction should parse paths, canonicalize order, and reject duplicates.
+def test_from_pairs_builds_canonical_saved_activation_and_rejects_duplicates() -> None:
+    tensor_a = torch.randn(2, 3)
+    tensor_b = torch.randn(2, 3)
+
+    saved = SavedActivation.from_pairs(
+        ("model.transformer.h[3].output[9]", tensor_b),
+        ("model.transformer.h[2].output[10]", tensor_a),
+    )
+
+    assert saved.keys() == [
+        "model.transformer.h[2].output[10]",
+        "model.transformer.h[3].output[9]",
+    ]
+    assert saved.get("model.transformer.h[2].output[10]") is tensor_a
+    assert saved.get("model.transformer.h[3].output[9]") is tensor_b
+
+    with pytest.raises(ValueError, match="duplicate path"):
+        SavedActivation.from_pairs(
+            ("model.transformer.h[2].output[10]", tensor_a),
+            ("model.transformer.h[2].output[10]", tensor_b),
+        )
+
+
 # Batch slicing should preserve order across mixed slice and single-index selectors.
 def test_slice_keeps_requested_batch_rows() -> None:
     tensor = torch.arange(20, dtype=torch.float32).reshape(5, 4)