Skip to content

Commit

Permalink
Merge pull request #165 from april-tools/rgnode_sorting
Browse files Browse the repository at this point in the history
Introduce topological ordering to RG
  • Loading branch information
lkct authored Dec 13, 2023
2 parents 5bde802 + c571443 commit bc9d592
Show file tree
Hide file tree
Showing 21 changed files with 1,000 additions and 478 deletions.
14 changes: 14 additions & 0 deletions cirkit/new/layers/inner/inner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import abstractmethod
from typing import Optional

from cirkit.new.layers.layer import Layer
Expand Down Expand Up @@ -32,3 +33,16 @@ def __init__(
arity=arity,
reparam=reparam,
)

@classmethod
@abstractmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int:
"""Infer the number of product units in the layer based on given information.
Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.
Returns:
int: The inferred number of product units.
"""
13 changes: 13 additions & 0 deletions cirkit/new/layers/inner/product/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ def __init__(
reparam=None,
)

@classmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int:
"""Infer the number of product units in the layer based on given information.
Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.
Returns:
int: The inferred number of product units.
"""
return num_input_units

def forward(self, x: Tensor) -> Tensor:
"""Run forward pass.
Expand Down
16 changes: 15 additions & 1 deletion cirkit/new/layers/inner/product/kronecker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Literal, Optional, cast

from torch import Tensor

Expand Down Expand Up @@ -38,6 +38,20 @@ def __init__(
reparam=None,
)

@classmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int:
"""Infer the number of product units in the layer based on given information.
Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.
Returns:
int: The inferred number of product units.
"""
# Cast: int**int is not guaranteed to be int.
return cast(int, num_input_units**arity)

def forward(self, x: Tensor) -> Tensor:
"""Run forward pass.
Expand Down
14 changes: 14 additions & 0 deletions cirkit/new/layers/inner/sum/sum.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from typing import Literal

import torch
from torch import nn
Expand Down Expand Up @@ -33,6 +34,19 @@ def __init__(
reparam=reparam,
)

@classmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> Literal[0]:
"""Infer the number of product units in the layer based on given information.
Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.
Returns:
Literal[0]: Sum layers have 0 product units.
"""
return 0

@torch.no_grad()
def reset_parameters(self) -> None:
"""Reset parameters to default: U(0.01, 0.99)."""
Expand Down
13 changes: 13 additions & 0 deletions cirkit/new/layers/inner/sum_product/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ def __init__(
)
# DenseLayer already invoked reset_parameters().

@classmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int:
"""Infer the number of product units in the layer based on given information.
Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.
Returns:
int: The inferred number of product units.
"""
return num_input_units

def forward(self, x: Tensor) -> Tensor:
"""Run forward pass.
Expand Down
16 changes: 15 additions & 1 deletion cirkit/new/layers/inner/sum_product/tucker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, cast

import torch
from torch import Tensor
Expand Down Expand Up @@ -43,6 +43,20 @@ def __init__(

self.reset_parameters()

@classmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int:
"""Infer the number of product units in the layer based on given information.
Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.
Returns:
int: The inferred number of product units.
"""
# Cast: int**int is not guaranteed to be int.
return cast(int, num_input_units**arity)

def _forward_linear(self, x0: Tensor, x1: Tensor) -> Tensor:
# shape (*B, I), (*B, J) -> (*B, O).
return torch.einsum("oij,...i,...j->...o", self.params(), x0, x1)
Expand Down
7 changes: 4 additions & 3 deletions cirkit/new/layers/input/input.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, Optional

from cirkit.new.layers.layer import Layer
from cirkit.new.reparams import Reparameterization
Expand All @@ -22,15 +22,16 @@ def __init__(
num_input_units: int,
num_output_units: int,
arity: Literal[1] = 1,
reparam: Reparameterization,
reparam: Optional[Reparameterization] = None,
) -> None:
"""Init class.
Args:
num_input_units (int): The number of input units, i.e. number of channels for variables.
num_output_units (int): The number of output units.
arity (Literal[1], optional): The arity of the layer, must be 1. Defaults to 1.
reparam (Reparameterization): The reparameterization for layer parameters.
reparam (Optional[Reparameterization], optional): The reparameterization for layer \
parameters, can be None if the layer has no params. Defaults to None.
"""
assert arity == 1, "We define arity=1 for InputLayer."
super().__init__(
Expand Down
7 changes: 4 additions & 3 deletions cirkit/new/region_graph/algorithms/poon_domingos.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from collections import deque
from typing import Deque, Dict, FrozenSet, List, Optional, Sequence, Union, cast
from typing import Deque, Dict, List, Optional, Sequence, Union, cast

from cirkit.new.region_graph.algorithms.utils import HyperCube, HypercubeToScope
from cirkit.new.region_graph.region_graph import RegionGraph
from cirkit.new.region_graph.rg_node import RegionNode
from cirkit.new.utils import Scope

# TODO: test what is constructed here


def _get_region_node_by_scope(graph: RegionGraph, scope: FrozenSet[int]) -> RegionNode:
def _get_region_node_by_scope(graph: RegionGraph, scope: Scope) -> RegionNode:
"""Find a RegionNode with a specific scope in the RG, and construct one if not found.
Args:
graph (RegionGraph): The region graph to find in.
scope (Iterable[int]): The scope to find.
scope (Scope): The scope to find.
Returns:
RegionNode: The RegionNode found or constructed.
Expand Down
13 changes: 9 additions & 4 deletions cirkit/new/region_graph/algorithms/quad_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from cirkit.new.region_graph.algorithms.utils import HypercubeToScope
from cirkit.new.region_graph.region_graph import RegionGraph
from cirkit.new.region_graph.rg_node import RegionNode
from cirkit.new.utils import Scope

