Skip to content

Commit

Permalink
Merge pull request #160 from april-tools/new_reparam_refactored
Browse files Browse the repository at this point in the history
Refactor reparams for the new impl for symbolic
  • Loading branch information
lkct authored Dec 7, 2023
2 parents 04946ce + a8cff53 commit f6c1034
Show file tree
Hide file tree
Showing 15 changed files with 782 additions and 148 deletions.
2 changes: 1 addition & 1 deletion cirkit/new/region_graph/region_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@final
class RegionGraph: # pylint: disable=too-many-instance-attributes
"""The region graph that holds the high-level abstraction of circuit structure.
This class is initiated empty and nodes can be pushed into the graph with edges. It can also \
serve as a container of RGNode for use in the RG construction algorithms.
"""
Expand Down
12 changes: 12 additions & 0 deletions cirkit/new/reparams/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .binary import BinaryReparam as BinaryReparam
from .composed import ComposedReparam as ComposedReparam
from .leaf import LeafReparam as LeafReparam
from .normalized import LogSoftmaxReparam as LogSoftmaxReparam
from .normalized import SoftmaxReparam as SoftmaxReparam
from .reparam import Reparameterization as Reparameterization
from .unary import ClampReparam as ClampReparam
from .unary import ExpReparam as ExpReparam
from .unary import LinearReparam as LinearReparam
from .unary import SigmoidReparam as SigmoidReparam
from .unary import SquareReparam as SquareReparam
from .unary import UnaryReparam as UnaryReparam
39 changes: 39 additions & 0 deletions cirkit/new/reparams/binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Callable, Optional, Tuple, Union

from torch import Tensor

from cirkit.new.reparams.composed import ComposedReparam
from cirkit.new.reparams.reparam import Reparameterization


class BinaryReparam(ComposedReparam[Tensor, Tensor]):
"""The binary composed reparameterization."""

def __init__(
self,
reparam1: Optional[Reparameterization] = None,
reparam2: Optional[Reparameterization] = None,
/,
*,
func: Callable[[Tensor, Tensor], Tensor],
inv_func: Optional[Callable[[Tensor], Union[Tuple[Tensor, Tensor], Tensor]]] = None,
) -> None:
# pylint: disable=line-too-long # Disable: This long line is unavoidable.
"""Init class.
Args:
reparam1 (Optional[Reparameterization], optional): The input reparameterization to be \
composed. If None, a LeafReparam will be constructed in its place. Defaults to None.
reparam2 (Optional[Reparameterization], optional): The input reparameterization to be \
composed. If None, a LeafReparam will be constructed in its place. Defaults to None.
func (Callable[[Tensor, Tensor], Tensor]): The function to compose the output from the \
parameters given by reparam.
inv_func (Optional[Callable[[Tensor], Union[Tuple[Tensor, Tensor], Tensor]]], optional): \
The inverse of func, used to transform the intialization. The initializer will \
directly pass through if no inv_func provided. Defaults to None.
"""
# pylint: enable=line-too-long
super().__init__(reparam1, reparam2, func=func, inv_func=inv_func)


# TODO: circuit product
138 changes: 138 additions & 0 deletions cirkit/new/reparams/composed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Callable, Generic, List, Optional, Sequence, Tuple, Union, cast
from typing_extensions import TypeVarTuple, Unpack # TODO: in typing from 3.11

import torch
from torch import Tensor, nn

from cirkit.new.reparams.leaf import LeafReparam
from cirkit.new.reparams.reparam import Reparameterization

Ts = TypeVarTuple("Ts")


# TODO: for now the solution I found is using Generic[Unpack[TypeVarTuple]], but it does not bound
# with Tuple[Tensor, ...], and extra cast is needed. Any better solution?
class ComposedReparam(Reparameterization, Generic[Unpack[Ts]]):
"""The base class for composed reparameterization."""

def __init__(
self,
*reparams: Optional[Reparameterization],
func: Callable[[Unpack[Ts]], Tensor],
inv_func: Optional[Callable[[Tensor], Union[Tuple[Unpack[Ts]], Tensor]]] = None,
) -> None:
"""Init class.
Args:
*reparams (Optional[Reparameterization]): The input reparameterizations to be \
composed. If there's None, a LeafReparam will be constructed in its place, but \
None must be provided instead of omitted so that the length is correct.
func (Callable[[*Ts], Tensor]): The function to compose the output from the \
parameters given by reparams.
inv_func (Optional[Callable[[Tensor], Union[Tuple[Unpack[Ts]], Tensor]]], optional): \
The inverse of func, used to transform the intialization. Returns one Tensor for \
all of reparams or a tuple for each of reparams. The initializer will directly \
pass through if no inv_func provided. Defaults to None.
"""
super().__init__()
# TODO: make ModuleList a generic?
# Ignore: Here we must use nn.ModuleList to register sub-modules, but we need
# List[Reparameterization] so that elements are properly typed.
self.reparams: List[Reparameterization] = nn.ModuleList( # type: ignore[assignment]
reparam if reparam is not None else LeafReparam() for reparam in reparams
)
self.func = func
self.inv_func = inv_func

