Skip to content

Commit

Permalink
minor fix for typo and error msg
Browse files Browse the repository at this point in the history
  • Loading branch information
lkct committed Nov 17, 2023
1 parent 4ed08dd commit 24bcc17
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 31 deletions.
4 changes: 2 additions & 2 deletions cirkit/layers/input/exp_family/exp_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from cirkit.layers.input import InputLayer
from cirkit.reparams.leaf import ReparamIdentity
from cirkit.reparams.reparam import Reparameterizaion
from cirkit.reparams.reparam import Reparameterization
from cirkit.utils.type_aliases import ReparamFactory


Expand All @@ -21,7 +21,7 @@ class ExpFamilyLayer(InputLayer):
based on its implementation.
"""

params: Reparameterizaion
params: Reparameterization
"""The reparameterizaion that gives the natural parameters eta, shape (D, K, P, S)."""

def __init__( # type: ignore[misc] # pylint: disable=too-many-arguments
Expand Down
4 changes: 2 additions & 2 deletions cirkit/layers/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from cirkit.layers.layer import Layer
from cirkit.reparams.leaf import ReparamIdentity
from cirkit.reparams.reparam import Reparameterizaion
from cirkit.reparams.reparam import Reparameterization
from cirkit.utils.log_trick import log_func_exp
from cirkit.utils.type_aliases import ReparamFactory

Expand All @@ -16,7 +16,7 @@ class SumLayer(Layer):
TODO: currently this is only a sum for mixing, but not generic sum layer.
"""

params: Reparameterizaion
params: Reparameterization
"""The reparameterizaion that gives the parameters for sum units, shape (F, H, K)."""