# TODO: now should work with H!=W but need tests

Expand Down Expand Up @@ -96,8 +97,12 @@ def QuadTree(shape: Tuple[int, int], *, struct_decomp: bool = False) -> RegionGr
graph = RegionGraph()
hypercube_to_scope = HypercubeToScope(shape)

# The regions of the current layer, in shape (H, W). scope={-1} is for padding.
layer: List[List[RegionNode]] = [[RegionNode({-1})] * (width + 1) for _ in range(height + 1)]
# Padding using Scope({num_var}) which is one larger than range(num_var).
pad_scope = Scope({height * width})
# The regions of the current layer, in shape (H, W).
layer: List[List[RegionNode]] = [
[RegionNode(pad_scope)] * (width + 1) for _ in range(height + 1)
]

# Add univariate input nodes.
for i, j in itertools.product(range(height), range(width)):
Expand All @@ -114,7 +119,7 @@ def QuadTree(shape: Tuple[int, int], *, struct_decomp: bool = False) -> RegionGr

height = (prev_height + 1) // 2
width = (prev_width + 1) // 2
layer = [[RegionNode({-1})] * (width + 1) for _ in range(height + 1)]
layer = [[RegionNode(pad_scope)] * (width + 1) for _ in range(height + 1)]

for i, j in itertools.product(range(height), range(width)):
regions = [ # Filter valid regions in the 4 possible sub-regions.
Expand All @@ -125,7 +130,7 @@ def QuadTree(shape: Tuple[int, int], *, struct_decomp: bool = False) -> RegionGr
prev_layer[i * 2 + 1][j * 2],
prev_layer[i * 2 + 1][j * 2 + 1],
)
if node.scope != {-1}
if node.scope != pad_scope
]
if len(regions) == 1:
node = regions[0]
Expand Down
8 changes: 4 additions & 4 deletions cirkit/new/region_graph/algorithms/random_binary_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def _partition_node_randomly(
Returns:
List[RegionNode]: The region nodes forming the partitioning.
"""
scope = list(node.scope)
random.shuffle(scope)
scope_list = list(node.scope)
random.shuffle(scope_list)

split: NDArray[np.float64] # Unnormalized split points including 0 and 1.
if proportions is None:
Expand All @@ -39,15 +39,15 @@ def _partition_node_randomly(

# Ignore: Numpy has typing issues.
split_point: List[int] = (
np.around(split / split[-1] * len(scope)) # type: ignore[assignment,misc]
np.around(split / split[-1] * len(scope_list)) # type: ignore[assignment,misc]
.astype(np.int64)
.tolist()
)

region_nodes: List[RegionNode] = []
for l, r in zip(split_point[:-1], split_point[1:]):
if l < r: # A region must have as least one var, otherwise we skip it.
region_node = RegionNode(scope[l:r])
region_node = RegionNode(scope_list[l:r])
region_nodes.append(region_node)

if len(region_nodes) == 1:
Expand Down
14 changes: 8 additions & 6 deletions cirkit/new/region_graph/algorithms/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Dict, FrozenSet, Sequence, Tuple
from typing import Dict, Sequence, Tuple

import numpy as np
from numpy.typing import NDArray

from cirkit.new.utils import Scope

HyperCube = Tuple[Tuple[int, ...], Tuple[int, ...]] # Just to shorten the annotation.
"""A hypercube represented by "top-left" and "bottom-right" coordinates (cut points)."""


class HypercubeToScope(Dict[HyperCube, FrozenSet[int]]):
class HypercubeToScope(Dict[HyperCube, Scope]):
"""Helper class to map sub-hypercubes to scopes with caching for variables arranged in a \
hypercube.
Expand All @@ -29,27 +31,27 @@ def __init__(self, shape: Sequence[int]) -> None:
# We assume it's feasible to save the whole hypercube, since it should be the whole region.
self.hypercube: NDArray[np.int64] = np.arange(np.prod(shape), dtype=np.int64).reshape(shape)

def __missing__(self, key: HyperCube) -> FrozenSet[int]:
def __missing__(self, key: HyperCube) -> Scope:
"""Construct the item when not exist in the dict.
Args:
key (HyperCube): The key that is missing from the dict, i.e., a hypercube that is \
visited for the first time.
Returns:
FrozenSet[int]: The value for the key, i.e., the corresponding scope.
Scope: The value for the key, i.e., the corresponding scope.
"""
point1, point2 = key # HyperCube is from point1 to point2.

assert (
len(point1) == len(point2) == self.ndims
), "The shape of the HyperCube is not correct."
), "The dimension of the HyperCube is not correct."
assert all(
0 <= x1 < x2 <= shape for x1, x2, shape in zip(point1, point2, self.shape)
), "The HyperCube is empty."

# Ignore: Numpy has typing issues.
return frozenset(
return Scope(
self.hypercube[
tuple(slice(x1, x2) for x1, x2 in zip(point1, point2)) # type: ignore[misc]
]
Expand Down
Loading

0 comments on commit bc9d592

Please sign in to comment.