From 838df9bf530875326243286a58e227cff1b9efaa Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 4 Feb 2025 14:57:46 -0500 Subject: [PATCH 1/3] add eval_jaxpr method to null.qubit --- doc/releases/changelog-dev.md | 2 ++ pennylane/devices/null_qubit.py | 14 ++++++++++++++ tests/devices/test_null_qubit.py | 20 ++++++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index adb442ed234..e7bc017d676 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -151,6 +151,8 @@ * The `qml.clifford_t_decomposition` has been improved to use less gates when decomposing `qml.PhaseShift`. [(#6842)](https://github.com/PennyLaneAI/pennylane/pull/6842) +* `null.qubit` can now execute jaxpr. +

Breaking changes 💔

* `MultiControlledX` no longer accepts strings as control values. diff --git a/pennylane/devices/null_qubit.py b/pennylane/devices/null_qubit.py index 004460ffaca..7c00899a99c 100644 --- a/pennylane/devices/null_qubit.py +++ b/pennylane/devices/null_qubit.py @@ -404,3 +404,17 @@ def execute_and_compute_vjp( results = tuple(self._simulate(c, _interface(execution_config)) for c in circuits) vjps = tuple(self._vjp(c, _interface(execution_config)) for c in circuits) return results, vjps + + def eval_jaxpr(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: + from pennylane.capture.primitives import ( # pylint: disable=import-outside-toplevel + AbstractMeasurement, + ) + + def zeros_like(var): + if isinstance(var.aval, AbstractMeasurement): + shots = self.shots.total_shots + s, dtype = var.aval.abstract_eval(num_device_wires=len(self.wires), shots=shots) + return math.zeros(s, dtype=dtype, like="jax") + return math.zeros(var.aval.shape, dtype=var.aval.dtype, like="jax") + + return [zeros_like(var) for var in jaxpr.outvars] diff --git a/tests/devices/test_null_qubit.py b/tests/devices/test_null_qubit.py index 174ddd87283..07ba457a6e6 100644 --- a/tests/devices/test_null_qubit.py +++ b/tests/devices/test_null_qubit.py @@ -1273,3 +1273,23 @@ def circuit(param): res = qml.QNode(circuit, nq)(x) target = qml.QNode(circuit, dq)(x) assert qml.math.shape(res) == qml.math.shape(target) + + +# pylint: disable=unused-argument +@pytest.mark.jax +def test_execute_plxpr(enable_disable_plxpr): + """Test that null.qubit can execute plxpr.""" + + import jax + + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)), qml.probs(), 4 + + jaxpr = jax.make_jaxpr(f)(0.5) + + dev = qml.device("null.qubit", wires=4) + res = dev.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1.0) + assert qml.math.allclose(res[0], 0) + assert qml.math.allclose(res[1], jax.numpy.zeros(2**4)) + assert qml.math.allclose(res[2], 0) # other values are still just zero From f3a69e15b0a0280c90a10336813431e198fb67c6 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Tue, 4 Feb 2025 15:00:35 -0500 Subject: [PATCH 2/3] Update doc/releases/changelog-dev.md --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index e7bc017d676..1d889e8ca7b 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -152,6 +152,7 @@ [(#6842)](https://github.com/PennyLaneAI/pennylane/pull/6842) * `null.qubit` can now execute jaxpr. + [(#6924)](https://github.com/PennyLaneAI/pennylane/pull/6924)

Breaking changes 💔

From 7be11a1742ebf18708830e9add46fa6b3c1688b1 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 4 Feb 2025 17:14:09 -0500 Subject: [PATCH 3/3] more tests --- tests/devices/test_null_qubit.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/devices/test_null_qubit.py b/tests/devices/test_null_qubit.py index 07ba457a6e6..577cd4efe1b 100644 --- a/tests/devices/test_null_qubit.py +++ b/tests/devices/test_null_qubit.py @@ -1284,7 +1284,7 @@ def test_execute_plxpr(enable_disable_plxpr): def f(x): qml.RX(x, 0) - return qml.expval(qml.Z(0)), qml.probs(), 4 + return qml.expval(qml.Z(0)), qml.probs(), 4, qml.var(qml.X(0)), qml.state() jaxpr = jax.make_jaxpr(f)(0.5) @@ -1293,3 +1293,22 @@ def f(x): assert qml.math.allclose(res[0], 0) assert qml.math.allclose(res[1], jax.numpy.zeros(2**4)) assert qml.math.allclose(res[2], 0) # other values are still just zero + assert qml.math.allclose(res[3], 0) + assert qml.math.allclose(res[4], jax.numpy.zeros(2**4, dtype=complex)) + + +@pytest.mark.jax +def test_execute_plxpr_shots(enable_disable_plxpr): + import jax + + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)), 5, qml.sample(wires=(0, 1)) + + jaxpr = jax.make_jaxpr(f)(0.5) + + dev = qml.device("null.qubit", wires=4, shots=50) + res = dev.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1.0) + assert qml.math.allclose(res[0], 0) + assert qml.math.allclose(res[1], 0) + assert qml.math.allclose(res[2], jax.numpy.zeros((50, 2)))