Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce topological ordering to RG #165

Merged
merged 7 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cirkit/new/layers/inner/inner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import abstractmethod
from typing import Optional

from cirkit.new.layers.layer import Layer
Expand Down Expand Up @@ -32,3 +33,16 @@ def __init__(
arity=arity,
reparam=reparam,
)

@classmethod
@abstractmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int:
"""Infer the number of product units in the layer based on given information.

Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.

Returns:
int: The inferred number of product units.
"""
13 changes: 13 additions & 0 deletions cirkit/new/layers/inner/product/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ def __init__(
reparam=None,
)

@classmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int:
"""Infer the number of product units in the layer based on given information.

Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.

Returns:
int: The inferred number of product units.
"""
return num_input_units

def forward(self, x: Tensor) -> Tensor:
"""Run forward pass.

Expand Down
16 changes: 15 additions & 1 deletion cirkit/new/layers/inner/product/kronecker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Literal, Optional, cast

from torch import Tensor

Expand Down Expand Up @@ -38,6 +38,20 @@ def __init__(
reparam=None,
)

@classmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int:
"""Infer the number of product units in the layer based on given information.

Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.

Returns:
int: The inferred number of product units.
"""
# Cast: int**int is not guaranteed to be int.
return cast(int, num_input_units**arity)

def forward(self, x: Tensor) -> Tensor:
"""Run forward pass.

Expand Down
14 changes: 14 additions & 0 deletions cirkit/new/layers/inner/sum/sum.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from typing import Literal

import torch
from torch import nn
Expand Down Expand Up @@ -33,6 +34,19 @@ def __init__(
reparam=reparam,
)

@classmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> Literal[0]:
"""Infer the number of product units in the layer based on given information.

Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.

Returns:
Literal[0]: Sum layers have 0 product units.
"""
return 0

@torch.no_grad()
def reset_parameters(self) -> None:
"""Reset parameters to default: U(0.01, 0.99)."""
Expand Down
13 changes: 13 additions & 0 deletions cirkit/new/layers/inner/sum_product/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ def __init__(
)
# DenseLayer already invoked reset_parameters().

@classmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int:
"""Infer the number of product units in the layer based on given information.

Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.

Returns:
int: The inferred number of product units.
"""
return num_input_units

def forward(self, x: Tensor) -> Tensor:
"""Run forward pass.

Expand Down
16 changes: 15 additions & 1 deletion cirkit/new/layers/inner/sum_product/tucker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, cast

import torch
from torch import Tensor
Expand Down Expand Up @@ -43,6 +43,20 @@ def __init__(

self.reset_parameters()

@classmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int:
"""Infer the number of product units in the layer based on given information.

Args:
num_input_units (int): The number of input units.
arity (int, optional): The arity of the layer. Defaults to 2.

Returns:
int: The inferred number of product units.
"""
# Cast: int**int is not guaranteed to be int.
return cast(int, num_input_units**arity)

def _forward_linear(self, x0: Tensor, x1: Tensor) -> Tensor:
# shape (*B, I), (*B, J) -> (*B, O).
return torch.einsum("oij,...i,...j->...o", self.params(), x0, x1)
Expand Down
7 changes: 4 additions & 3 deletions cirkit/new/layers/input/input.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, Optional

from cirkit.new.layers.layer import Layer
from cirkit.new.reparams import Reparameterization
Expand All @@ -22,15 +22,16 @@ def __init__(
num_input_units: int,
num_output_units: int,
arity: Literal[1] = 1,
reparam: Reparameterization,
reparam: Optional[Reparameterization] = None,
) -> None:
"""Init class.

Args:
num_input_units (int): The number of input units, i.e. number of channels for variables.
num_output_units (int): The number of output units.
arity (Literal[1], optional): The arity of the layer, must be 1. Defaults to 1.
reparam (Reparameterization): The reparameterization for layer parameters.
reparam (Optional[Reparameterization], optional): The reparameterization for layer \
parameters, can be None if the layer has no params. Defaults to None.
"""
assert arity == 1, "We define arity=1 for InputLayer."
super().__init__(
Expand Down
7 changes: 4 additions & 3 deletions cirkit/new/region_graph/algorithms/poon_domingos.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from collections import deque
from typing import Deque, Dict, FrozenSet, List, Optional, Sequence, Union, cast
from typing import Deque, Dict, List, Optional, Sequence, Union, cast

from cirkit.new.region_graph.algorithms.utils import HyperCube, HypercubeToScope
from cirkit.new.region_graph.region_graph import RegionGraph
from cirkit.new.region_graph.rg_node import RegionNode
from cirkit.new.utils import Scope

# TODO: test what is constructed here


def _get_region_node_by_scope(graph: RegionGraph, scope: FrozenSet[int]) -> RegionNode:
def _get_region_node_by_scope(graph: RegionGraph, scope: Scope) -> RegionNode:
"""Find a RegionNode with a specific scope in the RG, and construct one if not found.

Args:
graph (RegionGraph): The region graph to find in.
scope (Iterable[int]): The scope to find.
scope (Scope): The scope to find.

