Skip to content

Commit

Permalink
tensorized layers without folding
Browse files Browse the repository at this point in the history
  • Loading branch information
lkct committed Dec 6, 2023
1 parent f73f0d9 commit 794fce5
Show file tree
Hide file tree
Showing 22 changed files with 966 additions and 0 deletions.
15 changes: 15 additions & 0 deletions cirkit/new/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .inner import CPLayer as CPLayer
from .inner import DenseLayer as DenseLayer
from .inner import HadamardLayer as HadamardLayer
from .inner import InnerLayer as InnerLayer
from .inner import KroneckerLayer as KroneckerLayer
from .inner import MixingLayer as MixingLayer
from .inner import ProductLayer as ProductLayer
from .inner import SumLayer as SumLayer
from .inner import SumProductLayer as SumProductLayer
from .inner import TuckerLayer as TuckerLayer
from .input import CategoricalLayer as CategoricalLayer
from .input import ExpFamilyLayer as ExpFamilyLayer
from .input import InputLayer as InputLayer
from .input import NormalLayer as NormalLayer
from .layer import Layer as Layer
10 changes: 10 additions & 0 deletions cirkit/new/layers/inner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .inner import InnerLayer as InnerLayer
from .product import HadamardLayer as HadamardLayer
from .product import KroneckerLayer as KroneckerLayer
from .product import ProductLayer as ProductLayer
from .sum import DenseLayer as DenseLayer
from .sum import MixingLayer as MixingLayer
from .sum import SumLayer as SumLayer
from .sum_product import CPLayer as CPLayer
from .sum_product import SumProductLayer as SumProductLayer
from .sum_product import TuckerLayer as TuckerLayer
34 changes: 34 additions & 0 deletions cirkit/new/layers/inner/inner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional

from cirkit.new.layers.layer import Layer
from cirkit.new.reparams import Reparameterization


class InnerLayer(Layer):
"""The abstract base class for inner layers."""

# __init__ is overriden here to change the default value of arity, as arity=2 is the most common
# case for all inner layers.
def __init__(
self,
*,
num_input_units: int,
num_output_units: int,
arity: int = 2,
reparam: Optional[Reparameterization] = None,
) -> None:
"""Init class.
Args:
num_input_units (int): The number of input units.
num_output_units (int): The number of output units.
arity (int, optional): The arity of the layer. Defaults to 2.
reparam (Optional[Reparameterization], optional): The reparameterization for layer \
parameters, can be None if the layer has no params. Defaults to None.
"""
super().__init__(
num_input_units=num_input_units,
num_output_units=num_output_units,
arity=arity,
reparam=reparam,
)
3 changes: 3 additions & 0 deletions cirkit/new/layers/inner/product/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .hadamard import HadamardLayer as HadamardLayer
from .kronecker import KroneckerLayer as KroneckerLayer
from .product import ProductLayer as ProductLayer
48 changes: 48 additions & 0 deletions cirkit/new/layers/inner/product/hadamard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Optional

from torch import Tensor

from cirkit.new.layers.inner.product.product import ProductLayer
from cirkit.new.reparams import Reparameterization


class HadamardLayer(ProductLayer):
"""The Hadamard product layer."""

def __init__(
self,
*,
num_input_units: int,
num_output_units: int,
arity: int = 2,
reparam: Optional[Reparameterization] = None,
) -> None:
"""Init class.
Args:
num_input_units (int): The number of input units.
num_output_units (int): The number of output units, must be the same as input.
arity (int, optional): The arity of the layer. Defaults to 2.
reparam (Optional[Reparameterization], optional): Ignored. This layer has no params. \
Defaults to None.
"""
assert (
num_output_units == num_input_units
), "The number of input and output units must be the same for Hadamard product."
super().__init__(
num_input_units=num_input_units,
num_output_units=num_output_units,
arity=arity,
reparam=None,
)

