diff --git a/cirkit/layers/input/exp_family/exp_family.py b/cirkit/layers/input/exp_family/exp_family.py index 28d9d74a..2c477117 100644 --- a/cirkit/layers/input/exp_family/exp_family.py +++ b/cirkit/layers/input/exp_family/exp_family.py @@ -6,7 +6,7 @@ from cirkit.layers.input import InputLayer from cirkit.reparams.leaf import ReparamIdentity -from cirkit.reparams.reparam import Reparameterizaion +from cirkit.reparams.reparam import Reparameterization from cirkit.utils.type_aliases import ReparamFactory @@ -21,7 +21,7 @@ class ExpFamilyLayer(InputLayer): based on its implementation. """ - params: Reparameterizaion + params: Reparameterization """The reparameterizaion that gives the natural parameters eta, shape (D, K, P, S).""" def __init__( # type: ignore[misc] # pylint: disable=too-many-arguments diff --git a/cirkit/layers/sum.py b/cirkit/layers/sum.py index 33322b37..8bfbd8a8 100644 --- a/cirkit/layers/sum.py +++ b/cirkit/layers/sum.py @@ -5,7 +5,7 @@ from cirkit.layers.layer import Layer from cirkit.reparams.leaf import ReparamIdentity -from cirkit.reparams.reparam import Reparameterizaion +from cirkit.reparams.reparam import Reparameterization from cirkit.utils.log_trick import log_func_exp from cirkit.utils.type_aliases import ReparamFactory @@ -16,7 +16,7 @@ class SumLayer(Layer): TODO: currently this is only a sum for mixing, but not generic sum layer. """ - params: Reparameterizaion + params: Reparameterization """The reparameterizaion that gives the parameters for sum units, shape (F, H, K).""" def __init__( # pylint: disable=too-many-arguments diff --git a/cirkit/layers/sum_product/cp.py b/cirkit/layers/sum_product/cp.py index 389c311b..5591ba4a 100644 --- a/cirkit/layers/sum_product/cp.py +++ b/cirkit/layers/sum_product/cp.py @@ -5,7 +5,7 @@ from cirkit.layers.sum_product.sum_product import SumProductLayer from cirkit.reparams.leaf import ReparamIdentity -from cirkit.reparams.reparam import Reparameterizaion +from cirkit.reparams.reparam import Reparameterization from cirkit.utils.log_trick import log_func_exp from cirkit.utils.type_aliases import ReparamFactory @@ -13,11 +13,11 @@ class BaseCPLayer(SumProductLayer): """Candecomp Parafac (decomposition) layer.""" - params_in: Optional[Reparameterizaion] + params_in: Optional[Reparameterization] """The reparameterizaion that gives the parameters for sum units on input, shape as given by \ dim names, e.g., (F, H, I, O). Can be None to disable this part of computation.""" - params_out: Optional[Reparameterizaion] + params_out: Optional[Reparameterization] """The reparameterizaion that gives the parameters for sum units on output, shape as given by \ dim names, e.g., (F, I, O). Can be None to disable this part of computation.""" @@ -74,10 +74,14 @@ def __init__( # pylint: disable=too-many-arguments # TODO: convert to tuple currently required to unpack str, but will be changed in a # future version of mypy. see https://github.com/python/mypy/pull/15511 i, o = tuple(params_in_dim_name[-2:]) - assert i == "i" and o == ("r" if params_out_dim_name else "o") + assert i == "i" and o == ( + "r" if params_out_dim_name else "o" + ), f"Unexpected {params_in_dim_name=} (with {params_out_dim_name=})." self._einsum_in = f"{params_in_dim_name},fh{i}...->fh{o}..." # TODO: currently we can only support this. any elegant impl? - assert params_in_dim_name[:2] == "fh" or fold_mask is None + assert ( + params_in_dim_name[:2] == "fh" or fold_mask is None + ), f"Unexpected {params_in_dim_name=} with fold_mask." # Only params_in can see the folds and need mask. self.params_in = reparam( self._infer_shape(params_in_dim_name), @@ -94,7 +98,9 @@ def __init__( # pylint: disable=too-many-arguments if params_out_dim_name: i, o = tuple(params_out_dim_name[-2:]) - assert i == ("r" if params_in_dim_name else "i") and o == "o" + assert ( + i == ("r" if params_in_dim_name else "i") and o == "o" + ), f"Unexpected {params_out_dim_name=} (with {params_in_dim_name=})." self._einsum_out = f"{params_out_dim_name},f{i}...->f{o}..." self.params_out = reparam(self._infer_shape(params_out_dim_name), dim=-2) else: @@ -123,7 +129,7 @@ def _infer_shape(self, dim_names: str) -> Tuple[int, ...]: return tuple(mapping[name] for name in dim_names) def _forward_in_linear(self, x: Tensor) -> Tensor: - assert self.params_in is not None and self._einsum_in + assert self.params_in is not None and self._einsum_in, "This should not happen." # shape (F, H, K, *B) -> (F, H, K, *B) return torch.einsum(self._einsum_in, self.params_in(), x) @@ -140,7 +146,7 @@ def _forward_reduce_log(self, x: Tensor) -> Tensor: return x.sum(dim=1) # shape (F, H, K, *B) -> (F, K, *B) def _forward_out_linear(self, x: Tensor) -> Tensor: - assert self.params_out is not None and self._einsum_out + assert self.params_out is not None and self._einsum_out, "This should not happen." # shape (F, K, *B) -> (F, K, *B) return torch.einsum(self._einsum_out, self.params_out(), x) @@ -168,7 +174,7 @@ def forward(self, x: Tensor) -> Tensor: class CollapsedCPLayer(BaseCPLayer): """Candecomp Parafac (decomposition) layer, with matrix C collapsed.""" - params_in: Reparameterizaion + params_in: Reparameterization """The reparameterizaion that gives the parameters for sum units on input, \ shape (F, H, I, O).""" @@ -210,11 +216,11 @@ def __init__( # pylint: disable=too-many-arguments class UncollapsedCPLayer(BaseCPLayer): """Candecomp Parafac (decomposition) layer, without collapsing.""" - params_in: Reparameterizaion + params_in: Reparameterization """The reparameterizaion that gives the parameters for sum units on input, \ shape (F, H, I, R).""" - params_out: Reparameterizaion + params_out: Reparameterization """The reparameterizaion that gives the parameters for sum units on output, shape (F, R, O).""" def __init__( # pylint: disable=too-many-arguments @@ -257,7 +263,7 @@ def __init__( # pylint: disable=too-many-arguments class SharedCPLayer(BaseCPLayer): """Candecomp Parafac (decomposition) layer, with parameter sharing and matrix C collapsed.""" - params_in: Reparameterizaion + params_in: Reparameterization """The reparameterizaion that gives the parameters for sum units on input, shape (H, I, O).""" params_out: None diff --git a/cirkit/layers/sum_product/tucker.py b/cirkit/layers/sum_product/tucker.py index 80128b5a..c54e8eba 100644 --- a/cirkit/layers/sum_product/tucker.py +++ b/cirkit/layers/sum_product/tucker.py @@ -5,7 +5,7 @@ from cirkit.layers.sum_product.sum_product import SumProductLayer from cirkit.reparams.leaf import ReparamIdentity -from cirkit.reparams.reparam import Reparameterizaion +from cirkit.reparams.reparam import Reparameterization from cirkit.utils.log_trick import log_func_exp from cirkit.utils.type_aliases import ReparamFactory @@ -15,7 +15,7 @@ class TuckerLayer(SumProductLayer): """Tucker (2) layer.""" - params: Reparameterizaion + params: Reparameterization """The reparameterizaion that gives the parameters for sum units, shape (F, I, J, O).""" def __init__( # pylint: disable=too-many-arguments @@ -42,7 +42,7 @@ def __init__( # pylint: disable=too-many-arguments NotImplementedError: When arity is not 2. """ if arity != 2: - raise NotImplementedError("Tucker layers only implements binary product units.") + raise NotImplementedError("Tucker layers only implement binary product units.") assert fold_mask is None, "Input for Tucker layer should not be masked." super().__init__( num_input_units=num_input_units, diff --git a/cirkit/reparams/exp_family.py b/cirkit/reparams/exp_family.py index d5b19b80..9eadfb64 100644 --- a/cirkit/reparams/exp_family.py +++ b/cirkit/reparams/exp_family.py @@ -11,11 +11,11 @@ # This is just Indentity, optionally we can add a scaling factor but currently not implemented. ## # class ReparamEFBinomial(ReparamLeaf): -# """Reparametrization for ExpFamily -- Binomial.""" +# """Reparameterization for ExpFamily -- Binomial.""" class ReparamEFCategorical(ReparamLeaf): - """Reparametrization for ExpFamily -- Categorical.""" + """Reparameterization for ExpFamily -- Categorical.""" def __init__( # type: ignore[misc] self, @@ -51,7 +51,7 @@ def forward(self) -> Tensor: class ReparamEFNormal(ReparamLeaf): - """Reparametrization for ExpFamily -- Normal.""" + """Reparameterization for ExpFamily -- Normal.""" def __init__( # type: ignore[misc] self, diff --git a/cirkit/reparams/leaf.py b/cirkit/reparams/leaf.py index f0944bef..e037eb52 100644 --- a/cirkit/reparams/leaf.py +++ b/cirkit/reparams/leaf.py @@ -5,10 +5,10 @@ from cirkit.utils.type_aliases import ClampBounds -from .reparam import Reparameterizaion +from .reparam import Reparameterization -class ReparamLeaf(Reparameterizaion): +class ReparamLeaf(Reparameterization): """A leaf in reparameterizaion that holds the parameter instance and does simple transforms. There's no param initialization here. That's the responsibility of Layers. diff --git a/cirkit/reparams/reparam.py b/cirkit/reparams/reparam.py index 84be2a4b..407c8c2e 100644 --- a/cirkit/reparams/reparam.py +++ b/cirkit/reparams/reparam.py @@ -5,7 +5,7 @@ from torch import Tensor, nn -class Reparameterizaion(nn.Module, ABC): +class Reparameterization(nn.Module, ABC): """The base class for all reparameterizaions.""" log_mask: Optional[Tensor] # to be registered as buffer @@ -52,11 +52,11 @@ def __init__( # type: ignore[misc] assert -len(size) <= dim < len(size), f"dim={dim} out of range for {len(size)}-d." self.dims = (dim if dim >= 0 else dim + len(size),) - assert mask is None or log_mask is None, "mask/log_mask may not be supplied together." + assert mask is None or log_mask is None, "mask and log_mask may not be supplied together." # Currently only saves log_mask. We can add mask if useful if mask is not None: - # broadcast_to raises RuntimeError is not broadcastable + # broadcast_to raises RuntimeError when not broadcastable mask.broadcast_to(size) self.register_buffer("log_mask", torch.log(mask)) elif log_mask is not None: diff --git a/cirkit/utils/type_aliases.py b/cirkit/utils/type_aliases.py index 3b31eb67..e665e8b0 100644 --- a/cirkit/utils/type_aliases.py +++ b/cirkit/utils/type_aliases.py @@ -2,7 +2,7 @@ from torch import Tensor -from cirkit.reparams.reparam import Reparameterizaion +from cirkit.reparams.reparam import Reparameterization # Here're all the type defs and aliases shared across the lib. # For private types that is only used in one file, can be defined there. @@ -21,7 +21,7 @@ class ClampBounds(TypedDict, total=False): class ReparamFactory(Protocol): # pylint: disable=too-few-public-methods - """Protocol for Callable that mimics Reparameterizaion constructor.""" + """Protocol for Callable that mimics Reparameterization constructor.""" def __call__( self, @@ -31,7 +31,7 @@ def __call__( dim: Union[int, Sequence[int]], mask: Optional[Tensor] = None, log_mask: Optional[Tensor] = None, - ) -> Reparameterizaion: - """Construct a Reparameterizaion object.""" + ) -> Reparameterization: + """Construct a Reparameterization object.""" # TODO: pylance issue, ellipsis is required here ... # pylint:disable=unnecessary-ellipsis