Skip to content

Commit

Permalink
Add learned activations L0 norm metric (ai-safety-foundation#91)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Alan Cooney <[email protected]>
  • Loading branch information
lucyfarnik and alan-cooney authored Nov 27, 2023
1 parent 5782bc8 commit c38860d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
7 changes: 6 additions & 1 deletion sparse_autoencoder/metrics/metrics_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sparse_autoencoder.metrics.train.abstract_train_metric import AbstractTrainMetric
from sparse_autoencoder.metrics.train.capacity import CapacityMetric
from sparse_autoencoder.metrics.train.feature_density import TrainBatchFeatureDensityMetric
from sparse_autoencoder.metrics.train.l0_norm_metric import TrainBatchLearnedActivationsL0
from sparse_autoencoder.metrics.validate.abstract_validate_metric import AbstractValidationMetric
from sparse_autoencoder.metrics.validate.model_reconstruction_score import ModelReconstructionScore

Expand All @@ -32,7 +33,11 @@ class MetricsContainer:


default_metrics = MetricsContainer(
train_metrics=[TrainBatchFeatureDensityMetric(), CapacityMetric()],
train_metrics=[
TrainBatchFeatureDensityMetric(),
CapacityMetric(),
TrainBatchLearnedActivationsL0(),
],
resample_metrics=[NeuronActivityMetric()],
validation_metrics=[ModelReconstructionScore()],
)
Expand Down
25 changes: 25 additions & 0 deletions sparse_autoencoder/metrics/train/l0_norm_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""L0 norm sparsity metric."""
from typing import final

import torch

from sparse_autoencoder.metrics.train.abstract_train_metric import (
AbstractTrainMetric,
TrainMetricData,
)


@final
class TrainBatchLearnedActivationsL0(AbstractTrainMetric):
"""Learned activations L0 norm sparsity metric.
The L0 norm is the number of non-zero elements in a learned activation vector. We then average
this over the batch.
"""

def calculate(self, data: TrainMetricData) -> dict[str, float]:
"""Create a log item for Weights and Biases."""
batch_size = data.learned_activations.size(0)
n_non_zero_activations = torch.count_nonzero(data.learned_activations)
batch_average = n_non_zero_activations / batch_size
return {"learned_activations_l0_norm": batch_average.item()}
19 changes: 19 additions & 0 deletions sparse_autoencoder/metrics/train/tests/test_l0_norm_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Tests for the L0NormMetric class."""
import torch

from sparse_autoencoder.metrics.train.abstract_train_metric import TrainMetricData
from sparse_autoencoder.metrics.train.l0_norm_metric import TrainBatchLearnedActivationsL0


def test_l0_norm_metric() -> None:
"""Test the L0NormMetric."""
learned_activations = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.01, 2.0]])
l0_norm_metric = TrainBatchLearnedActivationsL0()
data = TrainMetricData(
input_activations=torch.zeros_like(learned_activations),
learned_activations=learned_activations,
decoded_activations=torch.zeros_like(learned_activations),
)
log = l0_norm_metric.calculate(data)
expected = 3 / 2
assert log["learned_activations_l0_norm"] == expected

0 comments on commit c38860d

Please sign in to comment.