Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lkct committed Dec 10, 2023
1 parent 384bd62 commit c571443
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 84 deletions.
59 changes: 14 additions & 45 deletions tests/new/symbolic/test_symbolic_circuit.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,15 @@
# pylint: disable=missing-function-docstring
from typing import Dict

from cirkit.new.layers import CategoricalLayer, CPLayer
from cirkit.new.region_graph import QuadTree, RegionGraph, RegionNode
from cirkit.new.reparams import ExpReparam
from cirkit.new.symbolic import SymbolicCircuit, SymbolicInputLayer, SymbolicSumLayer
from cirkit.new.region_graph import QuadTree
from cirkit.new.symbolic import SymbolicInputLayer, SymbolicSumLayer
from cirkit.new.utils import Scope
from tests.new.symbolic.test_utils import get_simple_rg, get_symbolic_circuit_on_rg


def test_symbolic_circuit() -> None:
input_cls = CategoricalLayer
input_kwargs = {"num_categories": 256}
layer_cls = CPLayer
layer_kwargs: Dict[str, None] = {}
reparam = ExpReparam()
def test_symbolic_circuit_simple() -> None:
rg = get_simple_rg()

rg = RegionGraph()
node1 = RegionNode({0})
node2 = RegionNode({1})
region = RegionNode({0, 1})
rg.add_partitioning(region, [node1, node2])
rg.freeze()

circuit = SymbolicCircuit(
rg,
layer_cls,
input_cls,
layer_kwargs,
input_kwargs,
reparam=reparam,
num_inner_units=4,
num_input_units=4,
num_classes=1,
)
circuit = get_symbolic_circuit_on_rg(rg)

assert len(list(circuit.layers)) == 4
# Ignore: SymbolicInputLayer contains Any.
Expand All @@ -45,21 +22,13 @@ def test_symbolic_circuit() -> None:
isinstance(layer, SymbolicSumLayer) for layer in circuit.output_layers # type: ignore[misc]
)

rg_2 = QuadTree((4, 4), struct_decomp=True)

circuit_2 = SymbolicCircuit(
rg_2,
layer_cls,
input_cls,
layer_kwargs,
input_kwargs,
reparam=reparam,
num_inner_units=4,
num_input_units=4,
num_classes=1,
)
def test_symbolic_circuit_qt() -> None:
rg = QuadTree((4, 4), struct_decomp=True)

circuit = get_symbolic_circuit_on_rg(rg)

assert len(list(circuit_2.layers)) == 46
assert len(list(circuit_2.input_layers)) == 16
assert len(list(circuit_2.output_layers)) == 1
assert circuit_2.scope == Scope(range(16))
assert len(list(circuit.layers)) == 46
assert len(list(circuit.input_layers)) == 16
assert len(list(circuit.output_layers)) == 1
assert circuit.scope == Scope(range(16))
154 changes: 115 additions & 39 deletions tests/new/symbolic/test_symbolic_layer.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,126 @@
# pylint: disable=missing-function-docstring

from cirkit.new.layers import CategoricalLayer, CPLayer, TuckerLayer
from typing import Dict

from cirkit.new.layers import CategoricalLayer, DenseLayer, HadamardLayer, TuckerLayer
from cirkit.new.reparams import ExpReparam
from cirkit.new.symbolic import SymbolicInputLayer, SymbolicProductLayer, SymbolicSumLayer
from tests.new.symbolic.test_utils import get_simple_rg

# TODO: avoid repetition?

def test_symbolic_sum_layer() -> None:
scope = {0, 1}
num_units = 3
layer = SymbolicSumLayer(scope, num_units, TuckerLayer, reparam=ExpReparam())
assert "SymbolicSumLayer" in repr(layer)
assert "Scope: Scope({0, 1})" in repr(layer)
assert "Layer Class: TuckerLayer" in repr(layer)
assert "Number of Units: 3" in repr(layer)

def test_symbolic_layers_sum_and_prod() -> None:
rg = get_simple_rg()
input_node0, input_node1 = rg.input_nodes
(partition_node,) = rg.partition_nodes
(region_node,) = rg.inner_region_nodes

def test_symbolic_sum_layer_cp() -> None:
scope = {0, 1}
num_units = 3
layer_kwargs = {"collapsed": False, "shared": False, "arity": 2}
layer = SymbolicSumLayer(scope, num_units, CPLayer, layer_kwargs, reparam=ExpReparam())
assert "SymbolicSumLayer" in repr(layer)
assert "Scope: Scope({0, 1})" in repr(layer)
assert "Layer Class: CPLayer" in repr(layer)
assert "Number of Units: 3" in repr(layer)


def test_symbolic_product_node() -> None:
scope = {0, 1}
num_input_units = 2
layer = SymbolicProductLayer(scope, num_input_units, TuckerLayer)
assert "SymbolicProductLayer" in repr(layer)
assert "Scope: Scope({0, 1})" in repr(layer)
assert "Layer Class: TuckerLayer" in repr(layer)
assert "Number of Units: 2" in repr(layer)


def test_symbolic_input_node() -> None:
scope = {0, 1}
input_kwargs = {"num_categories": 5}
sum_kwargs: Dict[str, None] = {} # Avoid Any.
reparam = ExpReparam()

