Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing mechanism to identify non trainable gates #61

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/qiboml/interfaces/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __post_init__(
params = torch.as_tensor(self.backend.to_numpy(x=params)).ravel()
params.requires_grad = True
self.circuit_parameters = torch.nn.Parameter(params)
# This is used to know if the non-differentiable gates have been marked
# so to reduce the computational cost of hardware-compatible differentiation
self._differentiability_checked = False

backend_string = (
f"{self.decoding.backend.name}-{self.decoding.backend.platform}"
Expand All @@ -59,6 +62,15 @@ def forward(self, x: torch.Tensor):
x = self.encoding(x) + self.circuit
x = self.decoding(x)
else:
# Also for the first iteration, marking which gates are differentiable
if not self._differentiability_checked:
self._mark_differentiable_angles(
circuit=self.encoding(x) + self.circuit,
differentiate_wrt_data=False, # TODO: expose this feature if we think it's useful
)
# Inform the model the check is done
self._differentiability_checked = True

x = QuantumModelAutoGrad.apply(
x,
self.encoding,
Expand Down Expand Up @@ -86,6 +98,27 @@ def backend(
def output_shape(self):
return self.decoding.output_shape

def _mark_differentiable_angles(self, circuit, differentiate_wrt_data=False):
"""
Check circuit's parameters and identify which, among them, are not differentiable.
This will be useful to reduce the computational cost of the parameter shift
rule, by avoiding the computation of gradients w.r.t. input data.
The distinction is made by setting the gates containing non differentiable
angles as `trainable=False` within Qibo.

Args:
circuit (qibo.Circuit): qibo circuit to be checked.
differentiate_wrt_data (bool): if True, gradient w.r.t. input data
are computed as well. Default to `False`.
"""

if not differentiate_wrt_data:
for gate in circuit.parametrized_gates:
if any(param.requires_grad == True for param in gate.parameters):
gate.trainable = True
else:
gate.trainable = False


class QuantumModelAutoGrad(torch.autograd.Function):

Expand Down
21 changes: 12 additions & 9 deletions src/qiboml/operations/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,19 @@ def evaluate(self, x: ndarray, encoding, training, decoding, backend, *parameter
)

# compute second gradient part, wrt parameters
for i in range(len(parameters)):
gradient.append(
self.one_parameter_shift(
circuit=circuit,
decoding=decoding,
parameters=parameters,
parameter_index=i,
backend=backend,
for i, gate in enumerate(len(circuit.trainable_gates)):
if gate.trainable:
gradient.append(
self.one_parameter_shift(
circuit=circuit,
decoding=decoding,
parameters=parameters,
parameter_index=i,
backend=backend,
)
)
)
else:
gradient.append(0.0)
return gradient

def one_parameter_shift(
Expand Down
Loading