Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and test Jax automatic differentiation #14

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

MatteoRobbiati
Copy link
Contributor

No description provided.

src/qiboml/operations/expectation.py Outdated Show resolved Hide resolved
Comment on lines 64 to 67
return _with_tf(**kwargs)

if isinstance(qibo_backend, JaxBackend):
return _with_jax(**kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the _with_x functions should directly be part of the backend, this way here you don't need to check which backend is used but just do qibo_backend._with(...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be fair, it would be even nicer if I could construct a Qibo observable without any specific backend...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a way to detect which is the frontend, and then build the differentiation rule on top of this. To have an agnostic observable doesn't seems to me a good idea, because it is the way we detect the frontend from the user code.
Regarding the _with(...) suggestion, it could be indeed a nice idea, but then I see a problem: this function will be a third version of the expectation we have in the backends. Moreover, it is a strange expectation since it wants as an argument a differentiation_rule, which (my opinion) makes sense only in this QML context.

Comment on lines 170 to 180
for p in range(len(params)):
gradients.append(
differentiation_rule(
circuit=circuit,
hamiltonian=observable,
parameter_index=p,
initial_state=initial_state,
nshots=nshots,
backend=backend,
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Batched execution wold be great here, for example, as you could perform psr on several different parameters parallely. Namely the batch would be

((theta_0 + eps, theta_1, ..., theta_n),
(theta_0, theta_1 + eps, ..., theta_n),
...,
(theta_0, theta_1, ..., theta_n + eps))

def _expectation(params):
params = jax.numpy.array(params)

def grad(params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about jitting grad? Is it possible?

Copy link
Member

@alecandido alecandido May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, but in principle it could be done even later on: JAX transformations are composable, so you could do grad(grad(jit(grad(...)))).

(but I'm not sure which is the best way for performances, if it's better to jit each building block or rely on the external compilation - also, in this case, I'm not sure if there is any benefit in jitting, since it could be the result of hardware execution...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I jitted all the functions we are returning now, but I think we should build a Jax version of the Qibojit backend, with a more fine-grained jitting, to boost the performances.

Copy link
Contributor

@BrunoLiegiBastonLiegi BrunoLiegiBastonLiegi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not entirely happy about the structure yet, but I think we need to start implementing our model object in order to have a better global view of the thing as a whole, and truly understand which direction to take

@@ -14,7 +16,7 @@ def parameter_shift(
initial_state: Optional[Union[np.ndarray, qibo.Circuit]] = None,
scale_factor: float = 1.0,
nshots: int = None,
backend: str = "qibojit",
exec_backend: qibo.backends.Backend = NumbaBackend,
Copy link
Contributor

@BrunoLiegiBastonLiegi BrunoLiegiBastonLiegi May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
exec_backend: qibo.backends.Backend = NumbaBackend,
exec_backend: qibo.backends.Backend = NumbaBackend(),

I think this should be NumbaBackend() otherwise you'll be passing just the class. In general I am not against using numba by default, but then qibojit has to become a mandatory dependency of qiboml. Personally, I think the JaxBackend is the best candidate to be used as default, if not the only one at all.

def Unitary(self, u):
return self._cast(u, dtype=self.dtype)


class JaxBackend(NumpyBackend):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably rewrite the execute_circuit and apply_gate functions of the backend to properly make use of the jit decorator

@alecandido alecandido changed the base branch from backends to main May 17, 2024 13:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants