Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testin some improvements in the JaxBackend #60

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion src/qiboml/backends/jax.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,57 @@
from functools import partial

import jax
import jax.numpy as jnp # pylint: disable=import-error
from qibo import __version__
from qibo.backends import einsum_utils
from qibo.backends.npmatrices import NumpyMatrices
from qibo.backends.numpy import NumpyBackend
from qibo.config import raise_error


class JaxMatrices(NumpyMatrices):

def __init__(self, dtype):
super().__init__(dtype)
self.np = jnp
self.dtype = dtype

def _cast(self, x, dtype):
return jnp.asarray(x, dtype=dtype)


@partial(jax.jit, static_argnums=(2, 3))
def _apply_gate(matrix, state, qubits, nqubits):
state = jnp.reshape(state, nqubits * (2,))
matrix = jnp.reshape(matrix, 2 * len(qubits) * (2,))
opstring = einsum_utils.apply_gate_string(qubits, nqubits)
state = jnp.einsum(opstring, state, matrix)
return jnp.reshape(state, (2**nqubits,))


@partial(jax.jit, static_argnums=(4, 5, 6))
def _apply_gate_controlled(
matrix, state, order, targets, control_qubits, target_qubits, nqubits
):
state = jnp.reshape(state, nqubits * (2,))
matrix = jnp.reshape(matrix, 2 * len(target_qubits) * (2,))
ncontrol = len(control_qubits)
nactive = nqubits - ncontrol
state = jnp.transpose(state, order)
# Apply `einsum` only to the part of the state where all controls
# are active. This should be `state[-1]`
state = jnp.reshape(state, (2**ncontrol,) + nactive * (2,))
opstring = einsum_utils.apply_gate_string(targets, nactive)
updates = jnp.einsum(opstring, state[-1], matrix)
# Concatenate the updated part of the state `updates` with the
# part of of the state that remained unaffected `state[:-1]`.
state = jnp.concatenate([state[:-1], updates[None]], axis=0)
state = jnp.reshape(state, nqubits * (2,))
# Put qubit indices back to their proper places
state = jnp.transpose(state, einsum_utils.reverse_order(order))
return jnp.reshape(state, (2**nqubits,))


class JaxBackend(NumpyBackend):
def __init__(self):
super().__init__()
Expand All @@ -22,7 +70,7 @@ def __init__(self):

self.np = jnp
self.tensor_types = (jnp.ndarray, numpy.ndarray)
self.matrices.np = jnp
self.matrices = JaxMatrices(self.dtype)

def set_precision(self, precision):
if precision != self.precision:
Expand Down Expand Up @@ -88,6 +136,20 @@ def update_frequencies(self, frequencies, probabilities, nsamples):
frequencies = frequencies.at[res].add(counts)
return frequencies

def apply_gate(self, gate, state, nqubits):
if gate.is_controlled_by:
order, targets = einsum_utils.control_order(gate, nqubits)
return _apply_gate_controlled(
gate.matrix(self),
state,
order,
targets,
gate.control_qubits,
gate.target_qubits,
nqubits,
)
return _apply_gate(gate.matrix(self), state, gate.qubits, nqubits)

def apply_gate_density_matrix(self, gate, state, nqubits):
state = self.cast(state)
state = self.np.reshape(state, 2 * nqubits * (2,))
Expand Down
Loading