Repositories / more_nnsight.git

tests/test_saved_activation.py

Clone (read-only): git clone http://git.guha-anderson.com/git/more_nnsight.git

Branch
7339 bytes · af9ce73eaa63
from __future__ import annotations import pytest import torch from more_nnsight import SavedActivation, updates from more_nnsight.saved_activation import ActivationTarget, AttrSegment, IndexSegment # The parser should preserve the model's real attribute/index path structure. def test_parse_path_uses_real_model_names() -> None: target = SavedActivation.parse_path("model.transformer.h[2].output[10]") assert target == ActivationTarget( activation_path=( AttrSegment("transformer"), AttrSegment("h"), IndexSegment(2), AttrSegment("output"), ), position=10, ) # Values should be reachable through the same nested path under `.values`. def test_saved_activation_mirrors_model_path() -> None: tensor = torch.randn(2, 768) saved = SavedActivation( { "transformer": { "h": { 2: { "output": {10: tensor} } } } } ) assert saved.values.transformer.h[2].output[10] is tensor # Missing attribute hops should fail immediately instead of producing empty nodes. def test_missing_attribute_errors() -> None: saved = SavedActivation({"transformer": {}}) with pytest.raises(AttributeError, match=r"saved_activation.transformer.h"): _ = saved.values.transformer.h # Missing index hops should also fail immediately. def test_missing_index_errors() -> None: saved = SavedActivation({"transformer": {"h": {2: {}}}}) with pytest.raises(KeyError, match="saved_activation.transformer.h"): _ = saved.values.transformer.h[3] # Subsetting should keep only the requested saved path and discard the rest. def test_subset_returns_new_saved_activation() -> None: tensor_1 = torch.randn(2, 768) tensor_2 = torch.randn(2, 768) saved = SavedActivation( { "transformer": { "h": { 2: {"output": {10: tensor_1}}, 3: {"output": {-1: tensor_2}}, } } } ) subset = saved.subset(["model.transformer.h[3].output[-1]"]) assert subset.values.transformer.h[3].output[-1] is tensor_2 with pytest.raises(KeyError, match="saved_activation.transformer.h"): _ = subset.values.transformer.h[2] # String keys should round-trip through `keys()` and `get(...)`. def test_keys_and_get_round_trip_saved_paths() -> None: tensor = torch.randn(2, 768) saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}}) assert saved.keys() == ["model.transformer.h[2].output[10]"] assert saved.get("model.transformer.h[2].output[10]") is tensor # Explicit path-value construction should parse paths, canonicalize order, and reject duplicates. def test_from_pairs_builds_canonical_saved_activation_and_rejects_duplicates() -> None: tensor_a = torch.randn(2, 3) tensor_b = torch.randn(2, 3) saved = SavedActivation.from_pairs( ("model.transformer.h[3].output[9]", tensor_b), ("model.transformer.h[2].output[10]", tensor_a), ) assert saved.keys() == [ "model.transformer.h[2].output[10]", "model.transformer.h[3].output[9]", ] assert saved.get("model.transformer.h[2].output[10]") is tensor_a assert saved.get("model.transformer.h[3].output[9]") is tensor_b with pytest.raises(ValueError, match="duplicate path"): SavedActivation.from_pairs( ("model.transformer.h[2].output[10]", tensor_a), ("model.transformer.h[2].output[10]", tensor_b), ) # Batch slicing should preserve order across mixed slice and single-index selectors. def test_slice_keeps_requested_batch_rows() -> None: tensor = torch.arange(20, dtype=torch.float32).reshape(5, 4) saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}}) sliced = saved.slice[1:3, 4] assert sliced.values.transformer.h[2].output[10].shape == (3, 4) assert torch.equal(sliced.values.transformer.h[2].output[10], tensor[[1, 2, 4]]) # Mean should return a new SavedActivation reduced across the batch dimension. def test_mean_reduces_batch_dimension() -> None: tensor = torch.randn(3, 768) saved = SavedActivation({"transformer": {"h": {2: {"output": {10: tensor}}}}}) mean_saved = saved.mean() assert saved.values.transformer.h[2].output[10].shape == (3, 768) assert mean_saved.values.transformer.h[2].output[10].shape == (1, 768) assert torch.allclose(mean_saved.values.transformer.h[2].output[10], tensor.mean(dim=0, keepdim=True)) # Arithmetic should work elementwise on matching keys and reject mismatched ones. def test_add_and_subtract_require_matching_keys() -> None: left = SavedActivation({"transformer": {"h": {2: {"output": {10: torch.ones(1, 2)}}}}}) right = SavedActivation({"transformer": {"h": {2: {"output": {10: 2 * torch.ones(1, 2)}}}}}) other = SavedActivation({"transformer": {"h": {3: {"output": {10: torch.ones(1, 2)}}}}}) assert torch.equal((left + right).values.transformer.h[2].output[10], 3 * torch.ones(1, 2)) assert torch.equal((right - left).values.transformer.h[2].output[10], torch.ones(1, 2)) assert torch.equal((2.0 * left).values.transformer.h[2].output[10], 2 * torch.ones(1, 2)) with pytest.raises(ValueError, match="same keys"): _ = left + other # Union should merge disjoint saved paths and reject overlapping ones. def test_union_merges_disjoint_paths_and_rejects_overlap() -> None: left_tensor = torch.ones(1, 2) right_tensor = 2 * torch.ones(1, 2) overlap_tensor = 3 * torch.ones(1, 2) left = SavedActivation({"transformer": {"h": {2: {"output": {10: left_tensor}}}}}) right = SavedActivation({"transformer": {"h": {3: {"output": {10: right_tensor}}}}}) overlap = SavedActivation({"transformer": {"h": {2: {"output": {10: overlap_tensor}}}}}) union = left.union(right) assert union.keys() == [ "model.transformer.h[2].output[10]", "model.transformer.h[3].output[10]", ] assert union.get("model.transformer.h[2].output[10]") is left_tensor assert union.get("model.transformer.h[3].output[10]") is right_tensor with pytest.raises(ValueError, match="disjoint keys"): _ = left.union(overlap) def test_updates_reject_duplicate_keys() -> None: with pytest.raises(ValueError, match="unique keys"): list( updates( object(), [ "model.transformer.h[2].output[10]", "model.transformer.h[2].output[10]", ], ) ) def test_updates_reject_noncanonical_key_order() -> None: with pytest.raises(ValueError, match="canonical layer-major order"): list( updates( object(), [ "model.transformer.h[3].output[10]", "model.transformer.h[2].output[10]", ], ) ) def test_updates_reject_slice_paths() -> None: with pytest.raises(ValueError, match="concrete keys"): list( updates( object(), ["model.transformer.h[:].output[10]"], ) )