@property
def dtype(self) -> torch.dtype:
"""The dtype of the output parameter."""
dtype = self.reparams[0].dtype
assert all(
reparam.dtype == dtype for reparam in self.reparams
), "The dtype of all composing reparams should be the same."
return dtype

@property
def device(self) -> torch.device:
"""The device of the output parameter."""
device = self.reparams[0].device
assert all(
reparam.device == device for reparam in self.reparams
), "The device of all composing reparams should be the same."
return device

def materialize(
self,
shape: Sequence[int],
/,
*,
dim: Union[int, Sequence[int]],
mask: Optional[Tensor] = None,
log_mask: Optional[Tensor] = None,
) -> None:
"""Materialize the internal parameter tensors with given shape.
The initial value of the parameter after materialization is not guaranteed, and explicit \
initialization is expected.
The three kwargs, dim, mask/log_mask, are used to hint the normalization of sum weights. \
The dim kwarg must be supplied to hint the sum-to-1 dimension, but mask/log_mask can be \
optional and at most one can be provided.
Args:
shape (Sequence[int]): The shape of the output parameter.
dim (Union[int, Sequence[int]]): The dimension(s) along which the normalization will \
be applied.
mask (Optional[Tensor], optional): The 0/1 mask for normalization positions. None for \
no masking. The shape must be broadcastable to shape if not None. Defaults to None.
log_mask (Optional[Tensor], optional): The -inf/0 mask for normalization positions. \
None for no masking. The shape must be broadcastable to shape if not None. \
Defaults to None.
"""
super().materialize(shape, dim=dim, mask=mask, log_mask=log_mask)
for reparam in self.reparams:
if not reparam.is_materialized:
# NOTE: Passing shape to all children reparams may not be always wanted. In that
# case, children reparams should be materialized first, so that the following
# is skipped by the above if.
reparam.materialize(shape, dim=dim, mask=mask, log_mask=log_mask)

assert self().shape == self.shape, "The actual shape does not match the given one."

def initialize(self, initializer_: Callable[[Tensor], Tensor]) -> None:
"""Initialize the internal parameter tensors with the given initializer.
Initialization will cause error if not materialized first.
Args:
initializer_ (Callable[[Tensor], Tensor]): A function that can initialize a tensor \
inplace while also returning the value.
"""
if self.inv_func is None:
for reparam in self.reparams:
reparam.initialize(initializer_)
else:
init = self.inv_func(initializer_(torch.zeros(self.shape)))
# TODO: This cast is unavoidable because the type of init_value is Union[*Ts].
init_values = (
(init,) * len(self.reparams)
if isinstance(init, Tensor) # type: ignore[misc] # Ignore: Tensor contains Any.
else cast(Tuple[Tensor, ...], init)
)
for reparam, init_value in zip(self.reparams, init_values):
# Disable: The following shuold be safe because the lambda is immediately used
# before the next loop iteration. # TODO: test if what I say is correct
# pylint: disable-next=cell-var-from-loop
reparam.initialize(lambda x: x.copy_(init_value))

def forward(self) -> Tensor:
"""Get the reparameterized parameters.
Returns:
Tensor: The parameters after reparameterization.
"""
# TODO: This cast is unavoidable because Tensor is not Ts.
# NOTE: params is not tuple, but it magically works with unpacking using *(...).
params = cast(Tuple[Unpack[Ts]], (reparam() for reparam in self.reparams))
return self.func(*params)
62 changes: 62 additions & 0 deletions cirkit/new/reparams/leaf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Callable, Sequence, final
from typing_extensions import Unpack # TODO: in typing from 3.12 for Unpack[dict]

import torch
from torch import Tensor, nn

from cirkit.new.reparams.reparam import Reparameterization
from cirkit.new.utils.type_aliases import MaterializeKwargs


# The LeafReparam only holds the tensor. Everything else should be a (unary) composed reparam.
# The @final is to prevent inheritance from LeafReparam.
@final
class LeafReparam(Reparameterization):
"""The leaf in reparameterizations that holds the parameter Tensor."""

def __init__(self) -> None:
"""Init class."""
super().__init__()
self.param = nn.UninitializedParameter()

@property
def dtype(self) -> torch.dtype:
"""The dtype of the output parameter."""
return self.param.dtype

