-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
22 changed files
with
966 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from .inner import CPLayer as CPLayer | ||
from .inner import DenseLayer as DenseLayer | ||
from .inner import HadamardLayer as HadamardLayer | ||
from .inner import InnerLayer as InnerLayer | ||
from .inner import KroneckerLayer as KroneckerLayer | ||
from .inner import MixingLayer as MixingLayer | ||
from .inner import ProductLayer as ProductLayer | ||
from .inner import SumLayer as SumLayer | ||
from .inner import SumProductLayer as SumProductLayer | ||
from .inner import TuckerLayer as TuckerLayer | ||
from .input import CategoricalLayer as CategoricalLayer | ||
from .input import ExpFamilyLayer as ExpFamilyLayer | ||
from .input import InputLayer as InputLayer | ||
from .input import NormalLayer as NormalLayer | ||
from .layer import Layer as Layer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from .inner import InnerLayer as InnerLayer | ||
from .product import HadamardLayer as HadamardLayer | ||
from .product import KroneckerLayer as KroneckerLayer | ||
from .product import ProductLayer as ProductLayer | ||
from .sum import DenseLayer as DenseLayer | ||
from .sum import MixingLayer as MixingLayer | ||
from .sum import SumLayer as SumLayer | ||
from .sum_product import CPLayer as CPLayer | ||
from .sum_product import SumProductLayer as SumProductLayer | ||
from .sum_product import TuckerLayer as TuckerLayer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Optional | ||
|
||
from cirkit.new.layers.layer import Layer | ||
from cirkit.new.reparams import Reparameterization | ||
|
||
|
||
class InnerLayer(Layer): | ||
"""The abstract base class for inner layers.""" | ||
|
||
# __init__ is overriden here to change the default value of arity, as arity=2 is the most common | ||
# case for all inner layers. | ||
def __init__( | ||
self, | ||
*, | ||
num_input_units: int, | ||
num_output_units: int, | ||
arity: int = 2, | ||
reparam: Optional[Reparameterization] = None, | ||
) -> None: | ||
"""Init class. | ||
Args: | ||
num_input_units (int): The number of input units. | ||
num_output_units (int): The number of output units. | ||
arity (int, optional): The arity of the layer. Defaults to 2. | ||
reparam (Optional[Reparameterization], optional): The reparameterization for layer \ | ||
parameters, can be None if the layer has no params. Defaults to None. | ||
""" | ||
super().__init__( | ||
num_input_units=num_input_units, | ||
num_output_units=num_output_units, | ||
arity=arity, | ||
reparam=reparam, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .hadamard import HadamardLayer as HadamardLayer | ||
from .kronecker import KroneckerLayer as KroneckerLayer | ||
from .product import ProductLayer as ProductLayer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from typing import Optional | ||
|
||
from torch import Tensor | ||
|
||
from cirkit.new.layers.inner.product.product import ProductLayer | ||
from cirkit.new.reparams import Reparameterization | ||
|
||
|
||
class HadamardLayer(ProductLayer): | ||
"""The Hadamard product layer.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
num_input_units: int, | ||
num_output_units: int, | ||
arity: int = 2, | ||
reparam: Optional[Reparameterization] = None, | ||
) -> None: | ||
"""Init class. | ||
Args: | ||
num_input_units (int): The number of input units. | ||
num_output_units (int): The number of output units, must be the same as input. | ||
arity (int, optional): The arity of the layer. Defaults to 2. | ||
reparam (Optional[Reparameterization], optional): Ignored. This layer has no params. \ | ||
Defaults to None. | ||
""" | ||
assert ( | ||
num_output_units == num_input_units | ||
), "The number of input and output units must be the same for Hadamard product." | ||
super().__init__( | ||
num_input_units=num_input_units, | ||
num_output_units=num_output_units, | ||
arity=arity, | ||
reparam=None, | ||
) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
"""Run forward pass. | ||
Args: | ||
x (Tensor): The input to this layer, shape (H, *B, K). | ||
Returns: | ||
Tensor: The output of this layer, shape (*B, K). | ||
""" | ||
return self.comp_space.prod(x, dim=0, keepdim=False) # shape (H, *B, K) -> (*B, K). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import Literal, Optional | ||
|
||
from torch import Tensor | ||
|
||
from cirkit.new.layers.inner.product.product import ProductLayer | ||
from cirkit.new.reparams import Reparameterization | ||
|
||
|
||
class KroneckerLayer(ProductLayer): | ||
"""The Kronecker product layer.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
num_input_units: int, | ||
num_output_units: int, | ||
arity: Literal[2] = 2, | ||
reparam: Optional[Reparameterization] = None, | ||
) -> None: | ||
"""Init class. | ||
Args: | ||
num_input_units (int): The number of input units. | ||
num_output_units (int): The number of output units, must be input**arity. | ||
arity (Literal[2], optional): The arity of the layer, must be 2. Defaults to 2. | ||
reparam (Optional[Reparameterization], optional): Ignored. This layer has no params. \ | ||
Defaults to None. | ||
""" | ||
assert ( | ||
num_output_units == num_input_units**arity | ||
), "The number of input and output units must be the same for Hadamard product." | ||
if arity != 2: | ||
raise NotImplementedError("Kronecker only implemented for binary product units.") | ||
super().__init__( | ||
num_input_units=num_input_units, | ||
num_output_units=num_output_units, | ||
arity=arity, | ||
reparam=None, | ||
) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
"""Run forward pass. | ||
Args: | ||
x (Tensor): The input to this layer, shape (H, *B, K). | ||
Returns: | ||
Tensor: The output of this layer, shape (*B, K). | ||
""" | ||
x0 = x[0].unsqueeze(dim=-1) # shape (*B, K, 1). | ||
x1 = x[1].unsqueeze(dim=-2) # shape (*B, 1, K). | ||
return self.comp_space.mul(x0, x1).flatten(start_dim=-2) # shape (*B, K, K) -> (*B, K**2). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Optional | ||
|
||
from cirkit.new.layers.inner.inner import InnerLayer | ||
from cirkit.new.reparams import Reparameterization | ||
|
||
|
||
class ProductLayer(InnerLayer): | ||
"""The abstract base class for product layers.""" | ||
|
||
# We still accept any Reparameterization instance for reparam, but it will be ignored. | ||
# TODO: this disable should be a pylint bug | ||
def __init__( # pylint: disable=useless-parent-delegation | ||
self, | ||
*, | ||
num_input_units: int, | ||
num_output_units: int, | ||
arity: int = 2, | ||
reparam: Optional[Reparameterization] = None, | ||
) -> None: | ||
"""Init class. | ||
Args: | ||
num_input_units (int): The number of input units. | ||
num_output_units (int): The number of output units. | ||
arity (int, optional): The arity of the layer. Defaults to 2. | ||
reparam (Optional[Reparameterization], optional): Ignored. This layer has no params. \ | ||
Defaults to None. | ||
""" | ||
super().__init__( | ||
num_input_units=num_input_units, | ||
num_output_units=num_output_units, | ||
arity=arity, | ||
reparam=None, | ||
) | ||
|
||
def reset_parameters(self) -> None: | ||
"""Do nothing as the product layers do not have parameters.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .dense import DenseLayer as DenseLayer | ||
from .mixing import MixingLayer as MixingLayer | ||
from .sum import SumLayer as SumLayer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import Literal | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from cirkit.new.layers.inner.sum.sum import SumLayer | ||
from cirkit.new.reparams import Reparameterization | ||
|
||
|
||
class DenseLayer(SumLayer): | ||
"""The sum layer for dense sum within a layer.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
num_input_units: int, | ||
num_output_units: int, | ||
arity: Literal[1] = 1, | ||
reparam: Reparameterization, | ||
) -> None: | ||
"""Init class. | ||
Args: | ||
num_input_units (int): The number of input units. | ||
num_output_units (int): The number of output units, must be input**arity. | ||
arity (Literal[1], optional): The arity of the layer, must be 1. Defaults to 1. | ||
reparam (Reparameterization): The reparameterization for layer parameters. | ||
""" | ||
assert arity == 1, "DenseLayer must have arity=1. For arity>1, use MixingLayer." | ||
super().__init__( | ||
num_input_units=num_input_units, | ||
num_output_units=num_output_units, | ||
arity=arity, | ||
reparam=reparam, | ||
) | ||
|
||
self.params = reparam | ||
self.params.materialize((num_output_units, num_input_units), dim=1) | ||
|
||
self.reset_parameters() | ||
|
||
def _forward_linear(self, x: Tensor) -> Tensor: | ||
return torch.einsum("oi,...i->...o", self.params(), x) # shape (*B, Ki) -> (*B, Ko). | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
"""Run forward pass. | ||
Args: | ||
x (Tensor): The input to this layer, shape (H, *B, K). | ||
Returns: | ||
Tensor: The output of this layer, shape (*B, K). | ||
""" | ||
x = x.squeeze(dim=0) # shape (H=1, *B, K) -> (*B, K). | ||
return self.comp_space.sum(self._forward_linear, x, dim=-1, keepdim=True) # shape (*B, K). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
from cirkit.new.layers.inner.sum.sum import SumLayer | ||
from cirkit.new.reparams import Reparameterization | ||
|
||
|
||
class MixingLayer(SumLayer): | ||
"""The sum layer for mixture among layers.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
num_input_units: int, | ||
num_output_units: int, | ||
arity: int = 2, | ||
reparam: Reparameterization, | ||
) -> None: | ||
"""Init class. | ||
Args: | ||
num_input_units (int): The number of input units. | ||
num_output_units (int): The number of output units, must be the same as input. | ||
arity (int, optional): The arity of the layer. Defaults to 2. | ||
reparam (Reparameterization): The reparameterization for layer parameters. | ||
""" | ||
assert ( | ||
num_output_units == num_input_units | ||
), "The number of input and output units must be the same for MixingLayer." | ||
super().__init__( | ||
num_input_units=num_input_units, | ||
num_output_units=num_output_units, | ||
arity=arity, | ||
reparam=reparam, | ||
) | ||
|
||
self.params = reparam | ||
self.params.materialize((num_output_units, arity), dim=1) | ||
|
||
self.reset_parameters() | ||
|
||
def _forward_linear(self, x: Tensor) -> Tensor: | ||
return torch.einsum("kh,h...k->...k", self.params(), x) # shape (H, *B, K) -> (*B, K). | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
"""Run forward pass. | ||
Args: | ||
x (Tensor): The input to this layer, shape (H, *B, K). | ||
Returns: | ||
Tensor: The output of this layer, shape (*B, K). | ||
""" | ||
return self.comp_space.sum( | ||
self._forward_linear, x, dim=0, keepdim=False | ||
) # shape (H, *B, K) -> (*B, K). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import functools | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from cirkit.new.layers.inner.inner import InnerLayer | ||
from cirkit.new.reparams import Reparameterization | ||
|
||
|
||
class SumLayer(InnerLayer): | ||
"""The abstract base class for sum layers.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
num_input_units: int, | ||
num_output_units: int, | ||
arity: int = 2, | ||
reparam: Reparameterization, | ||
) -> None: | ||
"""Init class. | ||
Args: | ||
num_input_units (int): The number of input units. | ||
num_output_units (int): The number of output units. | ||
arity (int, optional): The arity of the layer. Defaults to 2. | ||
reparam (Reparameterization): The reparameterization for layer parameters. | ||
""" | ||
super().__init__( | ||
num_input_units=num_input_units, | ||
num_output_units=num_output_units, | ||
arity=arity, | ||
reparam=reparam, | ||
) | ||
|
||
@torch.no_grad() | ||
def reset_parameters(self) -> None: | ||
"""Reset parameters to default: U(0.01, 0.99).""" | ||
for child in self.children(): | ||
if isinstance(child, Reparameterization): | ||
child.initialize(functools.partial(nn.init.uniform_, a=0.01, b=0.99)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .cp import CPLayer as CPLayer | ||
from .sum_product import SumProductLayer as SumProductLayer | ||
from .tucker import TuckerLayer as TuckerLayer |
Oops, something went wrong.