diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index f0de7d626ff..de61177f448 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -96,6 +96,29 @@ def __init__(self, gate_set=None, max_expansion=None): super().__init__() + def sub_interpret_operation(self, op, current_depth): + + if op.has_plxpr_decomposition: + + if self.gate_set(op): + return super().interpret_operation(op) + + args = (*op.parameters, *op.wires) + jaxpr_decomp = jax.make_jaxpr( + partial(op.compute_plxpr_decomposition, **op.hyperparameters) + )(*args) + + return self.eval_dynamic_decomposition( + jaxpr_decomp.jaxpr, + jaxpr_decomp.consts, + *args, + current_depth=current_depth + 1, + ) + + else: + + return self.interpret_operation(op) + def stopping_condition(self, op: qml.operation.Operator) -> bool: """Function to determine whether or not an operator needs to be decomposed or not. @@ -106,6 +129,14 @@ def stopping_condition(self, op: qml.operation.Operator) -> bool: bool: Whether or not ``op`` is valid or needs to be decomposed. ``True`` means that the operator does not need to be decomposed. """ + # if op.has_plxpr_decomposition: + # warnings.warn( + # f"The plxpr decomposition of operator {op.name} will be ignored, " + # f"since the is the result of a decomposition of a previous operator " + # f"which doesn't have a plxpr decomposition. " + # f"This limitation will be removed in the future.", + # UserWarning, + # ) if not op.has_decomposition: if not self.gate_set(op): warnings.warn( @@ -148,7 +179,10 @@ def decompose_operation(self, op: qml.operation.Operator, current_depth: int = 0 ) ) - return [self.interpret_operation(decomp_op) for decomp_op in decomposition] + return [ + self.sub_interpret_operation(decomp_op, current_depth) + for decomp_op in decomposition + ] def eval_dynamic_decomposition( self, jaxpr_decomp: "jax.core.Jaxpr", consts, *args, current_depth: int = 0