From d440cd87c04680466682c15ee4d31d161c3474cd Mon Sep 17 00:00:00 2001 From: lkct Date: Fri, 8 Dec 2023 14:43:37 +0000 Subject: [PATCH 1/7] introduce new util classes --- cirkit/new/utils/__init__.py | 2 + cirkit/new/utils/ordered_set.py | 105 +++++++++++++++++ cirkit/new/utils/scope.py | 198 ++++++++++++++++++++++++++++++++ 3 files changed, 305 insertions(+) create mode 100644 cirkit/new/utils/ordered_set.py create mode 100644 cirkit/new/utils/scope.py diff --git a/cirkit/new/utils/__init__.py b/cirkit/new/utils/__init__.py index f6f1a0e6..e8adb600 100644 --- a/cirkit/new/utils/__init__.py +++ b/cirkit/new/utils/__init__.py @@ -5,3 +5,5 @@ from .comp_space import LogSpace as LogSpace from .flatten import flatten_dims as flatten_dims from .flatten import unflatten_dims as unflatten_dims +from .ordered_set import OrderedSet as OrderedSet +from .scope import Scope as Scope diff --git a/cirkit/new/utils/ordered_set.py b/cirkit/new/utils/ordered_set.py new file mode 100644 index 00000000..3fe4e2f1 --- /dev/null +++ b/cirkit/new/utils/ordered_set.py @@ -0,0 +1,105 @@ +from typing import Any, Collection, Dict, Iterable, Iterator, Literal, Protocol, TypeVar +from typing_extensions import Self # TODO: in typing from 3.11 + + +# TODO: pylint issue? protocol are expected to have few public methods +class _SupportsDunderLT(Protocol): # pylint: disable=too-few-public-methods + # Disable: This is the only way to get a TypeVar for Protocol with __lt__. Another option, using + # Protocol[T_contra], will introduce much more ignores. + def __lt__(self, other: Any, /) -> bool: # type: ignore[misc] + ... + + +ComparableT = TypeVar("ComparableT", bound=_SupportsDunderLT) +# This, to be used as (mutable)Collection[ComparableT], can't be either covariant or contravariant: +# - Function arguments cannot be covariant, therefore nor mutable generic types; +# - See explaination in https://github.com/python/mypy/issues/7049; +# - Containers can never be contravariant by nature. + + +class OrderedSet(Collection[ComparableT]): + """A mutable container of a set that preserves element ordering when iterated. + + The elements are required to support __lt__ comparison to make sorting work. + + This is designed for node (edge) lists in the graph data structure (incl. RG, SymbC, ...), but + does not comply with all standard builtin container (list, set, dict) interface. + + The implementation relies on the order preservation of dict introduced in Python 3.7. + """ + + # NOTE: We can also inherit Reversible[ComparableT] and implement __reversed__ based on + # reverse(dict), but currently this feature is not needed. + + def __init__(self, *iterables: Iterable[ComparableT]) -> None: + """Init class. + + Args: + *iterables (Iterable[ComparableT]): The initial content, if provided any. Will be \ + inserted in the order given. + """ + super().__init__() + # The dict values are unused and always set to True. + self._container: Dict[ComparableT, Literal[True]] = { + element: True for iterable in iterables for element in iterable + } + + # Ignore: We should only test the element type. + def __contains__(self, element: ComparableT) -> bool: # type: ignore[override] + """Test whether an element is contained in the set. + + Args: + element (ComparableT): The element to test. + + Returns: + bool: Whether the element exists. + """ + return element in self._container + + def __iter__(self) -> Iterator[ComparableT]: + """Iterate over the set in order. + + Returns: + Iterator[T_co]: The iterator over the set. + """ + return iter(self._container) + + def __len__(self) -> int: + """Get the length (number of elements) of the set. + + Returns: + int: The number of elements in the set. + """ + return len(self._container) + + def append(self, element: ComparableT) -> bool: + """Add an element to the end of the set, if it does not exist yet; otherwise no-op. + + New elements are always added at the end, and the order is preserved. Existing elements \ + will not be moved if appended again. + + Args: + element (ComparableT): The element to append. + + Returns: + bool: Whether the insertion actually happened at the end. + """ + if element in self: + return False # Meaning it's a no-op. + + self._container[element] = True + return True # Meaning a successful append. + + def sort(self) -> Self: + """Sort the set inplace and return self. + + It stably sorts the elements from the insertion order to the comparison order: + - If a < b, a always precedes b in the sorted order; + - If neither a < b nor b < a, the existing order is preserved. + + Returns: + Self: The self object. + """ + # This relies on that sorted() is stable. + self._container = {element: True for element in sorted(self._container)} + return self diff --git a/cirkit/new/utils/scope.py b/cirkit/new/utils/scope.py new file mode 100644 index 00000000..b9924bde --- /dev/null +++ b/cirkit/new/utils/scope.py @@ -0,0 +1,198 @@ +from typing import FrozenSet, Iterable, Iterator, Union, final +from typing_extensions import Self # TODO: in typing from 3.11 + +# TODO: convert to bitset + + +# We mark this final so that Scope==Self in typing. Also there's no need to inherit this. +@final +class Scope(FrozenSet[int]): + """An immutable container (Hashable Collection) of int to represent the scope of a unit in a \ + circuit. + + Scopes should always be subsets of range(num_vars), but for efficiency this is not checked. + """ + + # NOTE: The following also serves as the API for Scope. Even the methods defined in the base + # class can be reused, they should be overriden below to explicitly define the methods. + + # We should use __new__ instead of __init__ because it's immutable. + def __new__(cls, scope: Union["Scope", Iterable[int]]) -> Self: + """Create the scope. + + Args: + scope (Union[Scope, Iterable[int]]): The scope as an interable of variable ids. If \ + already a Scope object, the object passed in will be directly returned. + + Returns: + Self: The Scope object. + """ + if isinstance(scope, Scope): # Saves a copy. + return scope + # TODO: mypy bug? asking for Iterable[_T_co] but it's already FrozenSet[int] + return super().__new__(cls, scope) # type: ignore[arg-type] + + def __repr__(self) -> str: + """Generate the repr string of the scope, for repr(). + + Returns: + str: The str representation of the scope. + """ + return f"Scope({repr(set(self))})" # Scope({0, 1, ...}). + + ########################### + # collections.abc.Container + ########################### + # Ignore: We should only test int. + def __contains__(self, var: int) -> bool: # type: ignore[override] + """Test whether a variable is in the scope, for `in` and `not in` operators. + + Args: + var (int): The variable id to test. + + Returns: + bool: Whether the variable is in this scope. + """ + return super().__contains__(var) + + ########################## + # collections.abc.Iterable + ########################## + def __iter__(self) -> Iterator[int]: + """Iterate over the scope variables in the order of id, for convertion to other containers. + + Returns: + Iterator[int]: The iterator over the scope (sorted). + """ + return iter(sorted(super().__iter__())) # FrozenSet is not sorted. + + ####################### + # collections.abc.Sized + ####################### + # Disable: We require explicit definition. + def __len__(self) -> int: # pylint: disable=useless-parent-delegation + """Get the length (number of variables) of the scope, for len() as well as bool(). + + Returns: + int: The number of variables in the scope. + """ + return super().__len__() + + ############################ + # collections.abc.Collection + ############################ + # Collection = Sized Iterable Container + + ########################## + # collections.abc.Hashable + ########################## + def __hash__(self) -> int: + """Get the hash value of the scope, for use as dict/set keys. + + The same scope always has the same hash value. + + Returns: + int: The hash value. + """ + return super().__hash__() + + ################ + # Total Ordering + ################ + # Ignore: We should only compare scopes. + def __eq__(self, other: "Scope") -> bool: # type: ignore[override] + """Test equality between scopes, for == and != operators. + + Two scopes are equal when they contain the same set of variables. + + Args: + other (Scope): The other scope to compare with. + + Returns: + bool: Whether self == other. + """ + return super().__eq__(other) + + # Ignore: We should only compare scopes. + def __lt__(self, other: "Scope") -> bool: # type: ignore[override] + """Compare scopes for ordering, for < operator. + + It is guaranteed that exactly one of a == b, a < b, a > b is True. Can be used for sorting \ + and the order is guaranteed to be always stable. + + Two scopes compare by the following: + - If the lengths are different, the shorter one is smaller; + - If of same length, the one with the smallest non-shared variable id is smaller; + - They should be the same scope if the above cannot compare. + + Args: + other (Scope): The other scope to compare with. + + Returns: + bool: Whether self < other. + """ + return len(self) < len(other) or len(self) == len(other) and tuple(self) < tuple(other) + + # Ignore: We should only compare scopes. + def __gt__(self, other: "Scope") -> bool: # type: ignore[override] + """Compare scopes for ordering, for > operator. + + a > b is defined as b < a, so that the reflection relationship holds. + + It is guaranteed that exactly one of a == b, a < b, a > b is True. + + Args: + other (Scope): The other scope to compare with. + + Returns: + bool: Whether self > other. + """ + return other < self + + # Ignore: @functools.total_ordering won't work if the base class has defined the following, so + # we simply disable them. Yet the above already lead to a valid total ordering. + __le__ = __ge__ = None # type: ignore[assignment] + + ######################## + # Union and Intersection + ######################## + # Ignore: We should only intersect scopes. + def __and__(self, other: "Scope") -> "Scope": # type: ignore[override] + """Get the intersection of two scopes, for & operator. + + Args: + other (Scope): The other scope to take intersection with. + + Returns: + Scope: The intersection. + """ + return Scope(super().__and__(other)) + + # Ignore: We should only union scopes. + def __or__(self, other: "Scope") -> "Scope": # type: ignore[override] + """Get the union of two scopes, for | operator. + + Args: + other (Scope): The other scope to take union with. + + Returns: + Scope: The union. + """ + return Scope(super().__or__(other)) + + # Ignore: We should only union scopes. + # Disable: This is a hack that self goes as the first of scopes, so that self.union(...) and + # Scope.union(...) both work, even when ... is empty. + # pylint: disable-next=no-self-argument + def union(*scopes: "Scope") -> "Scope": # type: ignore[override] + """Take the union over multiple scopes, for use as n-ary | operator. + + Can be used as either self.union(...) or Scope.union(...). + + Args: + *scopes (Scope): The other scopes to take union with. + + Returns: + Scope: The union. + """ + return Scope(frozenset().union(*scopes)) From 0541083d07e7ab5f81384a7acbe92aa7b2046e8a Mon Sep 17 00:00:00 2001 From: lkct Date: Fri, 8 Dec 2023 17:48:48 +0000 Subject: [PATCH 2/7] use scope and orderedset in RG --- .../region_graph/algorithms/poon_domingos.py | 7 +- .../new/region_graph/algorithms/quad_tree.py | 13 +- .../algorithms/random_binary_tree.py | 8 +- cirkit/new/region_graph/algorithms/utils.py | 14 +- cirkit/new/region_graph/region_graph.py | 241 ++++++++++-------- cirkit/new/region_graph/rg_node.py | 73 +++--- 6 files changed, 203 insertions(+), 153 deletions(-) diff --git a/cirkit/new/region_graph/algorithms/poon_domingos.py b/cirkit/new/region_graph/algorithms/poon_domingos.py index 5ab9bb39..2b7495fe 100644 --- a/cirkit/new/region_graph/algorithms/poon_domingos.py +++ b/cirkit/new/region_graph/algorithms/poon_domingos.py @@ -1,19 +1,20 @@ from collections import deque -from typing import Deque, Dict, FrozenSet, List, Optional, Sequence, Union, cast +from typing import Deque, Dict, List, Optional, Sequence, Union, cast from cirkit.new.region_graph.algorithms.utils import HyperCube, HypercubeToScope from cirkit.new.region_graph.region_graph import RegionGraph from cirkit.new.region_graph.rg_node import RegionNode +from cirkit.new.utils import Scope # TODO: test what is constructed here -def _get_region_node_by_scope(graph: RegionGraph, scope: FrozenSet[int]) -> RegionNode: +def _get_region_node_by_scope(graph: RegionGraph, scope: Scope) -> RegionNode: """Find a RegionNode with a specific scope in the RG, and construct one if not found. Args: graph (RegionGraph): The region graph to find in. - scope (Iterable[int]): The scope to find. + scope (Scope): The scope to find. Returns: RegionNode: The RegionNode found or constructed. diff --git a/cirkit/new/region_graph/algorithms/quad_tree.py b/cirkit/new/region_graph/algorithms/quad_tree.py index 2e7d50b0..72e2ffd3 100644 --- a/cirkit/new/region_graph/algorithms/quad_tree.py +++ b/cirkit/new/region_graph/algorithms/quad_tree.py @@ -4,6 +4,7 @@ from cirkit.new.region_graph.algorithms.utils import HypercubeToScope from cirkit.new.region_graph.region_graph import RegionGraph from cirkit.new.region_graph.rg_node import RegionNode +from cirkit.new.utils import Scope # TODO: now should work with H!=W but need tests @@ -96,8 +97,12 @@ def QuadTree(shape: Tuple[int, int], *, struct_decomp: bool = False) -> RegionGr graph = RegionGraph() hypercube_to_scope = HypercubeToScope(shape) - # The regions of the current layer, in shape (H, W). scope={-1} is for padding. - layer: List[List[RegionNode]] = [[RegionNode({-1})] * (width + 1) for _ in range(height + 1)] + # Padding using Scope({num_var}) which is one larger than range(num_var). + pad_scope = Scope({height * width}) + # The regions of the current layer, in shape (H, W). + layer: List[List[RegionNode]] = [ + [RegionNode(pad_scope)] * (width + 1) for _ in range(height + 1) + ] # Add univariate input nodes. for i, j in itertools.product(range(height), range(width)): @@ -114,7 +119,7 @@ def QuadTree(shape: Tuple[int, int], *, struct_decomp: bool = False) -> RegionGr height = (prev_height + 1) // 2 width = (prev_width + 1) // 2 - layer = [[RegionNode({-1})] * (width + 1) for _ in range(height + 1)] + layer = [[RegionNode(pad_scope)] * (width + 1) for _ in range(height + 1)] for i, j in itertools.product(range(height), range(width)): regions = [ # Filter valid regions in the 4 possible sub-regions. @@ -125,7 +130,7 @@ def QuadTree(shape: Tuple[int, int], *, struct_decomp: bool = False) -> RegionGr prev_layer[i * 2 + 1][j * 2], prev_layer[i * 2 + 1][j * 2 + 1], ) - if node.scope != {-1} + if node.scope != pad_scope ] if len(regions) == 1: node = regions[0] diff --git a/cirkit/new/region_graph/algorithms/random_binary_tree.py b/cirkit/new/region_graph/algorithms/random_binary_tree.py index 1c4743cd..08b59ddd 100644 --- a/cirkit/new/region_graph/algorithms/random_binary_tree.py +++ b/cirkit/new/region_graph/algorithms/random_binary_tree.py @@ -27,8 +27,8 @@ def _partition_node_randomly( Returns: List[RegionNode]: The region nodes forming the partitioning. """ - scope = list(node.scope) - random.shuffle(scope) + scope_list = list(node.scope) + random.shuffle(scope_list) split: NDArray[np.float64] # Unnormalized split points including 0 and 1. if proportions is None: @@ -39,7 +39,7 @@ def _partition_node_randomly( # Ignore: Numpy has typing issues. split_point: List[int] = ( - np.around(split / split[-1] * len(scope)) # type: ignore[assignment,misc] + np.around(split / split[-1] * len(scope_list)) # type: ignore[assignment,misc] .astype(np.int64) .tolist() ) @@ -47,7 +47,7 @@ def _partition_node_randomly( region_nodes: List[RegionNode] = [] for l, r in zip(split_point[:-1], split_point[1:]): if l < r: # A region must have as least one var, otherwise we skip it. - region_node = RegionNode(scope[l:r]) + region_node = RegionNode(scope_list[l:r]) region_nodes.append(region_node) if len(region_nodes) == 1: diff --git a/cirkit/new/region_graph/algorithms/utils.py b/cirkit/new/region_graph/algorithms/utils.py index 2c710505..78bcce7a 100644 --- a/cirkit/new/region_graph/algorithms/utils.py +++ b/cirkit/new/region_graph/algorithms/utils.py @@ -1,13 +1,15 @@ -from typing import Dict, FrozenSet, Sequence, Tuple +from typing import Dict, Sequence, Tuple import numpy as np from numpy.typing import NDArray +from cirkit.new.utils import Scope + HyperCube = Tuple[Tuple[int, ...], Tuple[int, ...]] # Just to shorten the annotation. """A hypercube represented by "top-left" and "bottom-right" coordinates (cut points).""" -class HypercubeToScope(Dict[HyperCube, FrozenSet[int]]): +class HypercubeToScope(Dict[HyperCube, Scope]): """Helper class to map sub-hypercubes to scopes with caching for variables arranged in a \ hypercube. @@ -29,7 +31,7 @@ def __init__(self, shape: Sequence[int]) -> None: # We assume it's feasible to save the whole hypercube, since it should be the whole region. self.hypercube: NDArray[np.int64] = np.arange(np.prod(shape), dtype=np.int64).reshape(shape) - def __missing__(self, key: HyperCube) -> FrozenSet[int]: + def __missing__(self, key: HyperCube) -> Scope: """Construct the item when not exist in the dict. Args: @@ -37,19 +39,19 @@ def __missing__(self, key: HyperCube) -> FrozenSet[int]: visited for the first time. Returns: - FrozenSet[int]: The value for the key, i.e., the corresponding scope. + Scope: The value for the key, i.e., the corresponding scope. """ point1, point2 = key # HyperCube is from point1 to point2. assert ( len(point1) == len(point2) == self.ndims - ), "The shape of the HyperCube is not correct." + ), "The dimension of the HyperCube is not correct." assert all( 0 <= x1 < x2 <= shape for x1, x2, shape in zip(point1, point2, self.shape) ), "The HyperCube is empty." # Ignore: Numpy has typing issues. - return frozenset( + return Scope( self.hypercube[ tuple(slice(x1, x2) for x1, x2 in zip(point1, point2)) # type: ignore[misc] ] diff --git a/cirkit/new/region_graph/region_graph.py b/cirkit/new/region_graph/region_graph.py index d6ef3cb9..a60ca2ee 100644 --- a/cirkit/new/region_graph/region_graph.py +++ b/cirkit/new/region_graph/region_graph.py @@ -1,32 +1,45 @@ import itertools import json -from typing import Dict, FrozenSet, Iterable, Iterator, Optional, Set, cast, final, overload +from typing import Dict, Iterable, Iterator, Optional, Set, Tuple, cast, final, overload from typing_extensions import Self # TODO: in typing from 3.11 import numpy as np from numpy.typing import NDArray from cirkit.new.region_graph.rg_node import PartitionNode, RegionNode, RGNode +from cirkit.new.utils import OrderedSet, Scope from cirkit.new.utils.type_aliases import RegionGraphJson -# We mark RG as final to hint that RG algorithms should not be its subclasses but factories. -# Disable: It's designed to have these many attributes. +# We mark RG as final to hint that RG algorithms should not be its subclasses but factories, so that +# constructed RGs and loaded RGs are all of type RegionGraph. @final -class RegionGraph: # pylint: disable=too-many-instance-attributes +class RegionGraph: """The region graph that holds the high-level abstraction of circuit structure. - This class is initiated empty and nodes can be pushed into the graph with edges. It can also \ - serve as a container of RGNode for use in the RG construction algorithms. + This class is initiated empty, and RG construction algorithms decides how to push nodes and \ + edges into the graph. + + After construction, the graph must be freezed before being used, so that some finalization \ + work for construction can be done properly. """ def __init__(self) -> None: - """Init class.""" - super().__init__() - # The nodes container will not be visible to the user. Instead, node views are provided for + """Init class. + + The graph is empty upon creation. + """ + # This node container will not be visible to the user. Instead, node views are provided for # read-only access to an iterable of nodes. - self._nodes: Set[RGNode] = set() - self._frozen = False + self._nodes: OrderedSet[RGNode] = OrderedSet() + + # It's on purpose that some attributes are defined outside __init__ but in freeze(). + + @property + def _is_frozen(self) -> bool: + """Whether freeze() has been called on this graph.""" + # self.scope is not set in __init__ and will be set in freeze(). + return hasattr(self, "scope") # TODO: __repr__? @@ -41,8 +54,8 @@ def add_node(self, node: RGNode) -> None: Args: node (RGNode): The node to add. """ - assert not self._frozen, "The RG should not be modified after calling freeze()." - self._nodes.add(node) + assert not self._is_frozen, "The RG should not be modified after calling freeze()." + self._nodes.append(node) @overload def add_edge(self, tail: RegionNode, head: PartitionNode) -> None: @@ -55,16 +68,16 @@ def add_edge(self, tail: PartitionNode, head: RegionNode) -> None: def add_edge(self, tail: RGNode, head: RGNode) -> None: """Add a directed edge to the graph. - If the nodes are not present, they'll be automatically added. + If the nodes are not present yet, they'll be automatically added. Args: tail (RGNode): The tail of the edge (from). head (RGNode): The head of the edge (to). """ - # add_node will check for _frozen. + # add_node will check for _is_frozen. self.add_node(tail) self.add_node(head) - tail.outputs.append(head) + tail.outputs.append(head) # TODO: this insertion order may be different from add_node order head.inputs.append(tail) def add_partitioning(self, region: RegionNode, sub_regions: Iterable[RegionNode]) -> None: @@ -72,41 +85,48 @@ def add_partitioning(self, region: RegionNode, sub_regions: Iterable[RegionNode] Args: region (RegionNode): The region to be partitioned. - sub_regions (Iterable[RegionNode]): The partitioned regions. + sub_regions (Iterable[RegionNode]): The partitioned sub-regions. """ partition = PartitionNode(region.scope) self.add_edge(partition, region) for sub_region in sub_regions: self.add_edge(sub_region, partition) - ####################################### Validation ####################################### + ######################################## Freezing ######################################## # After construction, the RG should be validated and its properties will be calculated. The RG - # should not be modified after being validated and frozen. + # should not be modified after being frozen. def freeze(self) -> Self: - """Freeze the RG to prevent further modifications. + """Freeze the RG to mark the end of construction and return self. - With a frozen RG, we also validate the RG structure and calculate its properties. - - For convenience, self is returned after freezing. + The work here includes: + - Finalizing the maintenance on internal data structures; + - Validating the RG structure; + - Assigning public attributes/properties. Returns: Self: The self object. """ - self._frozen = True - # TODO: print repr of self? - assert not (reason := self._validate()), f"The RG structure is not valid: {reason}." - self._calc_properties() + self._sort_nodes() + # TODO: include repr(self) in error msg? + assert not (reason := self._validate()), f"Illegal RG structure: {reason}." + self._set_properties() return self - # NOTE: The reason returned should not include a period. + def _sort_nodes(self) -> None: + """Sort the OrderedSet of RGNode for node list and edge tables.""" + self._nodes.sort() + for node in self._nodes: + node.inputs.sort() + node.outputs.sort() + def _validate(self) -> str: """Validate the RG structure to make sure it's a legal computational graph. Returns: - str: Empty if the RG is valid, otherwise the reason. + str: The reason for error (NOTE: without period), empty for nothing wrong. """ - # These two if conditions are also quick checks for DAG. + # These two conditions are also quick checks for DAG. if next(self.input_nodes, None) is None: return "RG must have at least one input node" if next(self.output_nodes, None) is None: @@ -115,24 +135,30 @@ def _validate(self) -> str: if any(len(partition.outputs) != 1 for partition in self.partition_nodes): return "PartitionNode can only have one output RegionNode" - if not self._check_dag(): + if any( + Scope.union(*(node_input.scope for node_input in node.inputs)) != node.scope + for node in self.inner_nodes + ): + return "The scope of an inner node should be the union of scopes of its inputs" + + if not self._check_dag(): # It's a bit complex, so extracted as a standalone method. return "RG must be a DAG" # TODO: Anything else needed? return "" - # Checking DAG is a bit complex, so it's extracted as a standalone method. def _check_dag(self) -> bool: - """Check if the RG is a DAG. + """Check whether the graph is a DAG. Returns: - bool: Whether the RG is a DAG. + bool: Whether a DAG. """ visited: Set[RGNode] = set() # Visited nodes during all DFS runs. path: Set[RGNode] = set() # Path stack for the current DFS run. + # Here we don't care about order and there's no duplicate, so set is used for fast in check. def _dfs(node: RGNode) -> bool: - """Try to traverse and check for cycle from node. + """Traverse and check for cycle from node. Args: node (RGNode): The node to start with. @@ -141,7 +167,7 @@ def _dfs(node: RGNode) -> bool: bool: Whether it's OK (not cyclic). """ visited.add(node) - path.add(node) + path.add(node) # If OK, we need to pop node out, otherwise just propagate failure. for next_node in node.outputs: if next_node in path: # Loop to the current path, including next_node==node. return False @@ -154,39 +180,41 @@ def _dfs(node: RGNode) -> bool: return True # Nothing wrong in the current DFS run. # If visited, shortcut to True, otherwise run DFS from node. - return all(node in visited or _dfs(node) for node in self._nodes) + return all(node in visited or _dfs(node) for node in self.nodes) - def _calc_properties(self) -> None: - """Calculate the properties of the RG and save them to self. + def _set_properties(self) -> None: + """Set the attributes for RG properties in self. - These properties are not valid before calling this method. + Names set here are not valid in self before calling this method. """ - # It's intended to assign these attributes outside __init__: without calling into freeze(), - # these attrs, especially self.num_vars, will be undefined, and therefore blocks downstream - # usage. Thus freeze() will be enforced to run before using RG. + # It's intended to assign these attributes outside __init__. Without calling into freeze(), + # these attrs, especially self.scope and self.num_vars, will be undefined, and therefore + # blocks downstream usage. Thus freeze() will be enforced to run before using the RG. - self.scope = frozenset().union(*(node.scope for node in self.output_nodes)) + # Guaranteed to be non-empty by _validate(). + self.scope = Scope.union(*(node.scope for node in self.output_nodes)) self.num_vars = len(self.scope) self.is_smooth = all( - all(partition.scope == region.scope for partition in region.inputs) + partition.scope == region.scope for region in self.inner_region_nodes + for partition in region.inputs ) - self.is_decomposable = all( - not any( - region1.scope & region2.scope - for region1, region2 in itertools.combinations(partition.inputs, 2) - ) - and set().union(*(region.scope for region in partition.inputs)) == partition.scope + # Union of input scopes is guaranteed to be the node scope by _validate(). + self.is_decomposable = not any( + region1.scope & region2.scope for partition in self.partition_nodes + for region1, region2 in itertools.combinations(partition.inputs, 2) ) + # TODO: is this correct for more-than-2 partition? # Structured-decomposablity first requires smoothness and decomposability. self.is_structured_decomposable = self.is_smooth and self.is_decomposable - decompositions: Dict[FrozenSet[int], Set[FrozenSet[int]]] = {} + decompositions: Dict[Scope, Tuple[Scope, ...]] = {} for partition in self.partition_nodes: - decomp = set(region.scope for region in partition.inputs) + # The scopes are sorted by _sort_nodes(). + decomp = tuple(region.scope for region in partition.inputs) if partition.scope not in decompositions: decompositions[partition.scope] = decomp self.is_structured_decomposable &= decomp == decompositions[partition.scope] @@ -198,11 +226,11 @@ def _calc_properties(self) -> None: ####################################### Properties ####################################### # Here are the basic properties and some structural properties of the RG. Some of them are - # static and defined in the _calc_properties after the RG is freezed. Some requires further - # information and is define below to be calculated on the fly. - # We list everything here to add "docstrings" to them. + # static and defined in the _set_properties() when the RG is freezed. Some requires further + # information and is defined below to be calculated on the fly. We list everything here to add + # "docstrings" to them, but note that they're not valid before freeze(). - scope: FrozenSet[int] + scope: Scope """The scope of the RG, i.e., the union of scopes of all output units.""" num_vars: int @@ -212,8 +240,7 @@ def _calc_properties(self) -> None: """Whether the RG is smooth, i.e., all inputs to a region have the same scope.""" is_decomposable: bool - """Whether the RG is decomposable, i.e., inputs to a partition have disjoint scopes and their \ - union is the scope of the partition.""" + """Whether the RG is decomposable, i.e., inputs to a partition have disjoint scopes.""" is_structured_decomposable: bool """Whether the RG is structured-decomposable, i.e., compatible to itself.""" @@ -221,24 +248,26 @@ def _calc_properties(self) -> None: is_omni_compatible: bool """Whether the RG is omni-compatible, i.e., compatible to all circuits of the same scope.""" - def is_compatible(self, other: "RegionGraph", scope: Optional[Iterable[int]] = None) -> bool: + def is_compatible(self, other: "RegionGraph", *, scope: Optional[Iterable[int]] = None) -> bool: """Test compatibility with another region graph over the given scope. Args: other (RegionGraph): The other region graph to compare with. scope (Optional[Iterable[int]], optional): The scope over which to check. If None, \ - will use the intersection of the scopes of two RG. Defaults to None. + will use the intersection of the scopes of the two RG. Defaults to None. Returns: - bool: Whether the RG is compatible to the other. + bool: Whether self is compatible to other. """ + # _is_frozen is implicitly tested because is_smooth is set in freeze(). if not ( self.is_smooth and self.is_decomposable and other.is_smooth and other.is_decomposable ): # Compatiblility first requires smoothness and decomposability. return False - scope = frozenset(scope) if scope is not None else self.scope & other.scope + scope = Scope(scope) if scope is not None else self.scope & other.scope + # TODO: is this correct for more-than-2 partition? for partition1, partition2 in itertools.product( self.partition_nodes, other.partition_nodes ): @@ -249,7 +278,8 @@ def is_compatible(self, other: "RegionGraph", scope: Optional[Iterable[int]] = N for (i, region1), (j, region2) in itertools.product( enumerate(partition1.inputs), enumerate(partition2.inputs) ): - adj_mat[i, j] = bool(region1.scope & region2.scope) # I.e., scopes intersect. + # I.e., if scopes intersect over the scope to test. + adj_mat[i, j] = bool(region1.scope & region2.scope & scope) adj_mat = adj_mat @ adj_mat.T # Now we have adjencency from inputs1 (of self) to inputs1. An edge means the two # regions must be partitioned together. @@ -267,11 +297,11 @@ def is_compatible(self, other: "RegionGraph", scope: Optional[Iterable[int]] = N return True ####################################### Node views ####################################### - # These are iterable views of the nodes in the RG, available even when the graph is only - # partially constructed. For efficiency, all these views are iterators (implemented as a - # container iter or a generator), so that they can be chained for iteration without - # instantiating intermediate containers. - # NOTE: There's no ordering graranteed for these views. However RGNode can be sorted if needed. + # These are iterable views of the nodes in the RG, and the topological order is guaranteed (by a + # stronger ordering). For efficiency, all these views are iterators (a container iter or a + # generator), so that they can be chained without instantiating intermediate containers. + # NOTE: The views are even available when the graph is only partially constructed, but without + # freeze() there's no ordering graranteed for these views. @property def nodes(self) -> Iterator[RGNode]: @@ -290,24 +320,29 @@ def partition_nodes(self) -> Iterator[PartitionNode]: @property def input_nodes(self) -> Iterator[RegionNode]: - """Input nodes of the graph, which are guaranteed to be regions.""" + """Input nodes of the graph, which are always regions.""" return (node for node in self.region_nodes if not node.inputs) @property def output_nodes(self) -> Iterator[RegionNode]: - """Output nodes of the graph, which are guaranteed to be regions.""" + """Output nodes of the graph, which are always regions.""" return (node for node in self.region_nodes if not node.outputs) + @property + def inner_nodes(self) -> Iterator[RGNode]: + """Inner (non-input) nodes in the graph.""" + return (node for node in self.nodes if node.inputs) + @property def inner_region_nodes(self) -> Iterator[RegionNode]: - """Inner (non-input) region nodes in the graph.""" + """Inner region nodes in the graph.""" return (node for node in self.region_nodes if node.inputs) #################################### (De)Serialization ################################### # The RG can be dumped and loaded from json files, which can be useful when we want to save and - # share it. The load() is another way to construct a RG. + # share it. The load() is another way to construct a RG other than the RG algorithms. - # TODO: we can only deal with 2-partition here + # TODO: The RG json is only defined to 2-partition. def dump(self, filename: str) -> None: """Dump the region graph to the json file. @@ -317,28 +352,26 @@ def dump(self, filename: str) -> None: Args: filename (str): The file name for dumping. """ - graph_json: RegionGraphJson = {"regions": {}, "graph": []} + rg_json: RegionGraphJson = {"regions": {}, "graph": []} - region_id = {node: idx for idx, node in enumerate(self.region_nodes)} - graph_json["regions"] = {str(idx): list(node.scope) for node, idx in region_id.items()} + region_idx = {node: idx for idx, node in enumerate(self.region_nodes)} + # "regions" keeps ordered by the index, corresponding to the self.region_nodes order. + rg_json["regions"] = {str(idx): list(node.scope) for node, idx in region_idx.items()} + + # "graph" keeps ordered by the list, corresponding to the self.partition_nodes order. for partition in self.partition_nodes: - part_inputs = partition.inputs - assert len(part_inputs) == 2, "We can only dump RG with 2-partitions." - (part_output,) = partition.outputs + input_idxs = [region_idx[region_in] for region_in in partition.inputs] + assert len(input_idxs) == 2, "We can only dump RG with 2-partitions." # TODO: 2-part + # partition.outputs is guaranteed to have len==1 by _validate(). + output_idx = region_idx[next(iter(partition.outputs))] - graph_json["graph"].append( - { - "l": region_id[part_inputs[0]], - "r": region_id[part_inputs[1]], - "p": region_id[part_output], - } - ) + rg_json["graph"].append({"l": input_idxs[0], "r": input_idxs[1], "p": output_idx}) - # TODO: logging for graph_json? + # TODO: logging for dumping graph_json? with open(filename, "w", encoding="utf-8") as f: - json.dump(graph_json, f) + json.dump(rg_json, f) @staticmethod def load(filename: str) -> "RegionGraph": @@ -353,26 +386,26 @@ def load(filename: str) -> "RegionGraph": RegionGraph: The loaded region graph. """ with open(filename, "r", encoding="utf-8") as f: - graph_json: RegionGraphJson = json.load(f) + rg_json: RegionGraphJson = json.load(f) - id_region = {int(idx): RegionNode(scope) for idx, scope in graph_json["regions"].items()} + # By json standard, this is not guaranteed to be sorted. + idx_region = {int(idx): RegionNode(scope) for idx, scope in rg_json["regions"].items()} graph = RegionGraph() - if not graph_json["graph"]: - # A corner case: no edge in RG, meaning ther's only one region node, and the following - # for-loop does not work, so we need to handle it here. - assert len(id_region) == 1 - graph.add_node(id_region[0]) - - for partition in graph_json["graph"]: - part_inputs = id_region[partition["l"]], id_region[partition["r"]] - part_output = id_region[partition["p"]] + # Iterate regions by the order of index so that the order of graph.region_nodes is + # preserved. + # TODO: enumerate does not work on dict + for idx in range(len(idx_region)): # pylint: disable=consider-using-enumerate + graph.add_node(idx_region[idx]) - partition_node = PartitionNode(part_output.scope) + # Iterate partitions by the order of list so that the order of graph.partition_nodes is + # preserved. + for partition in rg_json["graph"]: + regions_in = [idx_region[idx_in] for idx_in in (partition["l"], partition["r"])] + region_out = idx_region[partition["p"]] - for part_input in part_inputs: - graph.add_edge(part_input, partition_node) - graph.add_edge(partition_node, part_output) + # TODO: is the order of edge table saved in nodes preserved? + graph.add_partitioning(region_out, regions_in) return graph.freeze() diff --git a/cirkit/new/region_graph/rg_node.py b/cirkit/new/region_graph/rg_node.py index 99a274bc..33fc1213 100644 --- a/cirkit/new/region_graph/rg_node.py +++ b/cirkit/new/region_graph/rg_node.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterable + +from cirkit.new.utils import OrderedSet, Scope class RGNode(ABC): @@ -14,12 +16,12 @@ def __init__(self, scope: Iterable[int]) -> None: scope (Iterable[int]): The scope of this node. """ super().__init__() - self.scope = frozenset(scope) + self.scope = Scope(scope) assert self.scope, "The scope of a node in RG must not be empty." - # The edge lists are initiated empty because a node may be contructed without the whole RG. - self.inputs: List[RGNode] = [] - self.outputs: List[RGNode] = [] + # The edge tables are initiated empty because a node may be contructed without the whole RG. + self.inputs: OrderedSet[RGNode] = OrderedSet() + self.outputs: OrderedSet[RGNode] = OrderedSet() # TODO: we might want to save something, but this is not used yet. self._metadata: Dict[str, Any] = {} # type: ignore[misc] @@ -30,43 +32,51 @@ def __repr__(self) -> str: Returns: str: The str representation of the node. """ - # Here we convert scope to set so that we don't get "fronzenset(...)" in output. - return f"{self.__class__.__name__}({set(self.scope)})" + return f"{self.__class__.__name__}({self.scope})" + + # __hash__ and __eq__ are defined by default to compare on object identity, i.e., + # (a is b) <=> (a == b) <=> (hash(a) == hash(b)). # `other: Self` is wrong as it can be RGNode instead of just same as self. def __lt__(self, other: "RGNode") -> bool: - """Compare the node with another node, can be used for sorting. - - The default comparison is: - - First, RegionNode is smaller than PartitionNode; - - Then, the node with smaller scope (by frozenset.__lt__) is smaller; - - Finally, same type of nodes with same scope are ordered by hash (mem addr by default). - - This guarantees two nodes compare equal only when they're the same object. + """Compare the node with another node, for < operator implicitly used in sorting. + + TODO: the following is currently NOT correct because the sorting rule is not complete. + It is guaranteed that exactly one of a == b, a < b, a > b is True. Can be used for \ + sorting and order is guaranteed to be always stable. + TODO: alternative if we ignore the above todo: + Note that a != b does not imply a < b or b < a, as the order within the the same type of \ + node with the same scope is not defined, in which case a == b, a < b, b < a are all false. \ + Yet although there's no total ordering, sorting can still be performed. + + The comparison between two RGNode is: + - If they have different scopes, the one with a smaller scope is smaller; + - With the same scope, PartitionNode is smaller than RegionNode; + - For same type of node and same scope, __lt__ is always False, indicating "equality for \ + the purpose of sorting". + + This comparison guarantees the topological order in a (smooth and decomposable) RG: + - For a RegionNode->PartitionNode edge, Region.scope < Partition.scope; + - For a PartitionNode->RegionNode edge, they have the same scope and Partition < Region. Args: - other (RGNode): The other node to compare. + other (RGNode): The other node to compare with. Returns: bool: Whether self < other. """ # A trick to compare classes: if the class name is equal, then the class is the same; - # otherwise "P" < "R" but RegionNode < PartitionNode, so class names are reversed below. - return (other.__class__.__name__, self.scope, hash(self)) < ( - self.__class__.__name__, - other.scope, - hash(other), - ) + # otherwise "P" < "R" and PartitionNode < RegionNode, so comparison of class names works. + return (self.scope, self.__class__.__name__) < (other.scope, other.__class__.__name__) -# Disable: It's intended for RegionNode. It's only used to provide a concrete RGNode with nothing. +# Disable: It's intended for RegionNode to only have few methods. class RegionNode(RGNode): # pylint: disable=too-few-public-methods """The region node in the region graph.""" - # Ignore: Mutable types are typically invariant, so there's no other choice, and we can only - # enforce the typing with ignore. - inputs: List["PartitionNode"] # type: ignore[assignment] - outputs: List["PartitionNode"] # type: ignore[assignment] + # Ignore: Mutable containers are invariant, so there's no other choice. + inputs: OrderedSet["PartitionNode"] # type: ignore[assignment] + outputs: OrderedSet["PartitionNode"] # type: ignore[assignment] # TODO: better way to impl this? we must have an abstract method in RGNode def __init__(self, scope: Iterable[int]) -> None: # pylint: disable=useless-parent-delegation @@ -78,14 +88,13 @@ def __init__(self, scope: Iterable[int]) -> None: # pylint: disable=useless-par super().__init__(scope) -# Disable: It's intended for RegionNode. It's only used to provide a concrete RGNode with nothing. +# Disable: It's intended for PartitionNode to only have few methods. class PartitionNode(RGNode): # pylint: disable=too-few-public-methods """The partition node in the region graph.""" - # Ignore: Mutable types are typically invariant, so there's no other choice, and we can only - # enforce the typing with ignore. - inputs: List["RegionNode"] # type: ignore[assignment] - outputs: List["RegionNode"] # type: ignore[assignment] + # Ignore: Mutable containers are invariant, so there's no other choice. + inputs: OrderedSet["RegionNode"] # type: ignore[assignment] + outputs: OrderedSet["RegionNode"] # type: ignore[assignment] # TODO: better way to impl this? we must have an abstract method in RGNode def __init__(self, scope: Iterable[int]) -> None: # pylint: disable=useless-parent-delegation From a80aaedee5ce9e31ad53340ead54e0e67b5a7b5c Mon Sep 17 00:00:00 2001 From: lkct Date: Fri, 8 Dec 2023 17:52:18 +0000 Subject: [PATCH 3/7] use scope in SymbC --- cirkit/new/symbolic/symbolic_circuit.py | 16 ++++++++------- cirkit/new/symbolic/symbolic_layer.py | 3 ++- tests/new/symbolic/test_symbolic_circuit.py | 12 +++++------ tests/new/symbolic/test_symbolic_layer.py | 22 +++++++++------------ 4 files changed, 26 insertions(+), 27 deletions(-) diff --git a/cirkit/new/symbolic/symbolic_circuit.py b/cirkit/new/symbolic/symbolic_circuit.py index a4ad44ad..ad0c5dde 100644 --- a/cirkit/new/symbolic/symbolic_circuit.py +++ b/cirkit/new/symbolic/symbolic_circuit.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, FrozenSet, Iterable, Iterator, Optional, Set, Type +from typing import Any, Dict, Iterable, Iterator, Optional, Set, Type from cirkit.new.layers import InputLayer, SumProductLayer from cirkit.new.region_graph import RegionGraph, RGNode @@ -9,13 +9,14 @@ SymbolicProductLayer, SymbolicSumLayer, ) +from cirkit.new.utils import Scope # TODO: double check docs and __repr__ # Disable: It's designed to have these many attributes. class SymbolicCircuit: # pylint: disable=too-many-instance-attributes - """The Symbolic Circuit.""" + """The symbolic representation of a tensorized circuit.""" # TODO: how to design interface? require kwargs only? # TODO: how to deal with too-many? @@ -156,8 +157,9 @@ def _from_region_node( # type: ignore[misc] # Ignore: Unavoidable for kwargs. assert len(inputs) == 2, "Partition nodes should have exactly two inputs." assert len(outputs) > 0, "Partition nodes should have at least one output." - left_input_units = num_inner_units if inputs[0].inputs else num_input_units - right_input_units = num_inner_units if inputs[1].inputs else num_input_units + input0, input1 = inputs + left_input_units = num_inner_units if input0.inputs else num_input_units + right_input_units = num_inner_units if input1.inputs else num_input_units assert ( left_input_units == right_input_units @@ -196,7 +198,7 @@ def _add_edge(self, tail: SymbolicLayer, head: SymbolicLayer) -> None: # simply defined in __init__. Some requires additional treatment and is define below. # We list everything here to add "docstrings" to them. - scope: FrozenSet[int] + scope: Scope """The scope of the SymbC, i.e., the union of scopes of all output layers.""" num_vars: int @@ -216,7 +218,7 @@ def _add_edge(self, tail: SymbolicLayer, head: SymbolicLayer) -> None: """Whether the SymbC is omni-compatible, i.e., compatible to all circuits of the same scope.""" def is_compatible( - self, other: "SymbolicCircuit", scope: Optional[Iterable[int]] = None + self, other: "SymbolicCircuit", *, scope: Optional[Iterable[int]] = None ) -> bool: """Test compatibility with another symbolic circuit over the given scope. @@ -228,7 +230,7 @@ def is_compatible( Returns: bool: Whether the SymbC is compatible to the other. """ - return self.region_graph.is_compatible(other.region_graph, scope) + return self.region_graph.is_compatible(other.region_graph, scope=scope) ####################################### Layer views ###################################### # These are iterable views of the nodes in the SymbC. For efficiency, all these views are diff --git a/cirkit/new/symbolic/symbolic_layer.py b/cirkit/new/symbolic/symbolic_layer.py index 914aa1ba..e9a8b541 100644 --- a/cirkit/new/symbolic/symbolic_layer.py +++ b/cirkit/new/symbolic/symbolic_layer.py @@ -3,6 +3,7 @@ from cirkit.new.layers import InputLayer, SumProductLayer from cirkit.new.reparams import Reparameterization +from cirkit.new.utils import Scope # TODO: double check docs and __repr__ @@ -19,7 +20,7 @@ def __init__(self, scope: Iterable[int]) -> None: scope (Iterable[int]): The scope of this layer. """ super().__init__() - self.scope = frozenset(scope) + self.scope = Scope(scope) assert self.scope, "The scope of a layer in SymbC must be non-empty." # TODO: should this be a List? what do we need on ordering? diff --git a/tests/new/symbolic/test_symbolic_circuit.py b/tests/new/symbolic/test_symbolic_circuit.py index da9b40ae..2cfc9b3b 100644 --- a/tests/new/symbolic/test_symbolic_circuit.py +++ b/tests/new/symbolic/test_symbolic_circuit.py @@ -4,8 +4,8 @@ from cirkit.new.layers import CategoricalLayer, CPLayer from cirkit.new.region_graph import QuadTree, RegionGraph, RegionNode from cirkit.new.reparams import ExpReparam -from cirkit.new.symbolic.symbolic_circuit import SymbolicCircuit -from cirkit.new.symbolic.symbolic_layer import SymbolicInputLayer, SymbolicSumLayer +from cirkit.new.symbolic import SymbolicCircuit, SymbolicInputLayer, SymbolicSumLayer +from cirkit.new.utils import Scope def test_symbolic_circuit() -> None: @@ -16,9 +16,9 @@ def test_symbolic_circuit() -> None: reparam = ExpReparam() rg = RegionGraph() - node1 = RegionNode((1,)) - node2 = RegionNode((2,)) - region = RegionNode((1, 2)) + node1 = RegionNode({0}) + node2 = RegionNode({1}) + region = RegionNode({0, 1}) rg.add_partitioning(region, [node1, node2]) rg.freeze() @@ -62,4 +62,4 @@ def test_symbolic_circuit() -> None: assert len(list(circuit_2.layers)) == 46 assert len(list(circuit_2.input_layers)) == 16 assert len(list(circuit_2.output_layers)) == 1 - assert (circuit_2.scope) == frozenset(range(16)) + assert circuit_2.scope == Scope(range(16)) diff --git a/tests/new/symbolic/test_symbolic_layer.py b/tests/new/symbolic/test_symbolic_layer.py index f5d17ae0..1c9f2e20 100644 --- a/tests/new/symbolic/test_symbolic_layer.py +++ b/tests/new/symbolic/test_symbolic_layer.py @@ -2,53 +2,49 @@ from cirkit.new.layers import CategoricalLayer, CPLayer, TuckerLayer from cirkit.new.reparams import ExpReparam -from cirkit.new.symbolic.symbolic_layer import ( - SymbolicInputLayer, - SymbolicProductLayer, - SymbolicSumLayer, -) +from cirkit.new.symbolic import SymbolicInputLayer, SymbolicProductLayer, SymbolicSumLayer def test_symbolic_sum_layer() -> None: - scope = [1, 2] + scope = {0, 1} num_units = 3 layer = SymbolicSumLayer(scope, num_units, TuckerLayer, reparam=ExpReparam()) assert "SymbolicSumLayer" in repr(layer) - assert "Scope: frozenset({1, 2})" in repr(layer) + assert "Scope: Scope({0, 1})" in repr(layer) assert "Layer Class: TuckerLayer" in repr(layer) assert "Number of Units: 3" in repr(layer) def test_symbolic_sum_layer_cp() -> None: - scope = [1, 2] + scope = {0, 1} num_units = 3 layer_kwargs = {"collapsed": False, "shared": False, "arity": 2} layer = SymbolicSumLayer(scope, num_units, CPLayer, layer_kwargs, reparam=ExpReparam()) assert "SymbolicSumLayer" in repr(layer) - assert "Scope: frozenset({1, 2})" in repr(layer) + assert "Scope: Scope({0, 1})" in repr(layer) assert "Layer Class: CPLayer" in repr(layer) assert "Number of Units: 3" in repr(layer) def test_symbolic_product_node() -> None: - scope = [1, 2] + scope = {0, 1} num_input_units = 2 layer = SymbolicProductLayer(scope, num_input_units, TuckerLayer) assert "SymbolicProductLayer" in repr(layer) - assert "Scope: frozenset({1, 2})" in repr(layer) + assert "Scope: Scope({0, 1})" in repr(layer) assert "Layer Class: TuckerLayer" in repr(layer) assert "Number of Units: 2" in repr(layer) def test_symbolic_input_node() -> None: - scope = [1, 2] + scope = {0, 1} num_units = 3 input_kwargs = {"num_categories": 5} layer = SymbolicInputLayer( scope, num_units, CategoricalLayer, input_kwargs, reparam=ExpReparam() ) assert "SymbolicInputLayer" in repr(layer) - assert "Scope: frozenset({1, 2})" in repr(layer) + assert "Scope: Scope({0, 1})" in repr(layer) assert "Input Exp Family Class: CategoricalLayer" in repr(layer) assert "Layer KWArgs: {'num_categories': 5}" in repr(layer) assert "Number of Units: 3" in repr(layer) From 210cb463f7df73da22bf0cd4d94e0db7060a410e Mon Sep 17 00:00:00 2001 From: lkct Date: Fri, 8 Dec 2023 18:28:29 +0000 Subject: [PATCH 4/7] some input layers may have no params --- cirkit/new/layers/input/input.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cirkit/new/layers/input/input.py b/cirkit/new/layers/input/input.py index 5a3fefbc..b4d9e6e0 100644 --- a/cirkit/new/layers/input/input.py +++ b/cirkit/new/layers/input/input.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Optional from cirkit.new.layers.layer import Layer from cirkit.new.reparams import Reparameterization @@ -22,7 +22,7 @@ def __init__( num_input_units: int, num_output_units: int, arity: Literal[1] = 1, - reparam: Reparameterization, + reparam: Optional[Reparameterization] = None, ) -> None: """Init class. @@ -30,7 +30,8 @@ def __init__( num_input_units (int): The number of input units, i.e. number of channels for variables. num_output_units (int): The number of output units. arity (Literal[1], optional): The arity of the layer, must be 1. Defaults to 1. - reparam (Reparameterization): The reparameterization for layer parameters. + reparam (Optional[Reparameterization], optional): The reparameterization for layer \ + parameters, can be None if the layer has no params. Defaults to None. """ assert arity == 1, "We define arity=1 for InputLayer." super().__init__( From 34373aa0bd8cde415f72bb3d5e9f7ae29ed781ea Mon Sep 17 00:00:00 2001 From: lkct Date: Fri, 8 Dec 2023 19:07:50 +0000 Subject: [PATCH 5/7] refactor symb layer --- cirkit/new/symbolic/symbolic_layer.py | 197 ++++++++++++++++++-------- 1 file changed, 138 insertions(+), 59 deletions(-) diff --git a/cirkit/new/symbolic/symbolic_layer.py b/cirkit/new/symbolic/symbolic_layer.py index e9a8b541..21f27e57 100644 --- a/cirkit/new/symbolic/symbolic_layer.py +++ b/cirkit/new/symbolic/symbolic_layer.py @@ -1,31 +1,63 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Optional, Set, Type +from typing import Any, Dict, Iterable, Optional, Type -from cirkit.new.layers import InputLayer, SumProductLayer +from cirkit.new.layers import InnerLayer, InputLayer, Layer +from cirkit.new.region_graph import PartitionNode, RegionNode, RGNode from cirkit.new.reparams import Reparameterization -from cirkit.new.utils import Scope +from cirkit.new.utils import OrderedSet -# TODO: double check docs and __repr__ +# TODO: double check __repr__ -# Disable: It's intended for SymbolicLayer to have only these methods. -class SymbolicLayer(ABC): # pylint: disable=too-few-public-methods +# Disable: It's intended for SymbolicLayer to have these many attrs. +class SymbolicLayer(ABC): # pylint: disable=too-many-instance-attributes """The abstract base class for symbolic layers in symbolic circuits.""" - # TODO: Save a RGNode here? allow comparison here? - def __init__(self, scope: Iterable[int]) -> None: + # We accept structure as positional args, and layer spec as kw-only. + # Disable: This __init__ is designed to have these arguments. + # pylint: disable-next=too-many-arguments + def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs. + self, + rg_node: RGNode, + layers_in: Iterable["SymbolicLayer"], + *, + num_units: int, + layer_cls: Type[Layer], + layer_kwargs: Optional[Dict[str, Any]] = None, + reparam: Optional[Reparameterization] = None, + ) -> None: """Construct the SymbolicLayer. Args: - scope (Iterable[int]): The scope of this layer. + rg_node (RGNode): The region graph node corresponding to this layer. + layers_in (Iterable[SymbolicLayer]): The input to this layer, empty for input layers. + num_units (int): The number of units in this layer. + layer_cls (Type[Layer]): The concrete layer class to become. + layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs to initialize \ + layer_cls. Defaults to None. + reparam (Optional[Reparameterization], optional): The reparameterization for layer \ + parameters, can be None if layer_cls has no params. Defaults to None. """ super().__init__() - self.scope = Scope(scope) - assert self.scope, "The scope of a layer in SymbC must be non-empty." + self.rg_node = rg_node + self.scope = rg_node.scope + + # self.inputs is filled using layers_in, while self.outputs is empty until self appears in + # another layer's layers_in. + self.inputs: OrderedSet[SymbolicLayer] = OrderedSet() + self.outputs: OrderedSet[SymbolicLayer] = OrderedSet() + for layer_in in layers_in: + self.inputs.append(layer_in) + layer_in.outputs.append(self) + assert len(self.inputs) == len( + rg_node.inputs + ), "The number of inputs to this layer does not match the RG." - # TODO: should this be a List? what do we need on ordering? - self.inputs: Set[SymbolicLayer] = set() - self.outputs: Set[SymbolicLayer] = set() + self.num_units = num_units + self.layer_cls = layer_cls + # Ignore: Unavoidable for kwargs. + self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {} # type: ignore[misc] + self.reparam = reparam # We require subclasses to implement __repr__ on their own. This also forbids the instantiation # of this abstract class. @@ -37,43 +69,64 @@ def __repr__(self) -> str: str: The str representation of the layer. """ + # __hash__ and __eq__ are defined by default to compare on object identity, i.e., + # (a is b) <=> (a == b) <=> (hash(a) == hash(b)). + + def __lt__(self, other: "SymbolicLayer") -> bool: + """Compare the layer with another layer, for < operator implicitly used in sorting. + + SymbolicLayer is compared by the corresponding RGNode, so that SymbolicCircuit obtains the \ + same ordering as the RegionGraph. + + Args: + other (SymbolicLayer): The other layer to compare with. + + Returns: + bool: Whether self < other. + """ + return self.rg_node < other.rg_node + # Disable: It's intended for SymbolicSumLayer to have only these methods. class SymbolicSumLayer(SymbolicLayer): # pylint: disable=too-few-public-methods """The sum layer in symbolic circuits.""" - # TODO: how to design interface? require kwargs only? + reparam: Reparameterization # Sum layer always have params. + + # Note that the typing for layers_in cannot be refined because all layers are mixed in one + # container in SymbolicCircuit. Same the following two layers. # Disable: This __init__ is designed to have these arguments. # pylint: disable-next=too-many-arguments def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs. self, - scope: Iterable[int], + rg_node: RegionNode, + layers_in: Iterable[SymbolicLayer], + *, num_units: int, - layer_cls: Type[SumProductLayer], # TODO: is it correct to use SumProductLayer? + layer_cls: Type[InnerLayer], # TODO: more specific? layer_kwargs: Optional[Dict[str, Any]] = None, - *, - reparam: Reparameterization, # TODO: how to set default here? + reparam: Reparameterization, ) -> None: """Construct the SymbolicSumLayer. Args: - scope (Iterable[int]): The scope of this layer. - num_units (int): Number of output units in this layer. - layer_cls (Type[SumProductLayer]): The inner (sum) layer class. - layer_kwargs (Optional[Dict[str, Any]]): The parameters for the inner layer class. - reparam (Reparameterization): The reparam. - - Raises: - NotImplementedError: If the shared uncollapsed CP is not implemented. + rg_node (RegionNode): The region node corresponding to this layer. + layers_in (Iterable[SymbolicLayer]): The input to this layer. + num_units (int): The number of units in this layer. + layer_cls (Type[InnerLayer]): The concrete layer class to become. + layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs to initialize \ + layer_cls. Defaults to None. + reparam (Reparameterization): The reparameterization for layer parameters. """ - super().__init__(scope) - self.num_units = num_units - self.layer_cls = layer_cls - # Ignore: Unavoidable for kwargs. - self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {} # type: ignore[misc] - self.params = reparam # TODO: this is not correct, but will be reviewed in new layers. - self.params_in = reparam - self.params_out = reparam + assert rg_node.inputs, "SymbolicSumLayer must be based on an inner RegionNode." + super().__init__( + rg_node, + layers_in, + num_units=num_units, + layer_cls=layer_cls, + layer_kwargs=layer_kwargs, # type: ignore[misc] # Ignore: Unavoidable for kwargs. + reparam=reparam, + ) def __repr__(self) -> str: """Generate the repr string of the layer. @@ -97,19 +150,40 @@ def __repr__(self) -> str: class SymbolicProductLayer(SymbolicLayer): # pylint: disable=too-few-public-methods """The product layer in symbolic circuits.""" - def __init__( # TODO: is it correct to use SumProductLayer? - self, scope: Iterable[int], num_units: int, layer_cls: Type[SumProductLayer] + reparam: None # Product layer has no params. + + # Disable: This __init__ is designed to have these arguments. + # pylint: disable-next=too-many-arguments + def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs. + self, + rg_node: PartitionNode, + layers_in: Iterable[SymbolicLayer], + *, + num_units: int, + layer_cls: Type[InnerLayer], # TODO: more specific? + layer_kwargs: Optional[Dict[str, Any]] = None, + reparam: Optional[Reparameterization] = None, ) -> None: """Construct the SymbolicProductLayer. Args: - scope (Iterable[int]): The scope of this layer. - num_units (int): Number of input units. - layer_cls (Type[SumProductLayer]): The inner (sum) layer class. + rg_node (PartitionNode): The partition node corresponding to this layer. + layers_in (Iterable[SymbolicLayer]): The input to this layer. + num_units (int): The number of units in this layer. + layer_cls (Type[InnerLayer]): The concrete layer class to become. + layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs to initialize \ + layer_cls. Defaults to None. + reparam (Optional[Reparameterization], optional): Ignored. This layer has no params. \ + Defaults to None. """ - super().__init__(scope) - self.num_units = num_units - self.layer_cls = layer_cls + super().__init__( + rg_node, + layers_in, + num_units=num_units, + layer_cls=layer_cls, + layer_kwargs=layer_kwargs, # type: ignore[misc] # Ignore: Unavoidable for kwargs. + reparam=None, + ) def __repr__(self) -> str: """Generate the repr string of the layer. @@ -136,30 +210,35 @@ class SymbolicInputLayer(SymbolicLayer): # pylint: disable=too-few-public-metho # pylint: disable-next=too-many-arguments def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs. self, - scope: Iterable[int], + rg_node: RegionNode, + layers_in: Iterable[SymbolicLayer], + *, num_units: int, layer_cls: Type[InputLayer], layer_kwargs: Optional[Dict[str, Any]] = None, - *, - reparam: Reparameterization, # TODO: how to set default here? + reparam: Optional[Reparameterization] = None, ) -> None: """Construct the SymbolicInputLayer. Args: - scope (Iterable[int]): The scope of this layer. - num_units (int): Number of output units. - layer_cls (Type[ExpFamilyLayer]): The exponential family class. - layer_kwargs (Optional[Dict[str, Any]]): The parameters for - the exponential family class. - reparam (Reparameterization): The reparam. + rg_node (RegionNode): The region node corresponding to this layer. + layers_in (Iterable[SymbolicLayer]): Empty iterable. + num_units (int): The number of units in this layer. + layer_cls (Type[InputLayer]): The concrete layer class to become. + layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs to initialize \ + layer_cls. Defaults to None. + reparam (Optional[Reparameterization], optional): The reparameterization for layer \ + parameters, can be None if layer_cls has no params. Defaults to None. """ - # TODO: many things can be merged to SymbolicLayer.__init__. - super().__init__(scope) - self.num_units = num_units - self.layer_cls = layer_cls - # Ignore: Unavoidable for kwargs. - self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {} # type: ignore[misc] - self.params = reparam + assert not rg_node.inputs, "SymbolicInputLayer must be based on an input RegionNode." + super().__init__( + rg_node, + layers_in, # Should be empty, will be tested in super().__init__ by its length. + num_units=num_units, + layer_cls=layer_cls, + layer_kwargs=layer_kwargs, # type: ignore[misc] # Ignore: Unavoidable for kwargs. + reparam=reparam, + ) def __repr__(self) -> str: """Generate the repr string of the layer. From 384bd62aa8b304fedde778b4613d50986ffe03b3 Mon Sep 17 00:00:00 2001 From: lkct Date: Sat, 9 Dec 2023 00:16:49 +0000 Subject: [PATCH 6/7] refactor symb circuit also add _infer_num_prod_units to inner layers --- cirkit/new/layers/inner/inner.py | 14 + cirkit/new/layers/inner/product/hadamard.py | 13 + cirkit/new/layers/inner/product/kronecker.py | 16 +- cirkit/new/layers/inner/sum/sum.py | 14 + cirkit/new/layers/inner/sum_product/cp.py | 13 + cirkit/new/layers/inner/sum_product/tucker.py | 16 +- cirkit/new/symbolic/symbolic_circuit.py | 260 ++++++------------ 7 files changed, 175 insertions(+), 171 deletions(-) diff --git a/cirkit/new/layers/inner/inner.py b/cirkit/new/layers/inner/inner.py index 443d7870..a967869e 100644 --- a/cirkit/new/layers/inner/inner.py +++ b/cirkit/new/layers/inner/inner.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from typing import Optional from cirkit.new.layers.layer import Layer @@ -32,3 +33,16 @@ def __init__( arity=arity, reparam=reparam, ) + + @classmethod + @abstractmethod + def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int: + """Infer the number of product units in the layer based on given information. + + Args: + num_input_units (int): The number of input units. + arity (int, optional): The arity of the layer. Defaults to 2. + + Returns: + int: The inferred number of product units. + """ diff --git a/cirkit/new/layers/inner/product/hadamard.py b/cirkit/new/layers/inner/product/hadamard.py index c64fcf6e..5d4be5fd 100644 --- a/cirkit/new/layers/inner/product/hadamard.py +++ b/cirkit/new/layers/inner/product/hadamard.py @@ -36,6 +36,19 @@ def __init__( reparam=None, ) + @classmethod + def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int: + """Infer the number of product units in the layer based on given information. + + Args: + num_input_units (int): The number of input units. + arity (int, optional): The arity of the layer. Defaults to 2. + + Returns: + int: The inferred number of product units. + """ + return num_input_units + def forward(self, x: Tensor) -> Tensor: """Run forward pass. diff --git a/cirkit/new/layers/inner/product/kronecker.py b/cirkit/new/layers/inner/product/kronecker.py index 937b7c4b..8d371651 100644 --- a/cirkit/new/layers/inner/product/kronecker.py +++ b/cirkit/new/layers/inner/product/kronecker.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal, Optional, cast from torch import Tensor @@ -38,6 +38,20 @@ def __init__( reparam=None, ) + @classmethod + def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int: + """Infer the number of product units in the layer based on given information. + + Args: + num_input_units (int): The number of input units. + arity (int, optional): The arity of the layer. Defaults to 2. + + Returns: + int: The inferred number of product units. + """ + # Cast: int**int is not guaranteed to be int. + return cast(int, num_input_units**arity) + def forward(self, x: Tensor) -> Tensor: """Run forward pass. diff --git a/cirkit/new/layers/inner/sum/sum.py b/cirkit/new/layers/inner/sum/sum.py index a6f84b6f..f595c1ab 100644 --- a/cirkit/new/layers/inner/sum/sum.py +++ b/cirkit/new/layers/inner/sum/sum.py @@ -1,4 +1,5 @@ import functools +from typing import Literal import torch from torch import nn @@ -33,6 +34,19 @@ def __init__( reparam=reparam, ) + @classmethod + def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> Literal[0]: + """Infer the number of product units in the layer based on given information. + + Args: + num_input_units (int): The number of input units. + arity (int, optional): The arity of the layer. Defaults to 2. + + Returns: + Literal[0]: Sum layers have 0 product units. + """ + return 0 + @torch.no_grad() def reset_parameters(self) -> None: """Reset parameters to default: U(0.01, 0.99).""" diff --git a/cirkit/new/layers/inner/sum_product/cp.py b/cirkit/new/layers/inner/sum_product/cp.py index 164f8c7b..f240b32a 100644 --- a/cirkit/new/layers/inner/sum_product/cp.py +++ b/cirkit/new/layers/inner/sum_product/cp.py @@ -51,6 +51,19 @@ def __init__( ) # DenseLayer already invoked reset_parameters(). + @classmethod + def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int: + """Infer the number of product units in the layer based on given information. + + Args: + num_input_units (int): The number of input units. + arity (int, optional): The arity of the layer. Defaults to 2. + + Returns: + int: The inferred number of product units. + """ + return num_input_units + def forward(self, x: Tensor) -> Tensor: """Run forward pass. diff --git a/cirkit/new/layers/inner/sum_product/tucker.py b/cirkit/new/layers/inner/sum_product/tucker.py index 60d63ff8..a298c9c6 100644 --- a/cirkit/new/layers/inner/sum_product/tucker.py +++ b/cirkit/new/layers/inner/sum_product/tucker.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, cast import torch from torch import Tensor @@ -43,6 +43,20 @@ def __init__( self.reset_parameters() + @classmethod + def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int: + """Infer the number of product units in the layer based on given information. + + Args: + num_input_units (int): The number of input units. + arity (int, optional): The arity of the layer. Defaults to 2. + + Returns: + int: The inferred number of product units. + """ + # Cast: int**int is not guaranteed to be int. + return cast(int, num_input_units**arity) + def _forward_linear(self, x0: Tensor, x1: Tensor) -> Tensor: # shape (*B, I), (*B, J) -> (*B, O). return torch.einsum("oij,...i,...j->...o", self.params(), x0, x1) diff --git a/cirkit/new/symbolic/symbolic_circuit.py b/cirkit/new/symbolic/symbolic_circuit.py index ad0c5dde..9d7adac3 100644 --- a/cirkit/new/symbolic/symbolic_circuit.py +++ b/cirkit/new/symbolic/symbolic_circuit.py @@ -1,7 +1,7 @@ -from typing import Any, Dict, Iterable, Iterator, Optional, Set, Type +from typing import Any, Dict, Iterable, Iterator, Optional, Type -from cirkit.new.layers import InputLayer, SumProductLayer -from cirkit.new.region_graph import RegionGraph, RGNode +from cirkit.new.layers import InnerLayer, InputLayer +from cirkit.new.region_graph import PartitionNode, RegionGraph, RegionNode, RGNode from cirkit.new.reparams import Reparameterization from cirkit.new.symbolic.symbolic_layer import ( SymbolicInputLayer, @@ -24,29 +24,38 @@ class SymbolicCircuit: # pylint: disable=too-many-instance-attributes def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs. self, region_graph: RegionGraph, - layer_cls: Type[SumProductLayer], # TODO: is it correct to use SumProductLayer? - input_cls: Type[InputLayer], - layer_kwargs: Optional[Dict[str, Any]] = None, - input_kwargs: Optional[Dict[str, Any]] = None, *, - # TODO: reparam for input and inner? - reparam: Reparameterization, # TODO: how to set default here? - num_inner_units: int = 2, - num_input_units: int = 2, + num_input_units: int, + num_sum_units: int, num_classes: int = 1, + input_layer_cls: Type[InputLayer], + input_layer_kwargs: Optional[Dict[str, Any]] = None, + input_reparam: Optional[Reparameterization] = None, + sum_layer_cls: Type[InnerLayer], # TODO: more specific? + sum_layer_kwargs: Optional[Dict[str, Any]] = None, + sum_reparam: Reparameterization, + prod_layer_cls: Type[InnerLayer], # TODO: more specific? + prod_layer_kwargs: Optional[Dict[str, Any]] = None, ): """Construct symbolic circuit from a region graph. Args: region_graph (RegionGraph): The region graph to convert. - layer_cls (Type[SumProductLayer]): The layer class for inner layers. - input_cls (Type[ExpFamilyLayer]): The layer class for input layers. - layer_kwargs (Optional[Dict[str, Any]]): The parameters for inner layer class. - input_kwargs (Optional[Dict[str, Any]]): The parameters for input layer class. - reparam (ReparamFactory): The reparametrization function. - num_inner_units (int): Number of units for inner layers. - num_input_units (int): Number of units for input layers. - num_classes (int): Number of classes for the PC. + num_input_units (int): _description_ + num_sum_units (int): _description_ + num_classes (int, optional): _description_. Defaults to 1. + input_layer_cls (Type[InputLayer]): The layer class for input layers. + input_layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs for \ + input layer class. Defaults to None. + input_reparam (Optional[Reparameterization], optional): The reparameterization for \ + input layer parameters, can be None if it has no params. Defaults to None. + sum_layer_cls (Type[InnerLayer]): The layer class for sum layers. + sum_layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs for sum \ + layer class. Defaults to None. + sum_reparam (Reparameterization): The reparameterization for sum layer parameters. + prod_layer_cls (Type[InnerLayer]): The layer class for product layers. + prod_layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs for \ + product layer class. Defaults to None. """ self.region_graph = region_graph self.scope = region_graph.scope @@ -56,147 +65,51 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs. self.is_structured_decomposable = region_graph.is_structured_decomposable self.is_omni_compatible = region_graph.is_omni_compatible - self._layers: Set[SymbolicLayer] = set() - - existing_symbolic_layers: Dict[RGNode, SymbolicLayer] = {} - - # TODO: we need to refactor the construction algorithm. better directly assign inputs or - # outputs to SymbLayers instead of adding them later. - # TODO: too many ignores, need to be checked. - for input_node in region_graph.input_nodes: - rg_node_stack = [(input_node, None)] - - # TODO: verify this while. - while rg_node_stack: # pylint: disable=while-used - rg_node, prev_symbolic_layer = rg_node_stack.pop() - if rg_node in existing_symbolic_layers: - symbolic_layer = existing_symbolic_layers[rg_node] - else: - # Construct a symbolic layer from the region node - symbolic_layer = self._from_region_node( - rg_node, - region_graph, - layer_cls, - input_cls, - layer_kwargs, # type: ignore[misc] - input_kwargs, # type: ignore[misc] - reparam, - num_inner_units, - num_input_units, - num_classes, - ) - existing_symbolic_layers[rg_node] = symbolic_layer - - # Connect previous symbolic layer to the current one - if prev_symbolic_layer: - self._add_edge(prev_symbolic_layer, symbolic_layer) # type: ignore[unreachable] - - # Handle multiple source nodes - for output_rg_node in rg_node.outputs: - rg_node_stack.append((output_rg_node, symbolic_layer)) # type: ignore[arg-type] - - # TODO: the name is not correct. it's not region node. - # TODO: disable for the moment - # pylint: disable-next=no-self-use,too-many-arguments,too-many-locals - def _from_region_node( # type: ignore[misc] # Ignore: Unavoidable for kwargs. - self, - rg_node: RGNode, - region_graph: RegionGraph, - layer_cls: Type[SumProductLayer], - input_cls: Type[InputLayer], - layer_kwargs: Optional[Dict[str, Any]], - input_kwargs: Optional[Dict[str, Any]], - reparam: Reparameterization, - num_inner_units: int, - num_input_units: int, - num_classes: int, - ) -> SymbolicLayer: - """Create a symbolic layer based on the given region node. - - Args: - prev_symbolic_layer (SymbolicLayer): The parent symbolic layer - (starting from input layer) that the current layer grown from. - rg_node (RGNode): The current region graph node to convert to symbolic layer. - region_graph (RegionGraph): The region graph. - layer_cls (Type[SumProductLayer]): The layer class for inner layers. - input_cls (Type[ExpFamilyLayer]): The layer class for input layers. - layer_kwargs (Optional[Dict[str, Any]]): The parameters for inner layer class. - input_kwargs (Optional[Dict[str, Any]]): The parameters for input layer class. - reparam (ReparamFactory): The reparametrization function. - num_inner_units (int): Number of units for inner layers. - num_input_units (int): Number of units for input layers. - num_channels (int): Number of channels (e.g., 3 for RGB pixel) for input layers. - num_classes (int): Number of classes for the PC. - - Returns: - SymbolicLayer: The constructed symbolic layer. - - Raises: - ValueError: If the region node is not valid. - """ - scope = rg_node.scope - inputs = rg_node.inputs - outputs = rg_node.outputs - - symbolic_layer: SymbolicLayer - - if rg_node in region_graph.inner_region_nodes: # type: ignore[operator] - assert len(inputs) == 1, "Inner region nodes should have exactly one input." - - output_units = ( - num_classes - if rg_node in region_graph.output_nodes # type: ignore[operator] - else num_inner_units - ) - - symbolic_layer = SymbolicSumLayer( - scope, output_units, layer_cls, layer_kwargs, reparam=reparam # type: ignore[misc] - ) - - elif rg_node in region_graph.partition_nodes: # type: ignore[operator] - assert len(inputs) == 2, "Partition nodes should have exactly two inputs." - assert len(outputs) > 0, "Partition nodes should have at least one output." - - input0, input1 = inputs - left_input_units = num_inner_units if input0.inputs else num_input_units - right_input_units = num_inner_units if input1.inputs else num_input_units - - assert ( - left_input_units == right_input_units - ), "Input units for partition nodes should match." - - symbolic_layer = SymbolicProductLayer(scope, left_input_units, layer_cls) - - elif rg_node in region_graph.input_nodes: # type: ignore[operator] - symbolic_layer = SymbolicInputLayer( - scope, - num_input_units, - input_cls, - input_kwargs, # type: ignore[misc] - reparam=reparam, - ) - - else: - raise ValueError("Region node not valid.") - - return symbolic_layer - - def _add_edge(self, tail: SymbolicLayer, head: SymbolicLayer) -> None: - """Add edge and layer. - - Args: - tail (SymbolicLayer): The layer the edge originates from. - head (SymbolicLayer): The layer the edge points to. - """ - self._layers.add(tail) - self._layers.add(head) - tail.outputs.add(head) - head.inputs.add(tail) + node_layer: Dict[RGNode, SymbolicLayer] = {} + + for rg_node in region_graph.nodes: + layers_in = (node_layer[node_in] for node_in in rg_node.inputs) + layer: SymbolicLayer + # Ignore: Unavoidable for kwargs. + if isinstance(rg_node, RegionNode) and not rg_node.inputs: # Input node. + layer = SymbolicInputLayer( + rg_node, + layers_in, + num_units=num_input_units, + layer_cls=input_layer_cls, + layer_kwargs=input_layer_kwargs, # type: ignore[misc] + reparam=input_reparam, + ) + elif isinstance(rg_node, RegionNode) and rg_node.inputs: # Inner region node. + layer = SymbolicSumLayer( + rg_node, + layers_in, + num_units=num_sum_units if rg_node.outputs else num_classes, + layer_cls=sum_layer_cls, + layer_kwargs=sum_layer_kwargs, # type: ignore[misc] + reparam=sum_reparam, + ) + elif isinstance(rg_node, PartitionNode): # Partition node. + layer = SymbolicProductLayer( + rg_node, + layers_in, + num_units=prod_layer_cls._infer_num_prod_units( + num_sum_units, len(rg_node.inputs) + ), + layer_cls=prod_layer_cls, + layer_kwargs=prod_layer_kwargs, # type: ignore[misc] + reparam=None, + ) + else: + assert False, "This should not happen." + node_layer[rg_node] = layer + + self._node_layer = node_layer # Insertion order is preserved by dict@py3.7+. ####################################### Properties ####################################### # Here are the basic properties and some structural properties of the SymbC. Some of them are - # simply defined in __init__. Some requires additional treatment and is define below. - # We list everything here to add "docstrings" to them. + # simply defined in __init__. Some requires additional treatment and is define below. We list + # everything here to add "docstrings" to them. scope: Scope """The scope of the SymbC, i.e., the union of scopes of all output layers.""" @@ -208,8 +121,7 @@ def _add_edge(self, tail: SymbolicLayer, head: SymbolicLayer) -> None: """Whether the SymbC is smooth, i.e., all inputs to a sum have the same scope.""" is_decomposable: bool - """Whether the SymbC is decomposable, i.e., inputs to a product have disjoint scopes and their \ - union is the scope of the product.""" + """Whether the SymbC is decomposable, i.e., inputs to a product have disjoint scopes.""" is_structured_decomposable: bool """Whether the SymbC is structured-decomposable, i.e., compatible to itself.""" @@ -228,20 +140,20 @@ def is_compatible( will use the intersection of the scopes of two SymbC. Defaults to None. Returns: - bool: Whether the SymbC is compatible to the other. + bool: Whether self is compatible to other. """ return self.region_graph.is_compatible(other.region_graph, scope=scope) ####################################### Layer views ###################################### - # These are iterable views of the nodes in the SymbC. For efficiency, all these views are - # iterators (implemented as a container iter or a generator), so that they can be chained for - # iteration without instantiating intermediate containers. - # NOTE: There's no ordering graranteed for these views. However RGNode can be sorted if needed. + # These are iterable views of the nodes in the SymbC, and the topological order is guaranteed + # (by a stronger ordering). For efficiency, all these views are iterators (implemented as a + # container iter or a generator), so that they can be chained for iteration without + # instantiating intermediate containers. @property def layers(self) -> Iterator[SymbolicLayer]: - """All the layers in the circuit.""" - return iter(self._layers) + """All layers in the circuit.""" + return iter(self._node_layer.values()) @property def sum_layers(self) -> Iterator[SymbolicSumLayer]: @@ -256,7 +168,12 @@ def sum_layers(self) -> Iterator[SymbolicSumLayer]: @property def product_layers(self) -> Iterator[SymbolicProductLayer]: """Product layers in the circuit, which are always inner layers.""" - return (layer for layer in self.layers if isinstance(layer, SymbolicProductLayer)) + # Ignore: SymbolicProductLayer contains Any. + return ( + layer + for layer in self.layers + if isinstance(layer, SymbolicProductLayer) # type: ignore[misc] + ) @property def input_layers(self) -> Iterator[SymbolicInputLayer]: @@ -270,8 +187,13 @@ def input_layers(self) -> Iterator[SymbolicInputLayer]: @property def output_layers(self) -> Iterator[SymbolicSumLayer]: - """Output layer of the circuit, which are guaranteed to be sum layers.""" + """Output layers of the circuit, which are always sum layers.""" return (layer for layer in self.sum_layers if not layer.outputs) + @property + def inner_layers(self) -> Iterator[SymbolicLayer]: + """Inner (non-input) layers in the circuit.""" + return (layer for layer in self.layers if layer.inputs) + #################################### (De)Serialization ################################### - # TODO: impl? + # TODO: impl? or just save RG and kwargs of SymbC? From c571443d3a9b5d9e814132530738fea7a0c58c6b Mon Sep 17 00:00:00 2001 From: lkct Date: Sat, 9 Dec 2023 00:16:53 +0000 Subject: [PATCH 7/7] fix tests --- tests/new/symbolic/test_symbolic_circuit.py | 59 ++------ tests/new/symbolic/test_symbolic_layer.py | 154 +++++++++++++++----- tests/new/symbolic/test_utils.py | 39 +++++ 3 files changed, 168 insertions(+), 84 deletions(-) create mode 100644 tests/new/symbolic/test_utils.py diff --git a/tests/new/symbolic/test_symbolic_circuit.py b/tests/new/symbolic/test_symbolic_circuit.py index 2cfc9b3b..59d89a6a 100644 --- a/tests/new/symbolic/test_symbolic_circuit.py +++ b/tests/new/symbolic/test_symbolic_circuit.py @@ -1,38 +1,15 @@ # pylint: disable=missing-function-docstring -from typing import Dict -from cirkit.new.layers import CategoricalLayer, CPLayer -from cirkit.new.region_graph import QuadTree, RegionGraph, RegionNode -from cirkit.new.reparams import ExpReparam -from cirkit.new.symbolic import SymbolicCircuit, SymbolicInputLayer, SymbolicSumLayer +from cirkit.new.region_graph import QuadTree +from cirkit.new.symbolic import SymbolicInputLayer, SymbolicSumLayer from cirkit.new.utils import Scope +from tests.new.symbolic.test_utils import get_simple_rg, get_symbolic_circuit_on_rg -def test_symbolic_circuit() -> None: - input_cls = CategoricalLayer - input_kwargs = {"num_categories": 256} - layer_cls = CPLayer - layer_kwargs: Dict[str, None] = {} - reparam = ExpReparam() +def test_symbolic_circuit_simple() -> None: + rg = get_simple_rg() - rg = RegionGraph() - node1 = RegionNode({0}) - node2 = RegionNode({1}) - region = RegionNode({0, 1}) - rg.add_partitioning(region, [node1, node2]) - rg.freeze() - - circuit = SymbolicCircuit( - rg, - layer_cls, - input_cls, - layer_kwargs, - input_kwargs, - reparam=reparam, - num_inner_units=4, - num_input_units=4, - num_classes=1, - ) + circuit = get_symbolic_circuit_on_rg(rg) assert len(list(circuit.layers)) == 4 # Ignore: SymbolicInputLayer contains Any. @@ -45,21 +22,13 @@ def test_symbolic_circuit() -> None: isinstance(layer, SymbolicSumLayer) for layer in circuit.output_layers # type: ignore[misc] ) - rg_2 = QuadTree((4, 4), struct_decomp=True) - circuit_2 = SymbolicCircuit( - rg_2, - layer_cls, - input_cls, - layer_kwargs, - input_kwargs, - reparam=reparam, - num_inner_units=4, - num_input_units=4, - num_classes=1, - ) +def test_symbolic_circuit_qt() -> None: + rg = QuadTree((4, 4), struct_decomp=True) + + circuit = get_symbolic_circuit_on_rg(rg) - assert len(list(circuit_2.layers)) == 46 - assert len(list(circuit_2.input_layers)) == 16 - assert len(list(circuit_2.output_layers)) == 1 - assert circuit_2.scope == Scope(range(16)) + assert len(list(circuit.layers)) == 46 + assert len(list(circuit.input_layers)) == 16 + assert len(list(circuit.output_layers)) == 1 + assert circuit.scope == Scope(range(16)) diff --git a/tests/new/symbolic/test_symbolic_layer.py b/tests/new/symbolic/test_symbolic_layer.py index 1c9f2e20..1b00ded5 100644 --- a/tests/new/symbolic/test_symbolic_layer.py +++ b/tests/new/symbolic/test_symbolic_layer.py @@ -1,50 +1,126 @@ # pylint: disable=missing-function-docstring -from cirkit.new.layers import CategoricalLayer, CPLayer, TuckerLayer +from typing import Dict + +from cirkit.new.layers import CategoricalLayer, DenseLayer, HadamardLayer, TuckerLayer from cirkit.new.reparams import ExpReparam from cirkit.new.symbolic import SymbolicInputLayer, SymbolicProductLayer, SymbolicSumLayer +from tests.new.symbolic.test_utils import get_simple_rg +# TODO: avoid repetition? -def test_symbolic_sum_layer() -> None: - scope = {0, 1} - num_units = 3 - layer = SymbolicSumLayer(scope, num_units, TuckerLayer, reparam=ExpReparam()) - assert "SymbolicSumLayer" in repr(layer) - assert "Scope: Scope({0, 1})" in repr(layer) - assert "Layer Class: TuckerLayer" in repr(layer) - assert "Number of Units: 3" in repr(layer) +def test_symbolic_layers_sum_and_prod() -> None: + rg = get_simple_rg() + input_node0, input_node1 = rg.input_nodes + (partition_node,) = rg.partition_nodes + (region_node,) = rg.inner_region_nodes -def test_symbolic_sum_layer_cp() -> None: - scope = {0, 1} num_units = 3 - layer_kwargs = {"collapsed": False, "shared": False, "arity": 2} - layer = SymbolicSumLayer(scope, num_units, CPLayer, layer_kwargs, reparam=ExpReparam()) - assert "SymbolicSumLayer" in repr(layer) - assert "Scope: Scope({0, 1})" in repr(layer) - assert "Layer Class: CPLayer" in repr(layer) - assert "Number of Units: 3" in repr(layer) - - -def test_symbolic_product_node() -> None: - scope = {0, 1} - num_input_units = 2 - layer = SymbolicProductLayer(scope, num_input_units, TuckerLayer) - assert "SymbolicProductLayer" in repr(layer) - assert "Scope: Scope({0, 1})" in repr(layer) - assert "Layer Class: TuckerLayer" in repr(layer) - assert "Number of Units: 2" in repr(layer) - - -def test_symbolic_input_node() -> None: - scope = {0, 1} + input_kwargs = {"num_categories": 5} + sum_kwargs: Dict[str, None] = {} # Avoid Any. + reparam = ExpReparam() + + input_layer0 = SymbolicInputLayer( + input_node0, + (), + num_units=num_units, + layer_cls=CategoricalLayer, + layer_kwargs=input_kwargs, + reparam=reparam, + ) + assert "SymbolicInputLayer" in repr(input_layer0) + assert "Scope: Scope({0})" in repr(input_layer0) + assert "Input Exp Family Class: CategoricalLayer" in repr(input_layer0) + assert "Layer KWArgs: {'num_categories': 5}" in repr(input_layer0) + assert "Number of Units: 3" in repr(input_layer0) + input_layer1 = SymbolicInputLayer( + input_node1, + (), + num_units=num_units, + layer_cls=CategoricalLayer, + layer_kwargs=input_kwargs, + reparam=reparam, + ) + + prod_layer = SymbolicProductLayer( + partition_node, + (input_layer0, input_layer1), + num_units=num_units, + layer_cls=HadamardLayer, + ) + assert "SymbolicProductLayer" in repr(prod_layer) + assert "Scope: Scope({0, 1})" in repr(prod_layer) + assert "Layer Class: HadamardLayer" in repr(prod_layer) + assert "Number of Units: 3" in repr(prod_layer) + + sum_layer = SymbolicSumLayer( + region_node, + (prod_layer,), + num_units=num_units, + layer_cls=DenseLayer, + layer_kwargs=sum_kwargs, + reparam=reparam, + ) + assert "SymbolicSumLayer" in repr(sum_layer) + assert "Scope: Scope({0, 1})" in repr(sum_layer) + assert "Layer Class: DenseLayer" in repr(sum_layer) + assert "Number of Units: 3" in repr(sum_layer) + + +def test_symbolic_layers_sum_prod() -> None: + rg = get_simple_rg() + input_node0, input_node1 = rg.input_nodes + (partition_node,) = rg.partition_nodes + (region_node,) = rg.inner_region_nodes + num_units = 3 input_kwargs = {"num_categories": 5} - layer = SymbolicInputLayer( - scope, num_units, CategoricalLayer, input_kwargs, reparam=ExpReparam() - ) - assert "SymbolicInputLayer" in repr(layer) - assert "Scope: Scope({0, 1})" in repr(layer) - assert "Input Exp Family Class: CategoricalLayer" in repr(layer) - assert "Layer KWArgs: {'num_categories': 5}" in repr(layer) - assert "Number of Units: 3" in repr(layer) + sum_kwargs: Dict[str, None] = {} # Avoid Any. + reparam = ExpReparam() + + input_layer0 = SymbolicInputLayer( + input_node0, + (), + num_units=num_units, + layer_cls=CategoricalLayer, + layer_kwargs=input_kwargs, + reparam=reparam, + ) + assert "SymbolicInputLayer" in repr(input_layer0) + assert "Scope: Scope({0})" in repr(input_layer0) + assert "Input Exp Family Class: CategoricalLayer" in repr(input_layer0) + assert "Layer KWArgs: {'num_categories': 5}" in repr(input_layer0) + assert "Number of Units: 3" in repr(input_layer0) + input_layer1 = SymbolicInputLayer( + input_node1, + (), + num_units=num_units, + layer_cls=CategoricalLayer, + layer_kwargs=input_kwargs, + reparam=reparam, + ) + + prod_layer = SymbolicProductLayer( + partition_node, + (input_layer0, input_layer1), + num_units=num_units**2, + layer_cls=TuckerLayer, + ) + assert "SymbolicProductLayer" in repr(prod_layer) + assert "Scope: Scope({0, 1})" in repr(prod_layer) + assert "Layer Class: TuckerLayer" in repr(prod_layer) + assert "Number of Units: 9" in repr(prod_layer) + + sum_layer = SymbolicSumLayer( + region_node, + (prod_layer,), + num_units=num_units, + layer_cls=TuckerLayer, + layer_kwargs=sum_kwargs, + reparam=reparam, + ) + assert "SymbolicSumLayer" in repr(sum_layer) + assert "Scope: Scope({0, 1})" in repr(sum_layer) + assert "Layer Class: TuckerLayer" in repr(sum_layer) + assert "Number of Units: 3" in repr(sum_layer) diff --git a/tests/new/symbolic/test_utils.py b/tests/new/symbolic/test_utils.py new file mode 100644 index 00000000..16c7fa0e --- /dev/null +++ b/tests/new/symbolic/test_utils.py @@ -0,0 +1,39 @@ +# pylint: disable=missing-function-docstring,missing-return-doc +from typing import Dict + +from cirkit.new.layers import CategoricalLayer, CPLayer +from cirkit.new.region_graph import RegionGraph, RegionNode +from cirkit.new.reparams import ExpReparam +from cirkit.new.symbolic import SymbolicCircuit + + +def get_simple_rg() -> RegionGraph: + rg = RegionGraph() + node1 = RegionNode({0}) + node2 = RegionNode({1}) + region = RegionNode({0, 1}) + rg.add_partitioning(region, (node1, node2)) + return rg.freeze() + + +def get_symbolic_circuit_on_rg(rg: RegionGraph) -> SymbolicCircuit: + num_units = 4 + input_cls = CategoricalLayer + input_kwargs = {"num_categories": 256} + inner_cls = CPLayer + inner_kwargs: Dict[str, None] = {} # Avoid Any. + reparam = ExpReparam() + + return SymbolicCircuit( + rg, + num_input_units=num_units, + num_sum_units=num_units, + input_layer_cls=input_cls, + input_layer_kwargs=input_kwargs, + input_reparam=reparam, + sum_layer_cls=inner_cls, + sum_layer_kwargs=inner_kwargs, + sum_reparam=reparam, + prod_layer_cls=inner_cls, + prod_layer_kwargs=None, + )