forked from ai-safety-foundation/sparse_autoencoder
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add learned activations L0 norm metric (ai-safety-foundation#91)
--------- Co-authored-by: Alan Cooney <[email protected]>
- Loading branch information
1 parent
5782bc8
commit c38860d
Showing
3 changed files
with
50 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
"""L0 norm sparsity metric.""" | ||
from typing import final | ||
|
||
import torch | ||
|
||
from sparse_autoencoder.metrics.train.abstract_train_metric import ( | ||
AbstractTrainMetric, | ||
TrainMetricData, | ||
) | ||
|
||
|
||
@final | ||
class TrainBatchLearnedActivationsL0(AbstractTrainMetric): | ||
"""Learned activations L0 norm sparsity metric. | ||
The L0 norm is the number of non-zero elements in a learned activation vector. We then average | ||
this over the batch. | ||
""" | ||
|
||
def calculate(self, data: TrainMetricData) -> dict[str, float]: | ||
"""Create a log item for Weights and Biases.""" | ||
batch_size = data.learned_activations.size(0) | ||
n_non_zero_activations = torch.count_nonzero(data.learned_activations) | ||
batch_average = n_non_zero_activations / batch_size | ||
return {"learned_activations_l0_norm": batch_average.item()} |
19 changes: 19 additions & 0 deletions
19
sparse_autoencoder/metrics/train/tests/test_l0_norm_metric.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
"""Tests for the L0NormMetric class.""" | ||
import torch | ||
|
||
from sparse_autoencoder.metrics.train.abstract_train_metric import TrainMetricData | ||
from sparse_autoencoder.metrics.train.l0_norm_metric import TrainBatchLearnedActivationsL0 | ||
|
||
|
||
def test_l0_norm_metric() -> None: | ||
"""Test the L0NormMetric.""" | ||
learned_activations = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.01, 2.0]]) | ||
l0_norm_metric = TrainBatchLearnedActivationsL0() | ||
data = TrainMetricData( | ||
input_activations=torch.zeros_like(learned_activations), | ||
learned_activations=learned_activations, | ||
decoded_activations=torch.zeros_like(learned_activations), | ||
) | ||
log = l0_norm_metric.calculate(data) | ||
expected = 3 / 2 | ||
assert log["learned_activations_l0_norm"] == expected |