Skip to content

Commit

Permalink
support passing additional activation parameters in parameterizations
Browse files Browse the repository at this point in the history
  • Loading branch information
isurulucky committed Dec 11, 2024
1 parent 125e4a1 commit 6882355
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
6 changes: 5 additions & 1 deletion cirkit/templates/circuit_templates/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class Parameterization:
"""The activation function. Defaults to 'none', i.e., no activation."""
dtype: str = "real"
"""The data type. Defaults to 'real', i.e., real numbers."""
activation_kwargs: dict[str, object] | None = None
"""Additional keyword-based arguments to be used in activation"""


def build_image_region_graph(
Expand Down Expand Up @@ -118,7 +120,9 @@ def parameterization_to_factory(param: Parameterization) -> ParameterFactory:
Raises:
ValueError: If one of the settings in the given parameterization is unknown.
"""
unary_op_factory = name_to_parameter_activation(param.activation)
unary_op_factory = name_to_parameter_activation(
param.activation, **param.activation_kwargs if param.activation_kwargs else {}
)
dtype = name_to_dtype(param.dtype)
initializer = name_to_initializer(param.initialization)
return functools.partial(
Expand Down
Empty file.
13 changes: 13 additions & 0 deletions tests/templates/circuit_templates/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from cirkit.templates import circuit_templates
from cirkit.templates.circuit_templates.utils import parameterization_to_factory


def test_param_activation_kwargs():
min_var = 0.1
max_var = 1.0
stddev_param = circuit_templates.Parameterization(initialization="normal", activation="positive-clamp",
activation_kwargs={"vmin": min_var, "vmax": max_var})
stddev_param_factory = parameterization_to_factory(stddev_param)
parameter = stddev_param_factory(shape=(1, 1),)
assert (parameter.output.vmin == min_var), f"parameter min value expected: {min_var}, got: {parameter.output.vmin}"
assert (parameter.output.vmax == max_var), f"parameter ma value expected: {max_var}, got: {parameter.output.vmax}"

0 comments on commit 6882355

Please sign in to comment.