-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #160 from april-tools/new_reparam_refactored
Refactor reparams for the new impl for symbolic
- Loading branch information
Showing
15 changed files
with
782 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from .binary import BinaryReparam as BinaryReparam | ||
from .composed import ComposedReparam as ComposedReparam | ||
from .leaf import LeafReparam as LeafReparam | ||
from .normalized import LogSoftmaxReparam as LogSoftmaxReparam | ||
from .normalized import SoftmaxReparam as SoftmaxReparam | ||
from .reparam import Reparameterization as Reparameterization | ||
from .unary import ClampReparam as ClampReparam | ||
from .unary import ExpReparam as ExpReparam | ||
from .unary import LinearReparam as LinearReparam | ||
from .unary import SigmoidReparam as SigmoidReparam | ||
from .unary import SquareReparam as SquareReparam | ||
from .unary import UnaryReparam as UnaryReparam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Callable, Optional, Tuple, Union | ||
|
||
from torch import Tensor | ||
|
||
from cirkit.new.reparams.composed import ComposedReparam | ||
from cirkit.new.reparams.reparam import Reparameterization | ||
|
||
|
||
class BinaryReparam(ComposedReparam[Tensor, Tensor]): | ||
"""The binary composed reparameterization.""" | ||
|
||
def __init__( | ||
self, | ||
reparam1: Optional[Reparameterization] = None, | ||
reparam2: Optional[Reparameterization] = None, | ||
/, | ||
*, | ||
func: Callable[[Tensor, Tensor], Tensor], | ||
inv_func: Optional[Callable[[Tensor], Union[Tuple[Tensor, Tensor], Tensor]]] = None, | ||
) -> None: | ||
# pylint: disable=line-too-long # Disable: This long line is unavoidable. | ||
"""Init class. | ||
Args: | ||
reparam1 (Optional[Reparameterization], optional): The input reparameterization to be \ | ||
composed. If None, a LeafReparam will be constructed in its place. Defaults to None. | ||
reparam2 (Optional[Reparameterization], optional): The input reparameterization to be \ | ||
composed. If None, a LeafReparam will be constructed in its place. Defaults to None. | ||
func (Callable[[Tensor, Tensor], Tensor]): The function to compose the output from the \ | ||
parameters given by reparam. | ||
inv_func (Optional[Callable[[Tensor], Union[Tuple[Tensor, Tensor], Tensor]]], optional): \ | ||
The inverse of func, used to transform the intialization. The initializer will \ | ||
directly pass through if no inv_func provided. Defaults to None. | ||
""" | ||
# pylint: enable=line-too-long | ||
super().__init__(reparam1, reparam2, func=func, inv_func=inv_func) | ||
|
||
|
||
# TODO: circuit product |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from typing import Callable, Generic, List, Optional, Sequence, Tuple, Union, cast | ||
from typing_extensions import TypeVarTuple, Unpack # TODO: in typing from 3.11 | ||
|
||
import torch | ||
from torch import Tensor, nn | ||
|
||
from cirkit.new.reparams.leaf import LeafReparam | ||
from cirkit.new.reparams.reparam import Reparameterization | ||
|
||
Ts = TypeVarTuple("Ts") | ||
|
||
|
||
# TODO: for now the solution I found is using Generic[Unpack[TypeVarTuple]], but it does not bound | ||
# with Tuple[Tensor, ...], and extra cast is needed. Any better solution? | ||
class ComposedReparam(Reparameterization, Generic[Unpack[Ts]]): | ||
"""The base class for composed reparameterization.""" | ||
|
||
def __init__( | ||
self, | ||
*reparams: Optional[Reparameterization], | ||
func: Callable[[Unpack[Ts]], Tensor], | ||
inv_func: Optional[Callable[[Tensor], Union[Tuple[Unpack[Ts]], Tensor]]] = None, | ||
) -> None: | ||
"""Init class. | ||
Args: | ||
*reparams (Optional[Reparameterization]): The input reparameterizations to be \ | ||
composed. If there's None, a LeafReparam will be constructed in its place, but \ | ||
None must be provided instead of omitted so that the length is correct. | ||
func (Callable[[*Ts], Tensor]): The function to compose the output from the \ | ||
parameters given by reparams. | ||
inv_func (Optional[Callable[[Tensor], Union[Tuple[Unpack[Ts]], Tensor]]], optional): \ | ||
The inverse of func, used to transform the intialization. Returns one Tensor for \ | ||
all of reparams or a tuple for each of reparams. The initializer will directly \ | ||
pass through if no inv_func provided. Defaults to None. | ||
""" | ||
super().__init__() | ||
# TODO: make ModuleList a generic? | ||
# Ignore: Here we must use nn.ModuleList to register sub-modules, but we need | ||
# List[Reparameterization] so that elements are properly typed. | ||
self.reparams: List[Reparameterization] = nn.ModuleList( # type: ignore[assignment] | ||
reparam if reparam is not None else LeafReparam() for reparam in reparams | ||
) | ||
self.func = func | ||
self.inv_func = inv_func | ||
|
||
@property | ||
def dtype(self) -> torch.dtype: | ||
"""The dtype of the output parameter.""" | ||
dtype = self.reparams[0].dtype | ||
assert all( | ||
reparam.dtype == dtype for reparam in self.reparams | ||
), "The dtype of all composing reparams should be the same." | ||
return dtype | ||
|
||
@property | ||
def device(self) -> torch.device: | ||
"""The device of the output parameter.""" | ||
device = self.reparams[0].device | ||
assert all( | ||
reparam.device == device for reparam in self.reparams | ||
), "The device of all composing reparams should be the same." | ||
return device | ||
|
||
def materialize( | ||
self, | ||
shape: Sequence[int], | ||
/, | ||
*, | ||
dim: Union[int, Sequence[int]], | ||
mask: Optional[Tensor] = None, | ||
log_mask: Optional[Tensor] = None, | ||
) -> None: | ||
"""Materialize the internal parameter tensors with given shape. | ||
The initial value of the parameter after materialization is not guaranteed, and explicit \ | ||
initialization is expected. | ||
The three kwargs, dim, mask/log_mask, are used to hint the normalization of sum weights. \ | ||
The dim kwarg must be supplied to hint the sum-to-1 dimension, but mask/log_mask can be \ | ||
optional and at most one can be provided. | ||
Args: | ||
shape (Sequence[int]): The shape of the output parameter. | ||
dim (Union[int, Sequence[int]]): The dimension(s) along which the normalization will \ | ||
be applied. | ||
mask (Optional[Tensor], optional): The 0/1 mask for normalization positions. None for \ | ||
no masking. The shape must be broadcastable to shape if not None. Defaults to None. | ||
log_mask (Optional[Tensor], optional): The -inf/0 mask for normalization positions. \ | ||
None for no masking. The shape must be broadcastable to shape if not None. \ | ||
Defaults to None. | ||
""" | ||
super().materialize(shape, dim=dim, mask=mask, log_mask=log_mask) | ||
for reparam in self.reparams: | ||
if not reparam.is_materialized: | ||
# NOTE: Passing shape to all children reparams may not be always wanted. In that | ||
# case, children reparams should be materialized first, so that the following | ||
# is skipped by the above if. | ||
reparam.materialize(shape, dim=dim, mask=mask, log_mask=log_mask) | ||
|
||
assert self().shape == self.shape, "The actual shape does not match the given one." | ||
|
||
def initialize(self, initializer_: Callable[[Tensor], Tensor]) -> None: | ||
"""Initialize the internal parameter tensors with the given initializer. | ||
Initialization will cause error if not materialized first. | ||
Args: | ||
initializer_ (Callable[[Tensor], Tensor]): A function that can initialize a tensor \ | ||
inplace while also returning the value. | ||
""" | ||
if self.inv_func is None: | ||
for reparam in self.reparams: | ||
reparam.initialize(initializer_) | ||
else: | ||
init = self.inv_func(initializer_(torch.zeros(self.shape))) | ||
# TODO: This cast is unavoidable because the type of init_value is Union[*Ts]. | ||
init_values = ( | ||
(init,) * len(self.reparams) | ||
if isinstance(init, Tensor) # type: ignore[misc] # Ignore: Tensor contains Any. | ||
else cast(Tuple[Tensor, ...], init) | ||
) | ||
for reparam, init_value in zip(self.reparams, init_values): | ||
# Disable: The following shuold be safe because the lambda is immediately used | ||
# before the next loop iteration. # TODO: test if what I say is correct | ||
# pylint: disable-next=cell-var-from-loop | ||
reparam.initialize(lambda x: x.copy_(init_value)) | ||
|
||
def forward(self) -> Tensor: | ||
"""Get the reparameterized parameters. | ||
Returns: | ||
Tensor: The parameters after reparameterization. | ||
""" | ||
# TODO: This cast is unavoidable because Tensor is not Ts. | ||
# NOTE: params is not tuple, but it magically works with unpacking using *(...). | ||
params = cast(Tuple[Unpack[Ts]], (reparam() for reparam in self.reparams)) | ||
return self.func(*params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from typing import Callable, Sequence, final | ||
from typing_extensions import Unpack # TODO: in typing from 3.12 for Unpack[dict] | ||
|
||
import torch | ||
from torch import Tensor, nn | ||
|
||
from cirkit.new.reparams.reparam import Reparameterization | ||
from cirkit.new.utils.type_aliases import MaterializeKwargs | ||
|
||
|
||
# The LeafReparam only holds the tensor. Everything else should be a (unary) composed reparam. | ||
# The @final is to prevent inheritance from LeafReparam. | ||
@final | ||
class LeafReparam(Reparameterization): | ||
"""The leaf in reparameterizations that holds the parameter Tensor.""" | ||
|
||
def __init__(self) -> None: | ||
"""Init class.""" | ||
super().__init__() | ||
self.param = nn.UninitializedParameter() | ||
|
||
@property | ||
def dtype(self) -> torch.dtype: | ||
"""The dtype of the output parameter.""" | ||
return self.param.dtype | ||
|
||
@property | ||
def device(self) -> torch.device: | ||
"""The device of the output parameter.""" | ||
return self.param.device | ||
|
||
def materialize(self, shape: Sequence[int], /, **_kwargs: Unpack[MaterializeKwargs]) -> None: | ||
"""Materialize the internal parameter tensors with given shape. | ||
The initial value of the parameter after materialization is not guaranteed, and explicit \ | ||
initialization is expected. | ||
Args: | ||
shape (Sequence[int]): The shape of the output parameter. | ||
**_kwargs (Unpack[MaterializeKwargs]): Unused. See Reparameterization.materialize(). | ||
""" | ||
super().materialize(shape, dim=()) | ||
self.param.materialize(self.shape) | ||
|
||
def initialize(self, initializer_: Callable[[Tensor], Tensor]) -> None: | ||
"""Initialize the internal parameter tensors with the given initializer. | ||
Initialization will cause error if not materialized first. | ||
Args: | ||
initializer_ (Callable[[Tensor], Tensor]): A function that can initialize a tensor \ | ||
inplace while also returning the value. | ||
""" | ||
initializer_(self.param) | ||
|
||
def forward(self) -> Tensor: | ||
"""Get the reparameterized parameters. | ||
Returns: | ||
Tensor: The parameters after reparameterization. | ||
""" | ||
return self.param |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# pylint: disable=too-few-public-methods | ||
# Disable: For this file we disable the above because all classes trigger this but it's intended. | ||
|
||
from typing import Optional, Protocol, Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from cirkit.new.reparams.reparam import Reparameterization | ||
from cirkit.new.reparams.unary import UnaryReparam | ||
from cirkit.new.utils import flatten_dims, unflatten_dims | ||
|
||
|
||
class _TensorFuncWithDim(Protocol): | ||
"""The protocol for `(Tensor, dim: int) -> Tensor`.""" | ||
|
||
def __call__(self, x: Tensor, /, dim: int) -> Tensor: | ||
... | ||
|
||
|
||
class _NormalizedReparamMixin: | ||
"""A mixin for helpers useful for reparams with normalization on some dims.""" | ||
|
||
dims: Tuple[int, ...] | ||
|
||
def _apply_normalizer(self, normalizer: _TensorFuncWithDim, x: Tensor, /) -> Tensor: | ||
"""Apply a normalizer function on a Tensor over self.dims. | ||
Args: | ||
normalizer (_TensorFuncWithDim): The normalizer of a tensor with a dim arg. | ||
x (Tensor): The tensor input. | ||
Returns: | ||
Tensor: The normalized output. | ||
""" | ||
return unflatten_dims( | ||
normalizer(flatten_dims(x, dims=self.dims), dim=self.dims[0]), | ||
dims=self.dims, | ||
shape=x.shape, | ||
) | ||
|
||
|
||
class SoftmaxReparam(UnaryReparam, _NormalizedReparamMixin): | ||
"""Softmax reparameterization. | ||
Range: (0, 1), 0 available through mask, 1 available when only one element valid. | ||
Constraints: sum to 1. | ||
""" | ||
|
||
def __init__(self, reparam: Optional[Reparameterization] = None, /) -> None: | ||
"""Init class. | ||
Args: | ||
reparam (Optional[Reparameterization], optional): The input reparameterization to be \ | ||
composed. If None, a LeafReparam will be constructed in its place. Defaults to None. | ||
""" | ||
# Softmax is just scaled exp, so we take log as inv. Mask is ignored for inv. | ||
super().__init__(reparam, func=self._func, inv_func=torch.log) | ||
|
||
def _func(self, x: Tensor) -> Tensor: | ||
if self.log_mask is not None: | ||
x = x + self.log_mask | ||
# torch.softmax can only accept one dim, so we need to rearrange dims. | ||
x = self._apply_normalizer(torch.softmax, x) | ||
# nan will appear when there's only 1 element and it's masked. In that case we projecte nan | ||
# as 1 (0 in log-sapce). +inf and -inf are (unsafely) projected but will not appear. | ||
return torch.nan_to_num(x, nan=1) | ||
|
||
|
||
class LogSoftmaxReparam(UnaryReparam, _NormalizedReparamMixin): | ||
"""LogSoftmax reparameterization, which is more numarically-stable than log(softmax(...)). | ||
Range: (-inf, 0), -inf available through mask, 0 available when only one element valid. | ||
Constraints: logsumexp to 0. | ||
""" | ||
|
||
def __init__(self, reparam: Optional[Reparameterization] = None, /) -> None: | ||
"""Init class. | ||
Args: | ||
reparam (Optional[Reparameterization], optional): The input reparameterization to be \ | ||
composed. If None, a LeafReparam will be constructed in its place. Defaults to None. | ||
""" | ||
# Log_softmax is just an offset, so we take identity as inv. Mask is ignored for inv. | ||
super().__init__(reparam, func=self._func) | ||
|
||
def _func(self, x: Tensor) -> Tensor: | ||
if self.log_mask is not None: | ||
x = x + self.log_mask | ||
# torch.log_softmax can only accept one dim, so we need to rearrange dims. | ||
x = self._apply_normalizer(torch.log_softmax, x) | ||
# -inf still passes gradients, so we use a redundant projection to stop it. nan is the same | ||
# as SoftmaxReparam, projected to 0 (in log-space). +inf will not appear. | ||
return torch.nan_to_num(x, nan=0, neginf=float("-inf")) |
Oops, something went wrong.