def __init__( # pylint: disable=too-many-arguments
Expand Down
30 changes: 18 additions & 12 deletions cirkit/layers/sum_product/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@

from cirkit.layers.sum_product.sum_product import SumProductLayer
from cirkit.reparams.leaf import ReparamIdentity
from cirkit.reparams.reparam import Reparameterizaion
from cirkit.reparams.reparam import Reparameterization
from cirkit.utils.log_trick import log_func_exp
from cirkit.utils.type_aliases import ReparamFactory


class BaseCPLayer(SumProductLayer):
"""Candecomp Parafac (decomposition) layer."""

params_in: Optional[Reparameterizaion]
params_in: Optional[Reparameterization]
"""The reparameterizaion that gives the parameters for sum units on input, shape as given by \
dim names, e.g., (F, H, I, O). Can be None to disable this part of computation."""

params_out: Optional[Reparameterizaion]
params_out: Optional[Reparameterization]
"""The reparameterizaion that gives the parameters for sum units on output, shape as given by \
dim names, e.g., (F, I, O). Can be None to disable this part of computation."""

Expand Down Expand Up @@ -74,10 +74,14 @@ def __init__( # pylint: disable=too-many-arguments
# TODO: convert to tuple currently required to unpack str, but will be changed in a
# future version of mypy. see https://github.com/python/mypy/pull/15511
i, o = tuple(params_in_dim_name[-2:])
assert i == "i" and o == ("r" if params_out_dim_name else "o")
assert i == "i" and o == (
"r" if params_out_dim_name else "o"
), f"Unexpected {params_in_dim_name=} (with {params_out_dim_name=})."
self._einsum_in = f"{params_in_dim_name},fh{i}...->fh{o}..."
# TODO: currently we can only support this. any elegant impl?
assert params_in_dim_name[:2] == "fh" or fold_mask is None
assert (
params_in_dim_name[:2] == "fh" or fold_mask is None
), f"Unexpected {params_in_dim_name=} with fold_mask."
# Only params_in can see the folds and need mask.
self.params_in = reparam(
self._infer_shape(params_in_dim_name),
Expand All @@ -94,7 +98,9 @@ def __init__( # pylint: disable=too-many-arguments

if params_out_dim_name:
i, o = tuple(params_out_dim_name[-2:])
assert i == ("r" if params_in_dim_name else "i") and o == "o"
assert (
i == ("r" if params_in_dim_name else "i") and o == "o"
), f"Unexpected {params_out_dim_name=} (with {params_in_dim_name=})."
self._einsum_out = f"{params_out_dim_name},f{i}...->f{o}..."
self.params_out = reparam(self._infer_shape(params_out_dim_name), dim=-2)
else:
Expand Down Expand Up @@ -123,7 +129,7 @@ def _infer_shape(self, dim_names: str) -> Tuple[int, ...]:
return tuple(mapping[name] for name in dim_names)

def _forward_in_linear(self, x: Tensor) -> Tensor:
assert self.params_in is not None and self._einsum_in
assert self.params_in is not None and self._einsum_in, "This should not happen."
# shape (F, H, K, *B) -> (F, H, K, *B)
return torch.einsum(self._einsum_in, self.params_in(), x)

Expand All @@ -140,7 +146,7 @@ def _forward_reduce_log(self, x: Tensor) -> Tensor:
return x.sum(dim=1) # shape (F, H, K, *B) -> (F, K, *B)

def _forward_out_linear(self, x: Tensor) -> Tensor:
assert self.params_out is not None and self._einsum_out
assert self.params_out is not None and self._einsum_out, "This should not happen."
# shape (F, K, *B) -> (F, K, *B)
return torch.einsum(self._einsum_out, self.params_out(), x)

Expand Down Expand Up @@ -168,7 +174,7 @@ def forward(self, x: Tensor) -> Tensor:
class CollapsedCPLayer(BaseCPLayer):
"""Candecomp Parafac (decomposition) layer, with matrix C collapsed."""

params_in: Reparameterizaion
params_in: Reparameterization
"""The reparameterizaion that gives the parameters for sum units on input, \
shape (F, H, I, O)."""

Expand Down Expand Up @@ -210,11 +216,11 @@ def __init__( # pylint: disable=too-many-arguments
class UncollapsedCPLayer(BaseCPLayer):
"""Candecomp Parafac (decomposition) layer, without collapsing."""

params_in: Reparameterizaion
params_in: Reparameterization
"""The reparameterizaion that gives the parameters for sum units on input, \
shape (F, H, I, R)."""

params_out: Reparameterizaion
params_out: Reparameterization
"""The reparameterizaion that gives the parameters for sum units on output, shape (F, R, O)."""

def __init__( # pylint: disable=too-many-arguments
Expand Down Expand Up @@ -257,7 +263,7 @@ def __init__( # pylint: disable=too-many-arguments
class SharedCPLayer(BaseCPLayer):
"""Candecomp Parafac (decomposition) layer, with parameter sharing and matrix C collapsed."""

params_in: Reparameterizaion
params_in: Reparameterization
"""The reparameterizaion that gives the parameters for sum units on input, shape (H, I, O)."""

params_out: None
Expand Down
6 changes: 3 additions & 3 deletions cirkit/layers/sum_product/tucker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from cirkit.layers.sum_product.sum_product import SumProductLayer
from cirkit.reparams.leaf import ReparamIdentity
from cirkit.reparams.reparam import Reparameterizaion
from cirkit.reparams.reparam import Reparameterization
from cirkit.utils.log_trick import log_func_exp
from cirkit.utils.type_aliases import ReparamFactory

Expand All @@ -15,7 +15,7 @@
class TuckerLayer(SumProductLayer):
"""Tucker (2) layer."""

params: Reparameterizaion
params: Reparameterization
"""The reparameterizaion that gives the parameters for sum units, shape (F, I, J, O)."""

def __init__( # pylint: disable=too-many-arguments
Expand All @@ -42,7 +42,7 @@ def __init__( # pylint: disable=too-many-arguments
NotImplementedError: When arity is not 2.
"""
if arity != 2:
raise NotImplementedError("Tucker layers only implements binary product units.")
raise NotImplementedError("Tucker layers only implement binary product units.")
assert fold_mask is None, "Input for Tucker layer should not be masked."
super().__init__(
num_input_units=num_input_units,
Expand Down
6 changes: 3 additions & 3 deletions cirkit/reparams/exp_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
# This is just Indentity, optionally we can add a scaling factor but currently not implemented.
##
# class ReparamEFBinomial(ReparamLeaf):
# """Reparametrization for ExpFamily -- Binomial."""
# """Reparameterization for ExpFamily -- Binomial."""


class ReparamEFCategorical(ReparamLeaf):
"""Reparametrization for ExpFamily -- Categorical."""
"""Reparameterization for ExpFamily -- Categorical."""

def __init__( # type: ignore[misc]
self,
Expand Down Expand Up @@ -51,7 +51,7 @@ def forward(self) -> Tensor:


class ReparamEFNormal(ReparamLeaf):
"""Reparametrization for ExpFamily -- Normal."""
"""Reparameterization for ExpFamily -- Normal."""

def __init__( # type: ignore[misc]
self,
Expand Down
4 changes: 2 additions & 2 deletions cirkit/reparams/leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from cirkit.utils.type_aliases import ClampBounds

from .reparam import Reparameterizaion
from .reparam import Reparameterization


class ReparamLeaf(Reparameterizaion):
class ReparamLeaf(Reparameterization):
"""A leaf in reparameterizaion that holds the parameter instance and does simple transforms.
There's no param initialization here. That's the responsibility of Layers.
Expand Down
6 changes: 3 additions & 3 deletions cirkit/reparams/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor, nn


class Reparameterizaion(nn.Module, ABC):
class Reparameterization(nn.Module, ABC):
"""The base class for all reparameterizaions."""

log_mask: Optional[Tensor] # to be registered as buffer
Expand Down Expand Up @@ -52,11 +52,11 @@ def __init__( # type: ignore[misc]
assert -len(size) <= dim < len(size), f"dim={dim} out of range for {len(size)}-d."
self.dims = (dim if dim >= 0 else dim + len(size),)

assert mask is None or log_mask is None, "mask/log_mask may not be supplied together."
assert mask is None or log_mask is None, "mask and log_mask may not be supplied together."

# Currently only saves log_mask. We can add mask if useful
if mask is not None:
# broadcast_to raises RuntimeError is not broadcastable
# broadcast_to raises RuntimeError when not broadcastable
mask.broadcast_to(size)
self.register_buffer("log_mask", torch.log(mask))
elif log_mask is not None:
Expand Down
8 changes: 4 additions & 4 deletions cirkit/utils/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch import Tensor

from cirkit.reparams.reparam import Reparameterizaion
from cirkit.reparams.reparam import Reparameterization

# Here're all the type defs and aliases shared across the lib.
# For private types that is only used in one file, can be defined there.
Expand All @@ -21,7 +21,7 @@ class ClampBounds(TypedDict, total=False):


class ReparamFactory(Protocol): # pylint: disable=too-few-public-methods
"""Protocol for Callable that mimics Reparameterizaion constructor."""
"""Protocol for Callable that mimics Reparameterization constructor."""

def __call__(
self,
Expand All @@ -31,7 +31,7 @@ def __call__(
dim: Union[int, Sequence[int]],
mask: Optional[Tensor] = None,
log_mask: Optional[Tensor] = None,
) -> Reparameterizaion:
"""Construct a Reparameterizaion object."""
) -> Reparameterization:
"""Construct a Reparameterization object."""
# TODO: pylance issue, ellipsis is required here
... # pylint:disable=unnecessary-ellipsis

0 comments on commit 24bcc17

Please sign in to comment.