def forward(self, x: Tensor) -> Tensor:
"""Run forward pass.
Args:
x (Tensor): The input to this layer, shape (H, *B, K).
Returns:
Tensor: The output of this layer, shape (*B, K).
"""
return self.comp_space.prod(x, dim=0, keepdim=False) # shape (H, *B, K) -> (*B, K).
52 changes: 52 additions & 0 deletions cirkit/new/layers/inner/product/kronecker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Literal, Optional

from torch import Tensor

from cirkit.new.layers.inner.product.product import ProductLayer
from cirkit.new.reparams import Reparameterization


class KroneckerLayer(ProductLayer):
"""The Kronecker product layer."""

def __init__(
self,
*,
num_input_units: int,
num_output_units: int,
arity: Literal[2] = 2,
reparam: Optional[Reparameterization] = None,
) -> None:
"""Init class.
Args:
num_input_units (int): The number of input units.
num_output_units (int): The number of output units, must be input**arity.
arity (Literal[2], optional): The arity of the layer, must be 2. Defaults to 2.
reparam (Optional[Reparameterization], optional): Ignored. This layer has no params. \
Defaults to None.
"""
assert (
num_output_units == num_input_units**arity
), "The number of input and output units must be the same for Hadamard product."
if arity != 2:
raise NotImplementedError("Kronecker only implemented for binary product units.")
super().__init__(
num_input_units=num_input_units,
num_output_units=num_output_units,
arity=arity,
reparam=None,
)

def forward(self, x: Tensor) -> Tensor:
"""Run forward pass.
Args:
x (Tensor): The input to this layer, shape (H, *B, K).
Returns:
Tensor: The output of this layer, shape (*B, K).
"""
x0 = x[0].unsqueeze(dim=-1) # shape (*B, K, 1).
x1 = x[1].unsqueeze(dim=-2) # shape (*B, 1, K).
return self.comp_space.mul(x0, x1).flatten(start_dim=-2) # shape (*B, K, K) -> (*B, K**2).
37 changes: 37 additions & 0 deletions cirkit/new/layers/inner/product/product.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Optional

from cirkit.new.layers.inner.inner import InnerLayer
from cirkit.new.reparams import Reparameterization


class ProductLayer(InnerLayer):
"""The abstract base class for product layers."""

# We still accept any Reparameterization instance for reparam, but it will be ignored.
# TODO: this disable should be a pylint bug
def __init__( # pylint: disable=useless-parent-delegation
self,
*,
num_input_units: int,
num_output_units: int,
arity: int = 2,
reparam: Optional[Reparameterization] = None,
) -> None:
"""Init class.
Args:
num_input_units (int): The number of input units.
num_output_units (int): The number of output units.
arity (int, optional): The arity of the layer. Defaults to 2.
reparam (Optional[Reparameterization], optional): Ignored. This layer has no params. \
Defaults to None.
"""
super().__init__(
num_input_units=num_input_units,
num_output_units=num_output_units,
arity=arity,
reparam=None,
)

def reset_parameters(self) -> None:
"""Do nothing as the product layers do not have parameters."""
3 changes: 3 additions & 0 deletions cirkit/new/layers/inner/sum/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dense import DenseLayer as DenseLayer
from .mixing import MixingLayer as MixingLayer
from .sum import SumLayer as SumLayer
55 changes: 55 additions & 0 deletions cirkit/new/layers/inner/sum/dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Literal

import torch
from torch import Tensor

from cirkit.new.layers.inner.sum.sum import SumLayer
from cirkit.new.reparams import Reparameterization


class DenseLayer(SumLayer):
"""The sum layer for dense sum within a layer."""

def __init__(
self,
*,
num_input_units: int,
num_output_units: int,
arity: Literal[1] = 1,
reparam: Reparameterization,
) -> None:
"""Init class.
Args:
num_input_units (int): The number of input units.
num_output_units (int): The number of output units, must be input**arity.
arity (Literal[1], optional): The arity of the layer, must be 1. Defaults to 1.
reparam (Reparameterization): The reparameterization for layer parameters.
"""
assert arity == 1, "DenseLayer must have arity=1. For arity>1, use MixingLayer."
super().__init__(
num_input_units=num_input_units,
num_output_units=num_output_units,
arity=arity,
reparam=reparam,
)

