Skip to content

Commit

Permalink
fix symb circuit/layers based on new reparams
Browse files Browse the repository at this point in the history
  • Loading branch information
lkct committed Dec 6, 2023
1 parent 7d88ac6 commit a8cff53
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 145 deletions.
40 changes: 11 additions & 29 deletions cirkit/new/symbolic/symbolic_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
from cirkit.layers.input.exp_family import ExpFamilyLayer
from cirkit.layers.sum_product import SumProductLayer
from cirkit.new.region_graph import RegionGraph, RGNode
from cirkit.new.reparams import Reparameterization
from cirkit.new.symbolic.symbolic_layer import (
SymbolicInputLayer,
SymbolicLayer,
SymbolicProductLayer,
SymbolicSumLayer,
)
from cirkit.reparams.leaf import ReparamIdentity
from cirkit.utils.type_aliases import ReparamFactory


# Disable: It's designed to have these many attributes.
class SymbolicCircuit: # pylint: disable=too-many-instance-attributes
"""The Symbolic Circuit."""

# TODO: how to design interface? require kwargs only?
# TODO: how to deal with too-many?
# pylint: disable-next=too-many-arguments,too-many-locals
def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
Expand All @@ -26,10 +26,10 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
efamily_cls: Type[ExpFamilyLayer],
layer_kwargs: Optional[Dict[str, Any]] = None,
efamily_kwargs: Optional[Dict[str, Any]] = None,
reparam: ReparamFactory = ReparamIdentity,
*,
reparam: Reparameterization, # TODO: how to set default here?
num_inner_units: int = 2,
num_input_units: int = 2,
num_channels: int = 1,
num_classes: int = 1,
):
"""Construct symbolic circuit from a region graph.
Expand All @@ -43,9 +43,7 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
reparam (ReparamFactory): The reparametrization function.
num_inner_units (int): Number of units for inner layers.
num_input_units (int): Number of units for input layers.
num_channels (int): Number of channels (e.g., 3 for RGB pixel) for input layers.
num_classes (int): Number of classes for the PC.
"""
self.region_graph = region_graph
self.scope = region_graph.scope
Expand Down Expand Up @@ -73,7 +71,6 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
else:
# Construct a symbolic layer from the region node
symbolic_layer = self._from_region_node(
prev_symbolic_layer, # type: ignore[arg-type]
rg_node,
region_graph,
layer_cls,
Expand All @@ -83,7 +80,6 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
reparam,
num_inner_units,
num_input_units,
num_channels,
num_classes,
)
existing_symbolic_layers[rg_node] = symbolic_layer
Expand All @@ -101,17 +97,15 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
# pylint: disable-next=no-self-use,too-many-arguments,too-many-locals
def _from_region_node( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
self,
prev_symbolic_layer: SymbolicLayer,
rg_node: RGNode,
region_graph: RegionGraph,
layer_cls: Type[SumProductLayer],
efamily_cls: Type[ExpFamilyLayer],
layer_kwargs: Optional[Dict[str, Any]],
efamily_kwargs: Optional[Dict[str, Any]],
reparam: ReparamFactory,
reparam: Reparameterization,
num_inner_units: int,
num_input_units: int,
num_channels: int,
num_classes: int,
) -> SymbolicLayer:
"""Create a symbolic layer based on the given region node.
Expand Down Expand Up @@ -151,20 +145,10 @@ def _from_region_node( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
if rg_node in region_graph.output_nodes # type: ignore[operator]
else num_inner_units
)
# Ignore: SymbolicInputLayer contains Any.
input_units = (
num_input_units
if any(
isinstance(layer, SymbolicInputLayer) # type: ignore[misc]
for layer in prev_symbolic_layer.inputs
)
else num_inner_units
)

symbolic_layer = SymbolicSumLayer(
scope, output_units, layer_cls, layer_kwargs # type: ignore[misc]
scope, output_units, layer_cls, layer_kwargs, reparam=reparam # type: ignore[misc]
)
symbolic_layer.set_placeholder_params(input_units, output_units, reparam)

elif rg_node in region_graph.partition_nodes: # type: ignore[operator]
assert len(inputs) == 2, "Partition nodes should have exactly two inputs."
Expand All @@ -180,13 +164,13 @@ def _from_region_node( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
symbolic_layer = SymbolicProductLayer(scope, left_input_units, layer_cls)

elif rg_node in region_graph.input_nodes: # type: ignore[operator]
# TODO: this is removed from RGNode. but we need it until refactor the input layers.
num_replicas = 1

symbolic_layer = SymbolicInputLayer(
scope, num_input_units, efamily_cls, efamily_kwargs # type: ignore[misc]
scope,
num_input_units,
efamily_cls,
efamily_kwargs, # type: ignore[misc]
reparam=reparam,
)
symbolic_layer.set_placeholder_params(num_channels, num_replicas, reparam)

else:
raise ValueError("Region node not valid.")
Expand All @@ -205,8 +189,6 @@ def _add_edge(self, tail: SymbolicLayer, head: SymbolicLayer) -> None:
tail.outputs.add(head)
head.inputs.add(tail)

########################## Properties #########################

####################################### Properties #######################################
# Here are the basic properties and some structural properties of the SymbC. Some of them are
# simply defined in __init__. Some requires additional treatment and is define below.
Expand Down
117 changes: 21 additions & 96 deletions cirkit/new/symbolic/symbolic_layer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Optional, Set, Type

from cirkit.layers.input.exp_family import (
BinomialLayer,
CategoricalLayer,
ExpFamilyLayer,
NormalLayer,
)
from cirkit.layers.input.exp_family import ExpFamilyLayer
from cirkit.layers.sum_product import (
CollapsedCPLayer,
SharedCPLayer,
SumProductLayer,
TuckerLayer,
UncollapsedCPLayer,
)
from cirkit.reparams.leaf import ReparamIdentity
from cirkit.reparams.reparam import Reparameterization
from cirkit.utils.type_aliases import ReparamFactory
from cirkit.new.reparams import Reparameterization

# TODO: double check docs and __repr__

Expand Down Expand Up @@ -51,16 +44,21 @@ def __repr__(self) -> str:
"""