Returns:
RegionNode: The RegionNode found or constructed.
Expand Down
13 changes: 9 additions & 4 deletions cirkit/new/region_graph/algorithms/quad_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from cirkit.new.region_graph.algorithms.utils import HypercubeToScope
from cirkit.new.region_graph.region_graph import RegionGraph
from cirkit.new.region_graph.rg_node import RegionNode
from cirkit.new.utils import Scope

# TODO: now should work with H!=W but need tests

Expand Down Expand Up @@ -96,8 +97,12 @@ def QuadTree(shape: Tuple[int, int], *, struct_decomp: bool = False) -> RegionGr
graph = RegionGraph()
hypercube_to_scope = HypercubeToScope(shape)

# The regions of the current layer, in shape (H, W). scope={-1} is for padding.
layer: List[List[RegionNode]] = [[RegionNode({-1})] * (width + 1) for _ in range(height + 1)]
# Padding using Scope({num_var}) which is one larger than range(num_var).
pad_scope = Scope({height * width})
# The regions of the current layer, in shape (H, W).
layer: List[List[RegionNode]] = [
[RegionNode(pad_scope)] * (width + 1) for _ in range(height + 1)
]

# Add univariate input nodes.
for i, j in itertools.product(range(height), range(width)):
Expand All @@ -114,7 +119,7 @@ def QuadTree(shape: Tuple[int, int], *, struct_decomp: bool = False) -> RegionGr

height = (prev_height + 1) // 2
width = (prev_width + 1) // 2
layer = [[RegionNode({-1})] * (width + 1) for _ in range(height + 1)]
layer = [[RegionNode(pad_scope)] * (width + 1) for _ in range(height + 1)]

for i, j in itertools.product(range(height), range(width)):
regions = [ # Filter valid regions in the 4 possible sub-regions.
Expand All @@ -125,7 +130,7 @@ def QuadTree(shape: Tuple[int, int], *, struct_decomp: bool = False) -> RegionGr
prev_layer[i * 2 + 1][j * 2],
prev_layer[i * 2 + 1][j * 2 + 1],
)
if node.scope != {-1}
if node.scope != pad_scope
]
if len(regions) == 1:
node = regions[0]
Expand Down
8 changes: 4 additions & 4 deletions cirkit/new/region_graph/algorithms/random_binary_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def _partition_node_randomly(
Returns:
List[RegionNode]: The region nodes forming the partitioning.
"""
scope = list(node.scope)
random.shuffle(scope)
scope_list = list(node.scope)
random.shuffle(scope_list)

split: NDArray[np.float64] # Unnormalized split points including 0 and 1.
if proportions is None:
Expand All @@ -39,15 +39,15 @@ def _partition_node_randomly(

# Ignore: Numpy has typing issues.
split_point: List[int] = (
np.around(split / split[-1] * len(scope)) # type: ignore[assignment,misc]
np.around(split / split[-1] * len(scope_list)) # type: ignore[assignment,misc]
.astype(np.int64)
.tolist()
)

region_nodes: List[RegionNode] = []
for l, r in zip(split_point[:-1], split_point[1:]):
if l < r: # A region must have as least one var, otherwise we skip it.
region_node = RegionNode(scope[l:r])
region_node = RegionNode(scope_list[l:r])
region_nodes.append(region_node)

if len(region_nodes) == 1:
Expand Down
14 changes: 8 additions & 6 deletions cirkit/new/region_graph/algorithms/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Dict, FrozenSet, Sequence, Tuple
from typing import Dict, Sequence, Tuple

import numpy as np
from numpy.typing import NDArray

from cirkit.new.utils import Scope

HyperCube = Tuple[Tuple[int, ...], Tuple[int, ...]] # Just to shorten the annotation.
"""A hypercube represented by "top-left" and "bottom-right" coordinates (cut points)."""


class HypercubeToScope(Dict[HyperCube, FrozenSet[int]]):
class HypercubeToScope(Dict[HyperCube, Scope]):
"""Helper class to map sub-hypercubes to scopes with caching for variables arranged in a \
hypercube.

Expand All @@ -29,27 +31,27 @@ def __init__(self, shape: Sequence[int]) -> None:
# We assume it's feasible to save the whole hypercube, since it should be the whole region.
self.hypercube: NDArray[np.int64] = np.arange(np.prod(shape), dtype=np.int64).reshape(shape)

def __missing__(self, key: HyperCube) -> FrozenSet[int]:
def __missing__(self, key: HyperCube) -> Scope:
"""Construct the item when not exist in the dict.

Args:
key (HyperCube): The key that is missing from the dict, i.e., a hypercube that is \
visited for the first time.

Returns:
FrozenSet[int]: The value for the key, i.e., the corresponding scope.
Scope: The value for the key, i.e., the corresponding scope.
"""
point1, point2 = key # HyperCube is from point1 to point2.

assert (
len(point1) == len(point2) == self.ndims
), "The shape of the HyperCube is not correct."
), "The dimension of the HyperCube is not correct."
assert all(
0 <= x1 < x2 <= shape for x1, x2, shape in zip(point1, point2, self.shape)
), "The HyperCube is empty."

# Ignore: Numpy has typing issues.
return frozenset(
return Scope(
self.hypercube[
tuple(slice(x1, x2) for x1, x2 in zip(point1, point2)) # type: ignore[misc]
]
Expand Down
Loading