self.params = reparam
self.params.materialize((num_output_units, num_input_units), dim=1)

self.reset_parameters()

def _forward_linear(self, x: Tensor) -> Tensor:
return torch.einsum("oi,...i->...o", self.params(), x) # shape (*B, Ki) -> (*B, Ko).

def forward(self, x: Tensor) -> Tensor:
"""Run forward pass.
Args:
x (Tensor): The input to this layer, shape (H, *B, K).
Returns:
Tensor: The output of this layer, shape (*B, K).
"""
x = x.squeeze(dim=0) # shape (H=1, *B, K) -> (*B, K).
return self.comp_space.sum(self._forward_linear, x, dim=-1, keepdim=True) # shape (*B, K).
56 changes: 56 additions & 0 deletions cirkit/new/layers/inner/sum/mixing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from torch import Tensor

from cirkit.new.layers.inner.sum.sum import SumLayer
from cirkit.new.reparams import Reparameterization


class MixingLayer(SumLayer):
"""The sum layer for mixture among layers."""

def __init__(
self,
*,
num_input_units: int,
num_output_units: int,
arity: int = 2,
reparam: Reparameterization,
) -> None:
"""Init class.
Args:
num_input_units (int): The number of input units.
num_output_units (int): The number of output units, must be the same as input.
arity (int, optional): The arity of the layer. Defaults to 2.
reparam (Reparameterization): The reparameterization for layer parameters.
"""
assert (
num_output_units == num_input_units
), "The number of input and output units must be the same for MixingLayer."
super().__init__(
num_input_units=num_input_units,
num_output_units=num_output_units,
arity=arity,
reparam=reparam,
)

self.params = reparam
self.params.materialize((num_output_units, arity), dim=1)

self.reset_parameters()

def _forward_linear(self, x: Tensor) -> Tensor:
return torch.einsum("kh,h...k->...k", self.params(), x) # shape (H, *B, K) -> (*B, K).

def forward(self, x: Tensor) -> Tensor:
"""Run forward pass.
Args:
x (Tensor): The input to this layer, shape (H, *B, K).
Returns:
Tensor: The output of this layer, shape (*B, K).
"""
return self.comp_space.sum(
self._forward_linear, x, dim=0, keepdim=False
) # shape (H, *B, K) -> (*B, K).
41 changes: 41 additions & 0 deletions cirkit/new/layers/inner/sum/sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import functools

import torch
from torch import nn

from cirkit.new.layers.inner.inner import InnerLayer
from cirkit.new.reparams import Reparameterization


class SumLayer(InnerLayer):
"""The abstract base class for sum layers."""

def __init__(
self,
*,
num_input_units: int,
num_output_units: int,
arity: int = 2,
reparam: Reparameterization,
) -> None:
"""Init class.
Args:
num_input_units (int): The number of input units.
num_output_units (int): The number of output units.
arity (int, optional): The arity of the layer. Defaults to 2.
reparam (Reparameterization): The reparameterization for layer parameters.
"""
super().__init__(
num_input_units=num_input_units,
num_output_units=num_output_units,
arity=arity,
reparam=reparam,
)

@torch.no_grad()
def reset_parameters(self) -> None:
"""Reset parameters to default: U(0.01, 0.99)."""
for child in self.children():
if isinstance(child, Reparameterization):
child.initialize(functools.partial(nn.init.uniform_, a=0.01, b=0.99))
3 changes: 3 additions & 0 deletions cirkit/new/layers/inner/sum_product/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .cp import CPLayer as CPLayer
from .sum_product import SumProductLayer as SumProductLayer
from .tucker import TuckerLayer as TuckerLayer
Loading

0 comments on commit 794fce5

Please sign in to comment.