@property
def device(self) -> torch.device:
"""The device of the output parameter."""
return self.param.device

def materialize(self, shape: Sequence[int], /, **_kwargs: Unpack[MaterializeKwargs]) -> None:
"""Materialize the internal parameter tensors with given shape.
The initial value of the parameter after materialization is not guaranteed, and explicit \
initialization is expected.
Args:
shape (Sequence[int]): The shape of the output parameter.
**_kwargs (Unpack[MaterializeKwargs]): Unused. See Reparameterization.materialize().
"""
super().materialize(shape, dim=())
self.param.materialize(self.shape)

def initialize(self, initializer_: Callable[[Tensor], Tensor]) -> None:
"""Initialize the internal parameter tensors with the given initializer.
Initialization will cause error if not materialized first.
Args:
initializer_ (Callable[[Tensor], Tensor]): A function that can initialize a tensor \
inplace while also returning the value.
"""
initializer_(self.param)

def forward(self) -> Tensor:
"""Get the reparameterized parameters.
Returns:
Tensor: The parameters after reparameterization.
"""
return self.param
94 changes: 94 additions & 0 deletions cirkit/new/reparams/normalized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# pylint: disable=too-few-public-methods
# Disable: For this file we disable the above because all classes trigger this but it's intended.

from typing import Optional, Protocol, Tuple

import torch
from torch import Tensor

from cirkit.new.reparams.reparam import Reparameterization
from cirkit.new.reparams.unary import UnaryReparam
from cirkit.new.utils import flatten_dims, unflatten_dims


class _TensorFuncWithDim(Protocol):
"""The protocol for `(Tensor, dim: int) -> Tensor`."""

def __call__(self, x: Tensor, /, dim: int) -> Tensor:
...


class _NormalizedReparamMixin:
"""A mixin for helpers useful for reparams with normalization on some dims."""

dims: Tuple[int, ...]

def _apply_normalizer(self, normalizer: _TensorFuncWithDim, x: Tensor, /) -> Tensor:
"""Apply a normalizer function on a Tensor over self.dims.
Args:
normalizer (_TensorFuncWithDim): The normalizer of a tensor with a dim arg.
x (Tensor): The tensor input.
Returns:
Tensor: The normalized output.
"""
return unflatten_dims(
normalizer(flatten_dims(x, dims=self.dims), dim=self.dims[0]),
dims=self.dims,
shape=x.shape,
)


class SoftmaxReparam(UnaryReparam, _NormalizedReparamMixin):
"""Softmax reparameterization.
Range: (0, 1), 0 available through mask, 1 available when only one element valid.
Constraints: sum to 1.
"""

def __init__(self, reparam: Optional[Reparameterization] = None, /) -> None:
"""Init class.
Args:
reparam (Optional[Reparameterization], optional): The input reparameterization to be \
composed. If None, a LeafReparam will be constructed in its place. Defaults to None.
"""
# Softmax is just scaled exp, so we take log as inv. Mask is ignored for inv.
super().__init__(reparam, func=self._func, inv_func=torch.log)

def _func(self, x: Tensor) -> Tensor:
if self.log_mask is not None:
x = x + self.log_mask
# torch.softmax can only accept one dim, so we need to rearrange dims.
x = self._apply_normalizer(torch.softmax, x)
# nan will appear when there's only 1 element and it's masked. In that case we projecte nan
# as 1 (0 in log-sapce). +inf and -inf are (unsafely) projected but will not appear.
return torch.nan_to_num(x, nan=1)


class LogSoftmaxReparam(UnaryReparam, _NormalizedReparamMixin):
"""LogSoftmax reparameterization, which is more numarically-stable than log(softmax(...)).
Range: (-inf, 0), -inf available through mask, 0 available when only one element valid.
Constraints: logsumexp to 0.
"""

def __init__(self, reparam: Optional[Reparameterization] = None, /) -> None:
"""Init class.
Args:
reparam (Optional[Reparameterization], optional): The input reparameterization to be \
composed. If None, a LeafReparam will be constructed in its place. Defaults to None.
"""
# Log_softmax is just an offset, so we take identity as inv. Mask is ignored for inv.
super().__init__(reparam, func=self._func)

def _func(self, x: Tensor) -> Tensor:
if self.log_mask is not None:
x = x + self.log_mask
# torch.log_softmax can only accept one dim, so we need to rearrange dims.
x = self._apply_normalizer(torch.log_softmax, x)
# -inf still passes gradients, so we use a redundant projection to stop it. nan is the same
# as SoftmaxReparam, projected to 0 (in log-space). +inf will not appear.
return torch.nan_to_num(x, nan=0, neginf=float("-inf"))
Loading

0 comments on commit f6c1034

Please sign in to comment.