Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add learned activations L0 norm metric #91

Merged
merged 9 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sparse_autoencoder/metrics/metrics_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sparse_autoencoder.metrics.train.abstract_train_metric import AbstractTrainMetric
from sparse_autoencoder.metrics.train.capacity import CapacityMetric
from sparse_autoencoder.metrics.train.feature_density import TrainBatchFeatureDensityMetric
from sparse_autoencoder.metrics.train.l0_norm_metric import TrainBatchLearnedActivationsL0
from sparse_autoencoder.metrics.validate.abstract_validate_metric import AbstractValidationMetric
from sparse_autoencoder.metrics.validate.model_reconstruction_score import ModelReconstructionScore

Expand All @@ -32,7 +33,11 @@ class MetricsContainer:


default_metrics = MetricsContainer(
train_metrics=[TrainBatchFeatureDensityMetric(), CapacityMetric()],
train_metrics=[
TrainBatchFeatureDensityMetric(),
CapacityMetric(),
TrainBatchLearnedActivationsL0(),
],
resample_metrics=[NeuronActivityMetric()],
validation_metrics=[ModelReconstructionScore()],
)
Expand Down
25 changes: 25 additions & 0 deletions sparse_autoencoder/metrics/train/l0_norm_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""L0 norm sparsity metric."""
from typing import final

import torch

from sparse_autoencoder.metrics.train.abstract_train_metric import (
AbstractTrainMetric,
TrainMetricData,
)


@final
class TrainBatchLearnedActivationsL0(AbstractTrainMetric):
"""Learned activations L0 norm sparsity metric.

The L0 norm is the number of non-zero elements in a learned activation vector. We then average
this over the batch.
"""

def calculate(self, data: TrainMetricData) -> dict[str, float]:
"""Create a log item for Weights and Biases."""
batch_size = data.learned_activations.size(0)
n_non_zero_activations = torch.count_nonzero(data.learned_activations)
batch_average = n_non_zero_activations / batch_size
return {"learned_activations_l0_norm": batch_average.item()}
19 changes: 19 additions & 0 deletions sparse_autoencoder/metrics/train/tests/test_l0_norm_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Tests for the L0NormMetric class."""
import torch

from sparse_autoencoder.metrics.train.abstract_train_metric import TrainMetricData
from sparse_autoencoder.metrics.train.l0_norm_metric import TrainBatchLearnedActivationsL0


def test_l0_norm_metric() -> None:
"""Test the L0NormMetric."""
learned_activations = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.01, 2.0]])
l0_norm_metric = TrainBatchLearnedActivationsL0()
data = TrainMetricData(
input_activations=torch.zeros_like(learned_activations),
learned_activations=learned_activations,
decoded_activations=torch.zeros_like(learned_activations),
)
log = l0_norm_metric.calculate(data)
expected = 3 / 2
assert log["learned_activations_l0_norm"] == expected