Skip to content

Commit

Permalink
annotate shaped for params
Browse files Browse the repository at this point in the history
  • Loading branch information
lkct committed Nov 16, 2023
1 parent 407989e commit 4ed08dd
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 5 deletions.
4 changes: 4 additions & 0 deletions cirkit/layers/input/exp_family/exp_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from cirkit.layers.input import InputLayer
from cirkit.reparams.leaf import ReparamIdentity
from cirkit.reparams.reparam import Reparameterizaion

This comment has been minimized.

Copy link
@arranger1044

arranger1044 Nov 16, 2023

Member

typo: it should be Reparameterization

This comment has been minimized.

Copy link
@lkct

lkct Nov 17, 2023

Author Member

Thanks! Fixed

from cirkit.utils.type_aliases import ReparamFactory


Expand All @@ -20,6 +21,9 @@ class ExpFamilyLayer(InputLayer):
based on its implementation.
"""

params: Reparameterizaion
"""The reparameterizaion that gives the natural parameters eta, shape (D, K, P, S)."""

def __init__( # type: ignore[misc] # pylint: disable=too-many-arguments
self,
*,
Expand Down
5 changes: 4 additions & 1 deletion cirkit/layers/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from cirkit.layers.layer import Layer
from cirkit.reparams.leaf import ReparamIdentity
from cirkit.reparams.reparam import Reparameterizaion
from cirkit.utils.log_trick import log_func_exp
from cirkit.utils.type_aliases import ReparamFactory

Expand All @@ -15,6 +16,9 @@ class SumLayer(Layer):
TODO: currently this is only a sum for mixing, but not generic sum layer.
"""

params: Reparameterizaion
"""The reparameterizaion that gives the parameters for sum units, shape (F, H, K)."""

def __init__( # pylint: disable=too-many-arguments
self,
*,
Expand Down Expand Up @@ -49,7 +53,6 @@ def __init__( # pylint: disable=too-many-arguments
reparam=reparam,
)

# TODO: how to annotate shapes for reparams?
# TODO: better way to handle fold_mask shape? too many None checks
self.params = reparam(
(num_folds, arity, num_output_units),
Expand Down
34 changes: 30 additions & 4 deletions cirkit/layers/sum_product/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
class BaseCPLayer(SumProductLayer):
"""Candecomp Parafac (decomposition) layer."""

params_in: Optional[Reparameterizaion]
"""The reparameterizaion that gives the parameters for sum units on input, shape as given by \
dim names, e.g., (F, H, I, O). Can be None to disable this part of computation."""

params_out: Optional[Reparameterizaion]
"""The reparameterizaion that gives the parameters for sum units on output, shape as given by \
dim names, e.g., (F, I, O). Can be None to disable this part of computation."""

def __init__( # pylint: disable=too-many-arguments
self,
*,
Expand Down Expand Up @@ -71,7 +79,7 @@ def __init__( # pylint: disable=too-many-arguments
# TODO: currently we can only support this. any elegant impl?
assert params_in_dim_name[:2] == "fh" or fold_mask is None
# Only params_in can see the folds and need mask.
self.params_in: Optional[Reparameterizaion] = reparam(
self.params_in = reparam(
self._infer_shape(params_in_dim_name),
dim=-2,
mask=fold_mask.view(
Expand All @@ -88,9 +96,7 @@ def __init__( # pylint: disable=too-many-arguments
i, o = tuple(params_out_dim_name[-2:])
assert i == ("r" if params_in_dim_name else "i") and o == "o"
self._einsum_out = f"{params_out_dim_name},f{i}...->f{o}..."
self.params_out: Optional[Reparameterizaion] = reparam(
self._infer_shape(params_out_dim_name), dim=-2
)
self.params_out = reparam(self._infer_shape(params_out_dim_name), dim=-2)
else:
self._einsum_out = ""
self.params_out = None
Expand Down Expand Up @@ -162,6 +168,13 @@ def forward(self, x: Tensor) -> Tensor:
class CollapsedCPLayer(BaseCPLayer):
"""Candecomp Parafac (decomposition) layer, with matrix C collapsed."""

params_in: Reparameterizaion
"""The reparameterizaion that gives the parameters for sum units on input, \
shape (F, H, I, O)."""

params_out: None
"""CollapsedCPLayer does not have sum units on output."""

def __init__( # pylint: disable=too-many-arguments
self,
*,
Expand Down Expand Up @@ -197,6 +210,13 @@ def __init__( # pylint: disable=too-many-arguments
class UncollapsedCPLayer(BaseCPLayer):
"""Candecomp Parafac (decomposition) layer, without collapsing."""

params_in: Reparameterizaion
"""The reparameterizaion that gives the parameters for sum units on input, \
shape (F, H, I, R)."""

params_out: Reparameterizaion
"""The reparameterizaion that gives the parameters for sum units on output, shape (F, R, O)."""

def __init__( # pylint: disable=too-many-arguments
self,
*,
Expand Down Expand Up @@ -237,6 +257,12 @@ def __init__( # pylint: disable=too-many-arguments
class SharedCPLayer(BaseCPLayer):
"""Candecomp Parafac (decomposition) layer, with parameter sharing and matrix C collapsed."""

params_in: Reparameterizaion
"""The reparameterizaion that gives the parameters for sum units on input, shape (H, I, O)."""

params_out: None
"""SharedCPLayer does not have sum units on output."""

def __init__( # pylint: disable=too-many-arguments
self,
*,
Expand Down
4 changes: 4 additions & 0 deletions cirkit/layers/sum_product/tucker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from cirkit.layers.sum_product.sum_product import SumProductLayer
from cirkit.reparams.leaf import ReparamIdentity
from cirkit.reparams.reparam import Reparameterizaion
from cirkit.utils.log_trick import log_func_exp
from cirkit.utils.type_aliases import ReparamFactory

Expand All @@ -14,6 +15,9 @@
class TuckerLayer(SumProductLayer):
"""Tucker (2) layer."""

params: Reparameterizaion
"""The reparameterizaion that gives the parameters for sum units, shape (F, I, J, O)."""

def __init__( # pylint: disable=too-many-arguments
self,
*,
Expand Down

0 comments on commit 4ed08dd

Please sign in to comment.