input_layer0 = SymbolicInputLayer(
input_node0,
(),
num_units=num_units,
layer_cls=CategoricalLayer,
layer_kwargs=input_kwargs,
reparam=reparam,
)
assert "SymbolicInputLayer" in repr(input_layer0)
assert "Scope: Scope({0})" in repr(input_layer0)
assert "Input Exp Family Class: CategoricalLayer" in repr(input_layer0)
assert "Layer KWArgs: {'num_categories': 5}" in repr(input_layer0)
assert "Number of Units: 3" in repr(input_layer0)
input_layer1 = SymbolicInputLayer(
input_node1,
(),
num_units=num_units,
layer_cls=CategoricalLayer,
layer_kwargs=input_kwargs,
reparam=reparam,
)

prod_layer = SymbolicProductLayer(
partition_node,
(input_layer0, input_layer1),
num_units=num_units,
layer_cls=HadamardLayer,
)
assert "SymbolicProductLayer" in repr(prod_layer)
assert "Scope: Scope({0, 1})" in repr(prod_layer)
assert "Layer Class: HadamardLayer" in repr(prod_layer)
assert "Number of Units: 3" in repr(prod_layer)

sum_layer = SymbolicSumLayer(
region_node,
(prod_layer,),
num_units=num_units,
layer_cls=DenseLayer,
layer_kwargs=sum_kwargs,
reparam=reparam,
)
assert "SymbolicSumLayer" in repr(sum_layer)
assert "Scope: Scope({0, 1})" in repr(sum_layer)
assert "Layer Class: DenseLayer" in repr(sum_layer)
assert "Number of Units: 3" in repr(sum_layer)


def test_symbolic_layers_sum_prod() -> None:
rg = get_simple_rg()
input_node0, input_node1 = rg.input_nodes
(partition_node,) = rg.partition_nodes
(region_node,) = rg.inner_region_nodes

num_units = 3
input_kwargs = {"num_categories": 5}
layer = SymbolicInputLayer(
scope, num_units, CategoricalLayer, input_kwargs, reparam=ExpReparam()
)
assert "SymbolicInputLayer" in repr(layer)
assert "Scope: Scope({0, 1})" in repr(layer)
assert "Input Exp Family Class: CategoricalLayer" in repr(layer)
assert "Layer KWArgs: {'num_categories': 5}" in repr(layer)
assert "Number of Units: 3" in repr(layer)
sum_kwargs: Dict[str, None] = {} # Avoid Any.
reparam = ExpReparam()

input_layer0 = SymbolicInputLayer(
input_node0,
(),
num_units=num_units,
layer_cls=CategoricalLayer,
layer_kwargs=input_kwargs,
reparam=reparam,
)
assert "SymbolicInputLayer" in repr(input_layer0)
assert "Scope: Scope({0})" in repr(input_layer0)
assert "Input Exp Family Class: CategoricalLayer" in repr(input_layer0)
assert "Layer KWArgs: {'num_categories': 5}" in repr(input_layer0)
assert "Number of Units: 3" in repr(input_layer0)
input_layer1 = SymbolicInputLayer(
input_node1,
(),
num_units=num_units,
layer_cls=CategoricalLayer,
layer_kwargs=input_kwargs,
reparam=reparam,
)

prod_layer = SymbolicProductLayer(
partition_node,
(input_layer0, input_layer1),
num_units=num_units**2,
layer_cls=TuckerLayer,
)
assert "SymbolicProductLayer" in repr(prod_layer)
assert "Scope: Scope({0, 1})" in repr(prod_layer)
assert "Layer Class: TuckerLayer" in repr(prod_layer)
assert "Number of Units: 9" in repr(prod_layer)

sum_layer = SymbolicSumLayer(
region_node,
(prod_layer,),
num_units=num_units,
layer_cls=TuckerLayer,
layer_kwargs=sum_kwargs,
reparam=reparam,
)
assert "SymbolicSumLayer" in repr(sum_layer)
assert "Scope: Scope({0, 1})" in repr(sum_layer)
assert "Layer Class: TuckerLayer" in repr(sum_layer)
assert "Number of Units: 3" in repr(sum_layer)
39 changes: 39 additions & 0 deletions tests/new/symbolic/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# pylint: disable=missing-function-docstring,missing-return-doc
from typing import Dict

from cirkit.new.layers import CategoricalLayer, CPLayer
from cirkit.new.region_graph import RegionGraph, RegionNode
from cirkit.new.reparams import ExpReparam
from cirkit.new.symbolic import SymbolicCircuit


def get_simple_rg() -> RegionGraph:
rg = RegionGraph()
node1 = RegionNode({0})
node2 = RegionNode({1})
region = RegionNode({0, 1})
rg.add_partitioning(region, (node1, node2))
return rg.freeze()


def get_symbolic_circuit_on_rg(rg: RegionGraph) -> SymbolicCircuit:
num_units = 4
input_cls = CategoricalLayer
input_kwargs = {"num_categories": 256}
inner_cls = CPLayer
inner_kwargs: Dict[str, None] = {} # Avoid Any.
reparam = ExpReparam()

return SymbolicCircuit(
rg,
num_input_units=num_units,
num_sum_units=num_units,
input_layer_cls=input_cls,
input_layer_kwargs=input_kwargs,
input_reparam=reparam,
sum_layer_cls=inner_cls,
sum_layer_kwargs=inner_kwargs,
sum_reparam=reparam,
prod_layer_cls=inner_cls,
prod_layer_kwargs=None,
)

0 comments on commit c571443

Please sign in to comment.