-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix and test Jax automatic differentiation #14
base: main
Are you sure you want to change the base?
Conversation
src/qiboml/operations/expectation.py
Outdated
return _with_tf(**kwargs) | ||
|
||
if isinstance(qibo_backend, JaxBackend): | ||
return _with_jax(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe the _with_x
functions should directly be part of the backend, this way here you don't need to check which backend is used but just do qibo_backend._with(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be fair, it would be even nicer if I could construct a Qibo observable without any specific backend...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a way to detect which is the frontend, and then build the differentiation rule on top of this. To have an agnostic observable doesn't seems to me a good idea, because it is the way we detect the frontend from the user code.
Regarding the _with(...)
suggestion, it could be indeed a nice idea, but then I see a problem: this function will be a third version of the expectation
we have in the backends. Moreover, it is a strange expectation since it wants as an argument a differentiation_rule
, which (my opinion) makes sense only in this QML context.
src/qiboml/operations/expectation.py
Outdated
for p in range(len(params)): | ||
gradients.append( | ||
differentiation_rule( | ||
circuit=circuit, | ||
hamiltonian=observable, | ||
parameter_index=p, | ||
initial_state=initial_state, | ||
nshots=nshots, | ||
backend=backend, | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Batched execution wold be great here, for example, as you could perform psr on several different parameters parallely. Namely the batch would be
((theta_0 + eps, theta_1, ..., theta_n),
(theta_0, theta_1 + eps, ..., theta_n),
...,
(theta_0, theta_1, ..., theta_n + eps))
def _expectation(params): | ||
params = jax.numpy.array(params) | ||
|
||
def grad(params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about jitting grad? Is it possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, but in principle it could be done even later on: JAX transformations are composable, so you could do grad(grad(jit(grad(...))))
.
(but I'm not sure which is the best way for performances, if it's better to jit
each building block or rely on the external compilation - also, in this case, I'm not sure if there is any benefit in jit
ting, since it could be the result of hardware execution...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I jitted
all the functions we are returning now, but I think we should build a Jax version of the Qibojit backend, with a more fine-grained jitting, to boost the performances.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not entirely happy about the structure yet, but I think we need to start implementing our model
object in order to have a better global view of the thing as a whole, and truly understand which direction to take
@@ -14,7 +16,7 @@ def parameter_shift( | |||
initial_state: Optional[Union[np.ndarray, qibo.Circuit]] = None, | |||
scale_factor: float = 1.0, | |||
nshots: int = None, | |||
backend: str = "qibojit", | |||
exec_backend: qibo.backends.Backend = NumbaBackend, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exec_backend: qibo.backends.Backend = NumbaBackend, | |
exec_backend: qibo.backends.Backend = NumbaBackend(), |
I think this should be NumbaBackend()
otherwise you'll be passing just the class. In general I am not against using numba by default, but then qibojit has to become a mandatory dependency of qiboml. Personally, I think the JaxBackend
is the best candidate to be used as default, if not the only one at all.
def Unitary(self, u): | ||
return self._cast(u, dtype=self.dtype) | ||
|
||
|
||
class JaxBackend(NumpyBackend): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would probably rewrite the execute_circuit
and apply_gate
functions of the backend to properly make use of the jit decorator
No description provided.