class SymbolicSumLayer(SymbolicLayer):
# Disable: It's intended for SymbolicSumLayer to have only these methods.
class SymbolicSumLayer(SymbolicLayer): # pylint: disable=too-few-public-methods
"""The sum layer in symbolic circuits."""

# TODO: how to design interface? require kwargs only?
# Disable: This __init__ is designed to have these arguments.
# pylint: disable-next=too-many-arguments
def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
self,
scope: Iterable[int],
num_units: int,
layer_cls: Type[SumProductLayer],
layer_kwargs: Optional[Dict[str, Any]] = None,
*,
reparam: Reparameterization, # TODO: how to set default here?
) -> None:
"""Construct the SymbolicSumLayer.
Expand All @@ -69,6 +67,7 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
num_units (int): Number of output units in this layer.
layer_cls (Type[SumProductLayer]): The inner (sum) layer class.
layer_kwargs (Optional[Dict[str, Any]]): The parameters for the inner layer class.
reparam (Reparameterization): The reparam.
Raises:
NotImplementedError: If the shared uncollapsed CP is not implemented.
Expand All @@ -77,9 +76,9 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
self.num_units = num_units
# Ignore: Unavoidable for kwargs.
self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {} # type: ignore[misc]
self.params: Optional[Reparameterization] = None
self.params_in: Optional[Reparameterization] = None
self.params_out: Optional[Reparameterization] = None
self.params = reparam # TODO: this is not correct, but will be reviewed in new layers.
self.params_in = reparam
self.params_out = reparam

if layer_cls == TuckerLayer:
self.layer_cls = layer_cls
Expand All @@ -97,46 +96,6 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
else:
raise NotImplementedError("The shared uncollapsed CP is not implemented.")

def set_placeholder_params(
self,
num_input_units: int,
num_units: int,
reparam: ReparamFactory = ReparamIdentity,
) -> None:
"""Set un-initialized parameter placeholders for the symbolic sum layer.
Args:
num_input_units (int): Number of input units.
num_units (int): Number of output units.
reparam (ReparamFactory): Reparameterization function.
Raises:
NotImplementedError: If the shared uncollapsed CP is not implemented.
"""
assert self.num_units == num_units

# Handling different layer types
if self.layer_cls == TuckerLayer:
# number of fold = 1
self.params = reparam((1, num_input_units, num_input_units, num_units), dim=(1, 2))
else: # CP layer
# TODO: for unfolded layers we will not need these variants and ignore may be resolved
arity: int = self.layer_kwargs.get("arity", 2) # type: ignore[misc]
assert (
"fold_mask" not in self.layer_kwargs # type: ignore[misc]
or self.layer_kwargs["A"] is None # type: ignore[misc]
), "Do not support fold_mask yet"

if self.layer_cls == CollapsedCPLayer:
self.params_in = reparam((1, arity, num_input_units, num_units), dim=-2, mask=None)
elif self.layer_cls == UncollapsedCPLayer:
self.params_in = reparam((1, arity, num_input_units, 1), dim=-2, mask=None)
self.params_out = reparam((1, 1, num_units), dim=-2, mask=None)
elif self.layer_cls == SharedCPLayer:
self.params_in = reparam((arity, num_input_units, num_units), dim=-2, mask=None)
else:
raise NotImplementedError("The shared uncollapsed CP is not implemented.")

def __repr__(self) -> str:
"""Generate the repr string of the layer.
Expand All @@ -145,21 +104,13 @@ def __repr__(self) -> str:
"""
class_name = self.__class__.__name__
layer_cls_name = self.layer_cls.__name__
# TODO: review this part when we have a new reparams.
params_shape = self.params.shape if self.params is not None else None

params_in_shape = self.params_in.shape if self.params_in is not None else None
params_out_shape = self.params_out.shape if self.params_out is not None else None

return (
f"{class_name}:\n" # type: ignore[misc] # Ignore: Unavoidable for kwargs.
f"Scope: {repr(self.scope)}\n"
f"Layer Class: {layer_cls_name}\n"
f"Layer KWArgs: {repr(self.layer_kwargs)}\n"
f"Number of Units: {repr(self.num_units)}\n"
f"Parameter Shape: {repr(params_shape)}\n"
f"CP Layer Parameter in Shape: {repr(params_in_shape)}\n"
f"CP Layer Parameter out Shape: {repr(params_out_shape)}\n"
)


Expand Down Expand Up @@ -198,15 +149,20 @@ def __repr__(self) -> str:
)


class SymbolicInputLayer(SymbolicLayer):
# Disable: It's intended for SymbolicInputLayer to have only these methods.
class SymbolicInputLayer(SymbolicLayer): # pylint: disable=too-few-public-methods
"""The input layer in symbolic circuits."""

# Disable: This __init__ is designed to have these arguments.
# pylint: disable-next=too-many-arguments
def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
self,
scope: Iterable[int],
num_units: int,
layer_cls: Type[ExpFamilyLayer],
layer_kwargs: Optional[Dict[str, Any]] = None,
*,
reparam: Reparameterization, # TODO: how to set default here?
) -> None:
"""Construct the SymbolicInputLayer.
Expand All @@ -215,45 +171,16 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
num_units (int): Number of output units.
layer_cls (Type[ExpFamilyLayer]): The exponential family class.
layer_kwargs (Optional[Dict[str, Any]]): The parameters for
the exponential family class.
the exponential family class.
reparam (Reparameterization): The reparam.
"""
# TODO: many things can be merged to SymbolicLayer.__init__.
super().__init__(scope)
self.num_units = num_units
self.layer_cls = layer_cls
# Ignore: Unavoidable for kwargs.
self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {} # type: ignore[misc]
self.params: Optional[Reparameterization] = None

def set_placeholder_params(
self,
num_channels: int = 1,
num_replicas: int = 1,
reparam: ReparamFactory = ReparamIdentity,
) -> None:
"""Set un-initialized parameter placeholders for the input layer.
Args:
num_channels (int): Number of channels.
num_replicas (int): Number of replicas.
reparam (ReparamFactory): Reparameterization function.
Raises:
NotImplementedError: Only support Normal, Categorical, and Binomial input layers.
"""
# Handling different exponential family layer types
if self.layer_cls == NormalLayer:
num_suff_stats = 2 * num_channels
elif self.layer_cls == CategoricalLayer:
num_suff_stats = (
self.layer_kwargs["num_categories"] * num_channels # type: ignore[misc]
)
elif self.layer_cls == BinomialLayer:
num_suff_stats = num_channels
else:
raise NotImplementedError("Only support Normal, Categorical, and Binomial input layers")

self.params = reparam((1, self.num_units, num_replicas, num_suff_stats), dim=-1)
self.params = reparam

def __repr__(self) -> str:
"""Generate the repr string of the layer.
Expand All @@ -263,13 +190,11 @@ def __repr__(self) -> str:
"""
class_name = self.__class__.__name__
efamily_cls_name = self.layer_cls.__name__ if self.layer_cls else "None"
params_shape = self.params.shape if self.params is not None else None

return (
f"{class_name}:\n" # type: ignore[misc] # Ignore: Unavoidable for kwargs.
f"Scope: {repr(self.scope)}\n"
f"Input Exp Family Class: {efamily_cls_name}\n"
f"Layer KWArgs: {repr(self.layer_kwargs)}\n"
f"Number of Units: {repr(self.num_units)}\n"
f"Parameter Shape: {repr(params_shape)}\n"
)
Loading

0 comments on commit a8cff53

Please sign in to comment.