Skip to content

Commit

Permalink
Merge branch 'main' into activation-params
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc authored Dec 12, 2024
2 parents df163c8 + 5eefd70 commit cae54eb
Show file tree
Hide file tree
Showing 7 changed files with 1,784 additions and 2 deletions.
236 changes: 235 additions & 1 deletion cirkit/symbolic/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from enum import IntEnum, auto
from functools import cached_property
from typing import Any, Protocol, cast
from os import PathLike
from pathlib import Path

import graphviz

from cirkit.symbolic.layers import (
HadamardLayer,
Expand All @@ -14,7 +18,23 @@
ProductLayer,
SumLayer,
)
from cirkit.symbolic.parameters import ParameterFactory
from cirkit.symbolic.parameters import (
Parameter,
ParameterFactory,
TensorParameter
)
from cirkit.symbolic.initializers import ConstantTensorInitializer
from cirkit.templates.logic import (
BottomNode,
ConjunctionNode,
DisjunctionNode,
LiteralNode,
LogicCircuitNode,
LogicGraph,
NegatedLiteralNode,
TopNode,
default_literal_input_factory,
)
from cirkit.templates.region_graph import PartitionNode, RegionGraph, RegionGraphNode, RegionNode
from cirkit.utils.algorithms import (
DiAcyclicGraph,
Expand Down Expand Up @@ -851,6 +871,220 @@ def from_hmm(

return cls(num_channels, layers, in_layers, [layers[-1]])

@classmethod
def from_logic_circuit(
cls,
logic_graph: LogicGraph,
*,
literal_input_factory: InputLayerFactory = None,
negated_literal_input_factory: InputLayerFactory = None,
weight_factory: ParameterFactory | None = None,
num_channels: int = 1,
enforce_smoothness: bool = True
) -> "Circuit":
"""
Construct a symbolic circuit from a logic circuit graph.
If input factories for literals and their negation are not provided the it
falls back to a categorical input layer with two categories parametrized by
the constant vector [0, 1] for a literal and [1, 0] for its negation.
Args:
logic_graph: The logic circuit graph.
literal_input_factory: A factory that builds an input layer for literals.
negated_literal_input_factory: A factory that builds an input layer for negated literals.
weight_factory: The factory to construct the weight of sum layers. It can be None,
or a parameter factory, i.e., a map from a shape to a symbolic parameter.
If None is used, the default weight factory uses non-trainable unitary parameters,
which instantiate a regular boolean logic graph.
num_channels: The number of channels for each variable.
enforce_smoothness: Enforces smoothness of the circuit to support efficient marginalization.
Returns:
Circuit: A symbolic circuit.
Raises:
ValueError: If only one of literal_input_factory and negated_literal_input_factory is specified.
"""
if enforce_smoothness:
simplified_graph = logic_graph.smooth().simplify()
else:
simplified_graph = logic_graph.simplify()

in_layers: dict[Layer, Sequence[Layer]] = {}
node_to_layer: dict[LogicCircuitNode, Layer] = {}

if (literal_input_factory is None) ^ (negated_literal_input_factory is None):
raise ValueError(
"Either both 'literal_input_factory' and 'negated_literal_input_factory' \
must be provided or none."
)

if literal_input_factory is None and negated_literal_input_factory is None:
# default factory is locally imported when needed to avoid circular imports
literal_input_factory = default_literal_input_factory(negated=False)
negated_literal_input_factory = default_literal_input_factory(negated=True)

if weight_factory is None:
# default to unitary weights
def weight_factory(n: tuple[int]) -> Parameter:
# locally import numpy to avoid dependency on the whole file
initializer = ConstantTensorInitializer(1.0)
return Parameter.from_input(TensorParameter(*n, initializer=initializer))

# map each input literal to a symbolic input layer
for i in simplified_graph.inputs:
match i:
case LiteralNode():
node_to_layer[i] = literal_input_factory(
Scope([i.literal]), num_units=1, num_channels=num_channels
)
case NegatedLiteralNode():
node_to_layer[i] = negated_literal_input_factory(
Scope([i.literal]), num_units=1, num_channels=num_channels
)

for node in simplified_graph.topological_ordering():
match node:
case ConjunctionNode():
product_node = HadamardLayer(1, arity=len(simplified_graph.node_inputs(node)))
in_layers[product_node] = [node_to_layer[i] for i in simplified_graph.node_inputs(node)]
node_to_layer[node] = product_node
case DisjunctionNode():
sum_node = SumLayer(1, 1, arity=len(simplified_graph.node_inputs(node)), weight_factory=weight_factory)
in_layers[sum_node] = [node_to_layer[i] for i in simplified_graph.node_inputs(node)]
node_to_layer[node] = sum_node

layers = list(set(itertools.chain(*in_layers.values())).union(in_layers.keys()))

return cls(num_channels, layers, in_layers, [node_to_layer[simplified_graph.output]])

def plot(
self,
out_path: str | PathLike[str] | None = None,
orientation: str = "vertical",
node_shape: str = "box",
label_font: str = "times italic bold",
label_size: str = "21pt",
label_color: str = "white",
sum_label: str | Callable[[SumLayer], str] = "+",
sum_color: str | Callable[[SumLayer], str] = "#607d8b",
product_label: str | Callable[[ProductLayer], str] = "⊙",
product_color: str | Callable[[ProductLayer], str] = "#24a5af",
input_label: str | Callable[[InputLayer], str] = lambda l: " ".join(map(str, l.scope)),
input_color: str | Callable[[InputLayer], str] = "#ffbd2a",
) -> graphviz.Digraph:
"""Plot the current symbolic circuit using graphviz.
A graphviz object is returned, which can be visualized in jupyter notebooks.
If format is not provided, SVG is used for optimal rendering in notebooks.
Args:
out_path ( str | PathLike[str] | None, optional): The output path where the plot is save
If it is None, the plot is not saved to a file. Defaults to None.
The Output file format is deduce from the path. Possible formats are:
{'jp2', 'plain-ext', 'sgi', 'x11', 'pic', 'jpeg', 'imap', 'psd', 'pct',
'json', 'jpe', 'tif', 'tga', 'gif', 'tk', 'xlib', 'vmlz', 'json0', 'vrml',
'gd', 'xdot', 'plain', 'cmap', 'canon', 'cgimage', 'fig', 'svg', 'dot_json',
'bmp', 'png', 'cmapx', 'pdf', 'webp', 'ico', 'xdot_json', 'gtk', 'svgz',
'xdot1.4', 'cmapx_np', 'dot', 'tiff', 'ps2', 'gd2', 'gv', 'ps', 'jpg',
'imap_np', 'wbmp', 'vml', 'eps', 'xdot1.2', 'pov', 'pict', 'ismap', 'exr'}.
See https://graphviz.org/docs/outputs/ for more.
orientation (str, optional): Orientation of the graph. "vertical" puts the root
node at the top, "horizontal" at left. Defaults to "vertical".
node_shape (str, optional): Default shape for a node in the graph. Defaults to "box".
See https://graphviz.org/doc/info/shapes.html for the supported shapes.
label_font (str, optional): Font used to render labels. Defaults to "times italic bold".
See https://graphviz.org/faq/font/ for the available fonts.
label_size (str, optional): Size of the font for labels in points. Defaults to 21pt.
label_color (str, optional): Color for the labels in the nodes. Defaults to "white".
See https://graphviz.org/docs/attr-types/color/ for supported color.
sum_label (str | Callable[[SumLayer], str], optional): Either a string or a function.
If a function is provided, then it must take as input a sum layer and returns a string
that will be used as label. Defaults to "+".
sum_color (str | Callable[[SumLayer], str], optional): Either a string or a function.
If a function is provided, then it must take as input a sum layer and returns a string
that will be used as color for the sum node. Defaults to "#607d8b".
product_label (str | Callable[[ProductLayer], str], optional): Either a string or a function.
If a function is provided, then it must take as input a product layer and returns a string
that will be used as label. Defaults to "⊙".
product_color (str | Callable[[ProductLayer], str], optional): Either a string or a function.
If a function is provided, then it must take as input a product layer and returns a string
that will be used as color for the product node. Defaults to "#24a5af".
input_label (_type_, optional): Either a string or a function.
If a function is provided, then it must take as input an input layer and returns a string
that will be used as label. Defaults to using the scope of the layer.
input_color (str | Callable[[ProductLayer], str], optional): Either a string or a function.
If a function is provided, then it must take as input an input layer and returns a string
that will be used as color for the input layer node. Defaults to "#ffbd2a".
Raises:
ValueError: The format is not among the supported ones.
ValueError: The direction is not among the supported ones.
Returns:
graphviz.Digraph: _description_
"""
if out_path is None:
fmt: str = "svg"
else:
fmt: str = Path(out_path).suffix.replace(".", "")
if fmt not in graphviz.FORMATS:
raise ValueError(f"Supported formats are {graphviz.FORMATS}.")

if orientation not in ["vertical", "horizontal"]:
raise ValueError("Supported graph directions are only 'vertical' and 'horizontal'.")

dot: graphviz.Digraph = graphviz.Digraph(
format=fmt,
node_attr={
"shape": node_shape,
"style": "filled",
"fontcolor": label_color,
"fontsize": label_size,
"fontname": label_font,
},
engine="dot",
)

dot.graph_attr["rankdir"] = "BT" if orientation == "vertical" else "LR"

for layer in self.layers:
match layer:
case HadamardLayer():
dot.node(
str(id(layer)),
product_label if isinstance(product_label, str) else product_label(layer),
color=product_color
if isinstance(product_color, str)
else product_color(layer),
)
case SumLayer():
dot.node(
str(id(layer)),
sum_label if isinstance(sum_label, str) else sum_label(layer),
color=sum_color if isinstance(sum_color, str) else sum_color(layer),
)
case InputLayer():
dot.node(
str(id(layer)),
input_label if isinstance(input_label, str) else input_label(layer),
color=input_color if isinstance(input_color, str) else input_color(layer),
)

for node, inputs in self.layers_inputs.items():
for i in inputs:
dot.edge(str(id(i)), str(id(node)))

if out_path is not None:
out_path: Path = Path(out_path).with_suffix("")

if fmt == "dot":
with open(out_path, "w", encoding="utf8") as f:
f.write(dot.source)
else:
dot.format = fmt
dot.render(out_path)

return dot

def are_compatible(sc1: Circuit, sc2: Circuit) -> bool:
"""Check if two symbolic circuits are compatible.
Expand Down
10 changes: 10 additions & 0 deletions cirkit/templates/logic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .graph import BottomNode as BottomNode
from .graph import ConjunctionNode as ConjunctionNode
from .graph import DisjunctionNode as DisjunctionNode
from .graph import LiteralNode as LiteralNode
from .graph import LogicCircuitNode as LogicCircuitNode
from .graph import LogicGraph as LogicGraph
from .graph import NegatedLiteralNode as NegatedLiteralNode
from .graph import TopNode as TopNode
from .utils import default_literal_input_factory as default_literal_input_factory
from .sdd import SDD as SDD
Loading

0 comments on commit cae54eb

Please sign in to comment.