Skip to content

Commit

Permalink
Add some explanations on NonInterpPrimitive class (#6851)
Browse files Browse the repository at this point in the history
**Context:**
The capture module uses a variant of `jax.core.Primitive` called
`NonInterpPrimitive`.
There were questions about why we need this and what it does.

**Description of the Change:**
This PR only adds some explanations to the respective `md` file to
motivate our usage of this primitive variant.

**Benefits:**
Explain code

**Possible Drawbacks:**
N/A

**Related GitHub Issues:**

---------

Co-authored-by: Mudit Pandey <[email protected]>
Co-authored-by: Pietropaolo Frisoni <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2025
1 parent 397273b commit 5eaaccb
Showing 1 changed file with 103 additions and 3 deletions.
106 changes: 103 additions & 3 deletions pennylane/capture/explanations.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class MyClass(metaclass=MyMetaClass):
self.kwargs = kwargs
```

Creating a new type <class '__main__.MyClass'> with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': <function MyClass.__init__ at 0x11c59cae0>}), {}.
Creating a new type <class '__main__.MyClass'> with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': <function MyClass.__init__ at 0x11c59cae0>}), {}.


And that we have set a class property `a`
Expand All @@ -272,7 +272,7 @@ But can we actually create instances of these classes?
```python
>> obj = MyClass(0.1, a=2)
>>> obj
creating an instance of type <class '__main__.MyClass'> with (0.1,), {'a': 2}.
creating an instance of type <class '__main__.MyClass'> with (0.1,), {'a': 2}.
now creating an instance in __init__
<__main__.MyClass at 0x11c5a2810>
```
Expand All @@ -294,7 +294,7 @@ class MyClass2(metaclass=MetaClass2):
self.args = args
```

You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`.
You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`.

Using a metaclass, we can hijack what happens when a type is called.

Expand Down Expand Up @@ -425,3 +425,103 @@ Now in our jaxpr, we can see thet `PrimitiveClass2` returns something of type `A
>>> jax.make_jaxpr(PrimitiveClass2)(0.1)
{ lambda ; a:f32[]. let b:AbstractPrimitiveClass() = PrimitiveClass2 a in (b,) }
```

# Non-interpreted primitives

Some of the primitives in the capture module have a somewhat non-standard requirement for the
behaviour under differentiation or batching: they should ignore that an input is a differentiation
or batching tracer and just execute the standard implementation on them.

We will look at an example to make the necessity for such a non-interpreted primitive clear.

Consider a finite-difference differentiation routine together with some test function `fun`.

```python
def finite_diff_impl(x, fun, delta):
"""Finite difference differentiation routine. Only supports differentiating
a function `fun` with a single scalar argument, for simplicity."""

out_plus = fun(x + delta)
out_minus = fun(x - delta)
return tuple((out_p - out_m) / (2 * delta) for out_p, out_m in zip(out_plus, out_minus))

def fun(x):
return (x**2, 4 * x - 3, x**23)
```

Now suppose we want to turn this into a primitive. We could just promote it to a standard
`jax.core.Primitive` as

```python
import jax

fd_prim = jax.core.Primitive("finite_diff")
fd_prim.multiple_results = True
fd_prim.def_impl(finite_diff_impl)

def finite_diff(x, fun, delta=1e-5):
return fd_prim.bind(x, fun, delta)
```

This allows us to use the forward pass as usual (to compute the first-order derivative):

```pycon
>>> finite_diff(1., fun, delta=1e-6)
(2.000000000002, 3.999999999892978, 23.000000001216492)
```

Now if we want to make this primitive differentiable (with automatic
differentiation/backprop, not by using a higher-order finite difference scheme),
we need to specify a JVP rule. (Note that there are multiple rather simple fixes for this example
that we could use to implement a finite difference scheme that is readily differentiable. This is
somewhat beside the point because we did not identify the possibility of using any of those
alternatives in the PennyLane code).

However, the finite difference rule is just a standard
algebraic function making use of calls to `fun` and some elementary operations, so ideally
we would like to just use the chain rule as it is known to the automatic differentiation framework. A JVP rule would
then just manually re-implement this chain rule, which we'd rather not do.

Instead, we define a non-interpreted type of primitive and create such a primitive
for our finite difference method. We also create the usual method that binds the
primitive to inputs.

```python
class NonInterpPrimitive(jax.core.Primitive):
"""A subclass to JAX's Primitive that works like a Python function
when evaluating JVPTracers."""

def bind_with_trace(self, trace, args, params):
"""Bind the ``NonInterpPrimitive`` with a trace.
If the trace is a ``JVPTrace``, it falls back to a standard Python function call.
Otherwise, the bind call of JAX's standard Primitive is used."""
if isinstance(trace, jax.interpreters.ad.JVPTrace):
return self.impl(*args, **params)
return super().bind_with_trace(trace, args, params)

fd_prim_2 = NonInterpPrimitive("finite_diff_2")
fd_prim_2.multiple_results = True
fd_prim_2.def_impl(finite_diff_impl) # This also defines the behaviour with a JVP tracer

def finite_diff_2(x, fun, delta=1e-5):
return fd_prim_2.bind(x, fun, delta)
```

Now we can use the primitive in a differentiable workflow, without defining a JVP rule
that just repeats the chain rule:

```pycon
>>> # Forward execution of finite_diff_2 (-> first-order derivative)
>>> finite_diff_2(1., fun, delta=1e-6)
(2.000000000002, 3.999999999892978, 23.000000001216492)
>>> # Differentiation of finite_diff_2 (-> second-order derivative)
>>> jax.jacobian(finite_diff_2)(1., fun, delta=1e-6)
(Array(1.9375, dtype=float32, weak_type=True), Array(0., dtype=float32, weak_type=True), Array(498., dtype=float32, weak_type=True))
```

In addition to the differentiation primitives for `qml.jacobian` and `qml.grad`, quantum operators
have non-interpreted primitives as well. This is because their differentiation is performed
by the surrounding QNode primitive rather than through the standard chain rule that acts
"locally" (in the circuit). In short, we only want gates to store their tracers (which will help
determine the differentiability of gate arguments, for example), but not to do anything with them.

0 comments on commit 5eaaccb

Please sign in to comment.