Skip to content

Commit

Permalink
Add pydantic validation (#167)
Browse files Browse the repository at this point in the history
Check all function calls where performant to do so
  • Loading branch information
alan-cooney authored Jan 2, 2024
1 parent 5132eb6 commit a2de674
Show file tree
Hide file tree
Showing 18 changed files with 100 additions and 65 deletions.
10 changes: 6 additions & 4 deletions sparse_autoencoder/activation_store/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import final

from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.utils.data import Dataset
Expand Down Expand Up @@ -105,12 +106,13 @@ def shuffle(self) -> None:
"""Optional shuffle method."""

@final
@validate_call
def fill_with_test_data(
self,
n_batches: int = 1,
batch_size: int = 16,
n_components: int = 1,
input_features: int = 256,
n_batches: PositiveInt = 1,
batch_size: PositiveInt = 16,
n_components: PositiveInt = 1,
input_features: PositiveInt = 256,
) -> None:
"""Fill the store with test data.
Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/activation_store/disk_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile

from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor

Expand Down Expand Up @@ -70,12 +71,13 @@ def current_activations_stored_per_component(self) -> list[int]:
disk_items_stored = len(self)
return [cache_items + disk_items_stored for cache_items in self._items_stored]

@validate_call
def __init__(
self,
n_neurons: int,
n_neurons: PositiveInt,
storage_path: Path = DEFAULT_DISK_ACTIVATION_STORE_PATH,
max_cache_size: int = 10_000,
n_components: int = 1,
max_cache_size: PositiveInt = 10_000,
n_components: PositiveInt = 1,
*,
empty_dir: bool = False,
):
Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/activation_store/tensor_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tensor Activation Store."""
from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor

Expand Down Expand Up @@ -75,11 +76,12 @@ def current_activations_stored_per_component(self) -> list[int]:
"""Number of activations stored per component."""
return self._items_stored

@validate_call(config={"arbitrary_types_allowed": True})
def __init__(
self,
max_items: int,
n_neurons: int,
n_components: int = 1,
max_items: PositiveInt,
n_neurons: PositiveInt,
n_components: PositiveInt = 1,
device: torch.device | None = None,
) -> None:
"""Initialise the Tensor Activation Store.
Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/autoencoder/components/abstract_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import final

from jaxtyping import Float, Int64
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn import Module, Parameter
Expand All @@ -24,11 +25,12 @@ class AbstractDecoder(Module, ABC):

_n_components: int | None

@validate_call
def __init__(
self,
learnt_features: int,
decoded_features: int,
n_components: int | None,
learnt_features: PositiveInt,
decoded_features: PositiveInt,
n_components: PositiveInt | None,
) -> None:
"""Initialise the decoder.
Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/autoencoder/components/abstract_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import final

from jaxtyping import Float, Int64
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn import Module, Parameter
Expand All @@ -25,11 +26,12 @@ class AbstractEncoder(Module, ABC):

_n_components: int | None

@validate_call
def __init__(
self,
input_features: int,
learnt_features: int,
n_components: int | None,
input_features: PositiveInt,
learnt_features: PositiveInt,
n_components: PositiveInt | None,
) -> None:
"""Initialise the encoder.
Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/autoencoder/components/linear_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import einops
from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn import Parameter, ReLU, init
Expand Down Expand Up @@ -77,11 +78,12 @@ def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
activation_function: ReLU
"""Activation function."""

@validate_call
def __init__(
self,
input_features: int,
learnt_features: int,
n_components: int | None,
input_features: PositiveInt,
learnt_features: PositiveInt,
n_components: PositiveInt | None,
):
"""Initialize the linear encoder layer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import einops
from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn import Parameter, init
Expand Down Expand Up @@ -76,11 +77,12 @@ def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
"""
return [(self.weight, -1)]

@validate_call
def __init__(
self,
learnt_features: int,
decoded_features: int,
n_components: int | None,
learnt_features: PositiveInt,
decoded_features: PositiveInt,
n_components: PositiveInt | None,
*,
enable_gradient_hook: bool = True,
) -> None:
Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/autoencoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import final

from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -72,15 +73,16 @@ def post_decoder_bias(self) -> TiedBias:
"""Post-decoder bias."""
return self._post_decoder_bias

@validate_call(config={"arbitrary_types_allowed": True})
def __init__(
self,
n_input_features: int,
n_learned_features: int,
n_input_features: PositiveInt,
n_learned_features: PositiveInt,
geometric_median_dataset: Float[
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
]
| None = None,
n_components: int | None = None,
n_components: PositiveInt | None = None,
) -> None:
"""Initialize the Sparse Autoencoder Model.
Expand Down
4 changes: 3 additions & 1 deletion sparse_autoencoder/loss/learned_activations_l1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import final

from jaxtyping import Float
from pydantic import PositiveFloat, validate_call
import torch
from torch import Tensor

Expand Down Expand Up @@ -37,8 +38,9 @@ def log_name(self) -> str:
"""
return "learned_activations_l1_loss_penalty"

@validate_call(config={"arbitrary_types_allowed": True})
def __init__(
self, l1_coefficient: float | Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)]
self, l1_coefficient: PositiveFloat | Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)]
) -> None:
"""Initialize the absolute error loss.
Expand Down
4 changes: 3 additions & 1 deletion sparse_autoencoder/metrics/train/feature_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from jaxtyping import Float
import numpy as np
from numpy import histogram
from pydantic import NonNegativeFloat, validate_call
import torch
from torch import Tensor
import wandb
Expand Down Expand Up @@ -32,9 +33,10 @@ class TrainBatchFeatureDensityMetric(AbstractTrainMetric):

threshold: float

@validate_call
def __init__(
self,
threshold: float = 0.0,
threshold: NonNegativeFloat = 0.0,
) -> None:
"""Initialise the train batch feature density metric.
Expand Down
10 changes: 6 additions & 4 deletions sparse_autoencoder/source_data/abstract_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from datasets import Dataset, IterableDataset, load_dataset
from jaxtyping import Int
from pydantic import PositiveInt, validate_call
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
Expand Down Expand Up @@ -106,17 +107,18 @@ def preprocess(
"""

@abstractmethod
@validate_call
def __init__(
self,
dataset_path: str,
dataset_split: str,
context_size: int,
buffer_size: int = 1000,
context_size: PositiveInt,
buffer_size: PositiveInt = 1000,
dataset_dir: str | None = None,
dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
dataset_column_name: str = "input_ids",
n_processes_preprocessing: int | None = None,
preprocess_batch_size: int = 1000,
n_processes_preprocessing: PositiveInt | None = None,
preprocess_batch_size: PositiveInt = 1000,
*,
pre_download: bool = False,
):
Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/source_data/mock_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from datasets import IterableDataset
from jaxtyping import Int
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from transformers import PreTrainedTokenizerFast
Expand Down Expand Up @@ -139,11 +140,12 @@ def preprocess(
# Nothing to do here
return source_batch

@validate_call
def __init__(
self,
context_size: int = 250,
buffer_size: int = 1000, # noqa: ARG002
preprocess_batch_size: int = 1000, # noqa: ARG002
context_size: PositiveInt = 250,
buffer_size: PositiveInt = 1000, # noqa: ARG002
preprocess_batch_size: PositiveInt = 1000, # noqa: ARG002
dataset_path: str = "dummy", # noqa: ARG002
dataset_split: str = "train", # noqa: ARG002
):
Expand Down
9 changes: 6 additions & 3 deletions sparse_autoencoder/source_data/pretokenized_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from collections.abc import Mapping, Sequence
from typing import final

from pydantic import PositiveInt, validate_call

from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts


Expand Down Expand Up @@ -67,16 +69,17 @@ def preprocess(

return {"input_ids": context_size_prompts}

@validate_call
def __init__(
self,
dataset_path: str,
context_size: int = 256,
buffer_size: int = 1000,
context_size: PositiveInt = 256,
buffer_size: PositiveInt = 1000,
dataset_dir: str | None = None,
dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
dataset_split: str = "train",
dataset_column_name: str = "input_ids",
preprocess_batch_size: int = 1000,
preprocess_batch_size: PositiveInt = 1000,
*,
pre_download: bool = False,
):
Expand Down
13 changes: 8 additions & 5 deletions sparse_autoencoder/source_data/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import TypedDict, final

from datasets import IterableDataset
from pydantic import PositiveInt, validate_call
from transformers import PreTrainedTokenizerBase

from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts
Expand Down Expand Up @@ -63,18 +64,19 @@ def preprocess(

return {"input_ids": context_size_prompts}

@validate_call(config={"arbitrary_types_allowed": True})
def __init__(
self,
dataset_path: str,
tokenizer: PreTrainedTokenizerBase,
buffer_size: int = 1000,
context_size: int = 256,
buffer_size: PositiveInt = 1000,
context_size: PositiveInt = 256,
dataset_dir: str | None = None,
dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
dataset_split: str = "train",
dataset_column_name: str = "input_ids",
n_processes_preprocessing: int | None = None,
preprocess_batch_size: int = 1000,
n_processes_preprocessing: PositiveInt | None = None,
preprocess_batch_size: PositiveInt = 1000,
*,
pre_download: bool = False,
):
Expand Down Expand Up @@ -116,12 +118,13 @@ def __init__(
preprocess_batch_size=preprocess_batch_size,
)

@validate_call
def push_to_hugging_face_hub(
self,
repo_id: str,
commit_message: str = "Upload preprocessed dataset using sparse_autoencoder.",
max_shard_size: str | None = None,
n_shards: int = 64,
n_shards: PositiveInt = 64,
revision: str = "main",
*,
private: bool = False,
Expand Down
Loading

0 comments on commit a2de674

Please sign in to comment.