Repositories / more_nnsight.git
tests/test_saved_activation.py
Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git
from __future__ import annotations
import pytest
import torch
from more_nnsight import SavedActivation, updates
from more_nnsight.saved_activation import ActivationTarget, AttrSegment, IndexSegment
# The parser should preserve the model's real attribute/index path structure.
def test_parse_path_uses_real_model_names() -> None:
target = SavedActivation.parse_path("model.transformer.h[2].output[10]")
assert target == ActivationTarget(
activation_path=(
AttrSegment("transformer"),
AttrSegment("h"),
IndexSegment(2),
AttrSegment("output"),
),
position=10,
)
# Values should be reachable through the same nested path under `.values`.
def test_saved_activation_mirrors_model_path() -> None:
tensor = torch.randn(2, 768)
saved = SavedActivation(
{
"transformer": {
"h": {
2: {
"output": {10: tensor}
}
}
}
}
)
assert saved.values.transformer.h[2].output[10] is tensor
# Missing attribute hops should fail immediately instead of producing empty nodes.
def test_missing_attribute_errors() -> None:
saved = SavedActivation({"transformer": {}})
with pytest.raises(AttributeError, match=r"saved_activation.transformer.h"):
_ = saved.values.transformer.h
# Missing index hops should also fail immediately.
def test_missing_index_errors() -> None:
saved = SavedActivation({"transformer": {"h": {2: {}}}})
with pytest.raises(KeyError, match="saved_activation.transformer.h"):
_ = saved.values.transformer.h[3]
# Subsetting should keep only the requested saved path and discard the rest.
def test_subset_returns_new_saved_activation() -> None:
tensor_1 = torch.randn(2, 768)
tensor_2 = torch.randn(2, 768)
saved = SavedActivation(
{
"transformer": {
"h": {
2: {"output": {10: tensor_1}},
3: {"output": {-1: tensor_2}},
}
}
}
)
subset = saved.subset(["model.transformer.h[3].output[-1]"])
assert subset.values.transformer.h[3].output[-1] is tensor_2
with pytest.raises(KeyError, match="saved_activation.transformer.h"):
_ = subset.values.transformer.h[2]
# String keys should round-trip through `keys()` and `get(...)`.
def test_keys_and_get_round_trip_saved_paths() -> None:
tensor = torch.randn(2, 768)
saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}})
assert saved.keys() == ["model.transformer.h[2].output[10]"]
assert saved.get("model.transformer.h[2].output[10]") is tensor
# Explicit path-value construction should parse paths, canonicalize order, and reject duplicates.
def test_from_pairs_builds_canonical_saved_activation_and_rejects_duplicates() -> None:
tensor_a = torch.randn(2, 3)
tensor_b = torch.randn(2, 3)
saved = SavedActivation.from_pairs(
("model.transformer.h[3].output[9]", tensor_b),
("model.transformer.h[2].output[10]", tensor_a),
)
assert saved.keys() == [
"model.transformer.h[2].output[10]",
"model.transformer.h[3].output[9]",
]
assert saved.get("model.transformer.h[2].output[10]") is tensor_a
assert saved.get("model.transformer.h[3].output[9]") is tensor_b
with pytest.raises(ValueError, match="duplicate path"):
SavedActivation.from_pairs(
("model.transformer.h[2].output[10]", tensor_a),
("model.transformer.h[2].output[10]", tensor_b),
)
# Batch slicing should preserve order across mixed slice and single-index selectors.
def test_slice_keeps_requested_batch_rows() -> None:
tensor = torch.arange(20, dtype=torch.float32).reshape(5, 4)
saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}})
sliced = saved.slice[1:3, 4]
assert sliced.values.transformer.h[2].output[10].shape == (3, 4)
assert torch.equal(sliced.values.transformer.h[2].output[10], tensor[[1, 2, 4]])
# Mean should return a new SavedActivation reduced across the batch dimension.
def test_mean_reduces_batch_dimension() -> None:
tensor = torch.randn(3, 768)
saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}})
mean_saved = saved.mean()
assert saved.values.transformer.h[2].output[10].shape == (3, 768)
assert mean_saved.values.transformer.h[2].output[10].shape == (1, 768)
assert torch.allclose(mean_saved.values.transformer.h[2].output[10], tensor.mean(dim=0, keepdim=True))
# Arithmetic should work elementwise on matching keys and reject mismatched ones.
def test_add_and_subtract_require_matching_keys() -> None:
left = SavedActivation({"transformer": {"h": {2: {"output": {10: torch.ones(1, 2)}}}}})
right = SavedActivation({"transformer": {"h": {2: {"output": {10: 2 * torch.ones(1, 2)}}}}})
other = SavedActivation({"transformer": {"h": {3: {"output": {10: torch.ones(1, 2)}}}}})
assert torch.equal((left + right).values.transformer.h[2].output[10], 3 * torch.ones(1, 2))
assert torch.equal((right - left).values.transformer.h[2].output[10], torch.ones(1, 2))
assert torch.equal((2.0 * left).values.transformer.h[2].output[10], 2 * torch.ones(1, 2))
with pytest.raises(ValueError, match="same keys"):
_ = left + other
# Union should merge disjoint saved paths and reject overlapping ones.
def test_union_merges_disjoint_paths_and_rejects_overlap() -> None:
left_tensor = torch.ones(1, 2)
right_tensor = 2 * torch.ones(1, 2)
overlap_tensor = 3 * torch.ones(1, 2)
left = SavedActivation({"transformer": {"h": {2: {"output": {10: left_tensor}}}}})
right = SavedActivation({"transformer": {"h": {3: {"output": {10: right_tensor}}}}})
overlap = SavedActivation({"transformer": {"h": {2: {"output": {10: overlap_tensor}}}}})
union = left.union(right)
assert union.keys() == [
"model.transformer.h[2].output[10]",
"model.transformer.h[3].output[10]",
]
assert union.get("model.transformer.h[2].output[10]") is left_tensor
assert union.get("model.transformer.h[3].output[10]") is right_tensor
with pytest.raises(ValueError, match="disjoint keys"):
_ = left.union(overlap)
def test_updates_reject_duplicate_keys() -> None:
with pytest.raises(ValueError, match="unique keys"):
list(
updates(
object(),
[
"model.transformer.h[2].output[10]",
"model.transformer.h[2].output[10]",
],
)
)
def test_updates_reject_noncanonical_key_order() -> None:
with pytest.raises(ValueError, match="canonical layer-major order"):
list(
updates(
object(),
[
"model.transformer.h[3].output[10]",
"model.transformer.h[2].output[10]",
],
)
)
def test_updates_reject_slice_paths() -> None:
with pytest.raises(ValueError, match="concrete keys"):
list(
updates(
object(),
["model.transformer.h[:].output[10]"],
)
)