Skip to content

Commit

Permalink
Quagliato
Browse files Browse the repository at this point in the history
  • Loading branch information
PietropaoloFrisoni committed Feb 7, 2025
1 parent 1a17f95 commit f949897
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,29 @@ def __init__(self, gate_set=None, max_expansion=None):

super().__init__()

def sub_interpret_operation(self, op, current_depth):

if op.has_plxpr_decomposition:

if self.gate_set(op):
return super().interpret_operation(op)

args = (*op.parameters, *op.wires)
jaxpr_decomp = jax.make_jaxpr(
partial(op.compute_plxpr_decomposition, **op.hyperparameters)
)(*args)

return self.eval_dynamic_decomposition(
jaxpr_decomp.jaxpr,
jaxpr_decomp.consts,
*args,
current_depth=current_depth + 1,
)

else:

return self.interpret_operation(op)

def stopping_condition(self, op: qml.operation.Operator) -> bool:
"""Function to determine whether or not an operator needs to be decomposed or not.
Expand All @@ -106,6 +129,14 @@ def stopping_condition(self, op: qml.operation.Operator) -> bool:
bool: Whether or not ``op`` is valid or needs to be decomposed. ``True`` means
that the operator does not need to be decomposed.
"""
# if op.has_plxpr_decomposition:
# warnings.warn(
# f"The plxpr decomposition of operator {op.name} will be ignored, "
# f"since the is the result of a decomposition of a previous operator "
# f"which doesn't have a plxpr decomposition. "
# f"This limitation will be removed in the future.",
# UserWarning,
# )
if not op.has_decomposition:
if not self.gate_set(op):
warnings.warn(
Expand Down Expand Up @@ -148,7 +179,10 @@ def decompose_operation(self, op: qml.operation.Operator, current_depth: int = 0
)
)

return [self.interpret_operation(decomp_op) for decomp_op in decomposition]
return [
self.sub_interpret_operation(decomp_op, current_depth)
for decomp_op in decomposition
]

def eval_dynamic_decomposition(
self, jaxpr_decomp: "jax.core.Jaxpr", consts, *args, current_depth: int = 0
Expand Down

0 comments on commit f949897

Please sign in to comment.