diff --git a/cirkit/layers/input/exp_family/exp_family.py b/cirkit/layers/input/exp_family/exp_family.py index 52d9ebc5..28d9d74a 100644 --- a/cirkit/layers/input/exp_family/exp_family.py +++ b/cirkit/layers/input/exp_family/exp_family.py @@ -6,6 +6,7 @@ from cirkit.layers.input import InputLayer from cirkit.reparams.leaf import ReparamIdentity +from cirkit.reparams.reparam import Reparameterizaion from cirkit.utils.type_aliases import ReparamFactory @@ -20,6 +21,9 @@ class ExpFamilyLayer(InputLayer): based on its implementation. """ + params: Reparameterizaion + """The reparameterizaion that gives the natural parameters eta, shape (D, K, P, S).""" + def __init__( # type: ignore[misc] # pylint: disable=too-many-arguments self, *, diff --git a/cirkit/layers/sum.py b/cirkit/layers/sum.py index 8f50214e..33322b37 100644 --- a/cirkit/layers/sum.py +++ b/cirkit/layers/sum.py @@ -5,6 +5,7 @@ from cirkit.layers.layer import Layer from cirkit.reparams.leaf import ReparamIdentity +from cirkit.reparams.reparam import Reparameterizaion from cirkit.utils.log_trick import log_func_exp from cirkit.utils.type_aliases import ReparamFactory @@ -15,6 +16,9 @@ class SumLayer(Layer): TODO: currently this is only a sum for mixing, but not generic sum layer. """ + params: Reparameterizaion + """The reparameterizaion that gives the parameters for sum units, shape (F, H, K).""" + def __init__( # pylint: disable=too-many-arguments self, *, @@ -49,7 +53,6 @@ def __init__( # pylint: disable=too-many-arguments reparam=reparam, ) - # TODO: how to annotate shapes for reparams? # TODO: better way to handle fold_mask shape? too many None checks self.params = reparam( (num_folds, arity, num_output_units), diff --git a/cirkit/layers/sum_product/cp.py b/cirkit/layers/sum_product/cp.py index c351f643..389c311b 100644 --- a/cirkit/layers/sum_product/cp.py +++ b/cirkit/layers/sum_product/cp.py @@ -13,6 +13,14 @@ class BaseCPLayer(SumProductLayer): """Candecomp Parafac (decomposition) layer.""" + params_in: Optional[Reparameterizaion] + """The reparameterizaion that gives the parameters for sum units on input, shape as given by \ + dim names, e.g., (F, H, I, O). Can be None to disable this part of computation.""" + + params_out: Optional[Reparameterizaion] + """The reparameterizaion that gives the parameters for sum units on output, shape as given by \ + dim names, e.g., (F, I, O). Can be None to disable this part of computation.""" + def __init__( # pylint: disable=too-many-arguments self, *, @@ -71,7 +79,7 @@ def __init__( # pylint: disable=too-many-arguments # TODO: currently we can only support this. any elegant impl? assert params_in_dim_name[:2] == "fh" or fold_mask is None # Only params_in can see the folds and need mask. - self.params_in: Optional[Reparameterizaion] = reparam( + self.params_in = reparam( self._infer_shape(params_in_dim_name), dim=-2, mask=fold_mask.view( @@ -88,9 +96,7 @@ def __init__( # pylint: disable=too-many-arguments i, o = tuple(params_out_dim_name[-2:]) assert i == ("r" if params_in_dim_name else "i") and o == "o" self._einsum_out = f"{params_out_dim_name},f{i}...->f{o}..." - self.params_out: Optional[Reparameterizaion] = reparam( - self._infer_shape(params_out_dim_name), dim=-2 - ) + self.params_out = reparam(self._infer_shape(params_out_dim_name), dim=-2) else: self._einsum_out = "" self.params_out = None @@ -162,6 +168,13 @@ def forward(self, x: Tensor) -> Tensor: class CollapsedCPLayer(BaseCPLayer): """Candecomp Parafac (decomposition) layer, with matrix C collapsed.""" + params_in: Reparameterizaion + """The reparameterizaion that gives the parameters for sum units on input, \ + shape (F, H, I, O).""" + + params_out: None + """CollapsedCPLayer does not have sum units on output.""" + def __init__( # pylint: disable=too-many-arguments self, *, @@ -197,6 +210,13 @@ def __init__( # pylint: disable=too-many-arguments class UncollapsedCPLayer(BaseCPLayer): """Candecomp Parafac (decomposition) layer, without collapsing.""" + params_in: Reparameterizaion + """The reparameterizaion that gives the parameters for sum units on input, \ + shape (F, H, I, R).""" + + params_out: Reparameterizaion + """The reparameterizaion that gives the parameters for sum units on output, shape (F, R, O).""" + def __init__( # pylint: disable=too-many-arguments self, *, @@ -237,6 +257,12 @@ def __init__( # pylint: disable=too-many-arguments class SharedCPLayer(BaseCPLayer): """Candecomp Parafac (decomposition) layer, with parameter sharing and matrix C collapsed.""" + params_in: Reparameterizaion + """The reparameterizaion that gives the parameters for sum units on input, shape (H, I, O).""" + + params_out: None + """SharedCPLayer does not have sum units on output.""" + def __init__( # pylint: disable=too-many-arguments self, *, diff --git a/cirkit/layers/sum_product/tucker.py b/cirkit/layers/sum_product/tucker.py index 815a5599..80128b5a 100644 --- a/cirkit/layers/sum_product/tucker.py +++ b/cirkit/layers/sum_product/tucker.py @@ -5,6 +5,7 @@ from cirkit.layers.sum_product.sum_product import SumProductLayer from cirkit.reparams.leaf import ReparamIdentity +from cirkit.reparams.reparam import Reparameterizaion from cirkit.utils.log_trick import log_func_exp from cirkit.utils.type_aliases import ReparamFactory @@ -14,6 +15,9 @@ class TuckerLayer(SumProductLayer): """Tucker (2) layer.""" + params: Reparameterizaion + """The reparameterizaion that gives the parameters for sum units, shape (F, I, J, O).""" + def __init__( # pylint: disable=too-many-arguments self, *,