diff --git a/qualtran/bloqs/data_loading/one_hot_encoding.py b/qualtran/bloqs/data_loading/one_hot_encoding.py new file mode 100644 index 000000000..23e8a5726 --- /dev/null +++ b/qualtran/bloqs/data_loading/one_hot_encoding.py @@ -0,0 +1,54 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import attrs +import cirq +from attr import field +from numpy._typing import NDArray + +from qualtran import GateWithRegisters, QAny, QUInt, Signature +from qualtran.bloqs.basic_gates import TwoBitCSwap + + +@attrs.frozen +class OneHotEncoding(GateWithRegisters): + """ + One-hot encode a binary integer into a target register. + + Registers: + a: an unsigned integer + b: the target to one-hot encode. + + References: + [Windowed quantum arithmetic](https://arxiv.org/pdf/1905.07682.pdf) + Figure 4] + """ + + binary_reg_size: int = field() + + @property + def signature(self) -> 'Signature': + return Signature.build_from_dtypes( + a=QUInt(self.binary_reg_size), b=QAny(2**self.binary_reg_size) + ) + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] + ) -> cirq.OP_TREE: + a = quregs['a'] + b = quregs['b'] + + yield cirq.X(b[0]) + for i in range(len(a)): + for j in range(2**i): + yield TwoBitCSwap().on_registers(ctrl=a[i], x=b[j], y=b[2**i + j]) diff --git a/qualtran/bloqs/data_loading/one_hot_encoding_test.py b/qualtran/bloqs/data_loading/one_hot_encoding_test.py new file mode 100644 index 000000000..a07ec86a8 --- /dev/null +++ b/qualtran/bloqs/data_loading/one_hot_encoding_test.py @@ -0,0 +1,57 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import attrs +import cirq +import pytest +from attr import field +from numpy._typing import NDArray + +from qualtran import GateWithRegisters, QUInt, Signature +from qualtran.bloqs.data_loading.one_hot_encoding import OneHotEncoding +from qualtran.cirq_interop.bit_tools import iter_bits +from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim + + +@attrs.frozen +class OneHotEncodingTest(GateWithRegisters): + integer: int = field() + size: int = field() + + @property + def signature(self) -> 'Signature': + return Signature.build_from_dtypes(a=QUInt(self.size), b=QUInt(2**self.size)) + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] + ) -> cirq.OP_TREE: + a = quregs['a'] + b = quregs['b'] + binary_repr = list(iter_bits(self.integer, self.size))[::-1] + for i in range(self.size): + if binary_repr[i] == 1: + yield cirq.X(a[i]) + yield OneHotEncoding(binary_reg_size=self.size).on_registers(a=a, b=b) + + +@pytest.mark.parametrize('integer', list(range(8))) +def test_one_hot_encoding(integer): + gate = OneHotEncodingTest(integer, 3) + qubits = cirq.LineQubit.range(3 + 2**3) + op = gate.on_registers(a=qubits[:3], b=qubits[3:]) + circuit0 = cirq.Circuit(op) + initial_state = [0] * (3 + 2**3) + final_state = [0] * (3 + 2**3) + final_state[:3] = list(iter_bits(integer, 3))[::-1] + final_state[3 + integer] = 1 + assert_circuit_inp_out_cirqsim(circuit0, qubits, initial_state, final_state) diff --git a/qualtran/bloqs/data_loading/qrom_adjoint.py b/qualtran/bloqs/data_loading/qrom_adjoint.py new file mode 100644 index 000000000..54e9230b8 --- /dev/null +++ b/qualtran/bloqs/data_loading/qrom_adjoint.py @@ -0,0 +1,108 @@ +import dataclasses +import itertools +from typing import List, Sequence, Tuple + +import attrs +import cirq +from attr import field +from cirq import Condition +from numpy._typing import NDArray + +from qualtran import Signature +from qualtran.bloqs.data_loading import QROM +from qualtran.bloqs.data_loading.one_hot_encoding import OneHotEncoding +from qualtran.bloqs.data_loading.qrom import _to_tuple +from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate +from qualtran.cirq_interop.bit_tools import iter_bits +from qualtran.resource_counting.symbolic_counting_utils import SymbolicInt, is_symbolic, log2 + +@dataclasses.dataclass(frozen=True) +class QROMAdjCondition(Condition): + key: cirq.MeasurementKey + dx: List[int] + + def replace_key(self, current: cirq.MeasurementKey, replacement: cirq.MeasurementKey): + return QROMAdjCondition(replacement, self.dx) if self.key == current else self + + def resolve(self, classical_data: cirq.ClassicalDataStoreReader) -> bool: + y = classical_data.get_digits(self.key) + active = False + for yi, dxi in zip(y, self.dx): + active = not active if yi * dxi == 1 else active + return active + + + +@attrs.define +class QROMWithClassicalControls(QROM): + QROM_bloq: QROM = field(default=None) + mz_key: str = field(default="target_mzs") + + def calc_dx(self, x): + bitstring = [] + x_start = 0 + for i in range(len(self.QROM_bloq.target_bitsizes)): + bitsize = self.QROM_bloq.target_bitsizes[i] + data = self.QROM_bloq.data[i][x[x_start:x_start + bitsize]] + bitstring.extend(iter_bits(data, bitsize)) + return bitstring + + + def nth_operation( + self, context: cirq.DecompositionContext, control: cirq.Qid, **kwargs + ) -> cirq.OP_TREE: + selection_idx: int = kwargs[self.selection_registers[0].name] + target_regs = {reg.name: kwargs[reg.name] for reg in self.target_registers} + # yield self._load_nth_data(selection_idx, lambda q: CNOT().on(control, q), **target_regs) + # for i, d in enumerate(self.data): + # target = target_regs.get(f'target{i}_', ()) + target = target_regs.get(f'target{0}_', ()) + # for q, bit in zip(target, f'{int(self.data[0][selection_idx]):0{len(target)}b}'): + # if int(bit): + # yield gate(q) + N = int(log2(len(target))) + selection_bits = iter_bits(selection_idx, self.selection_bitsizes[0]) + for i in range(len(target)): + target_bits = iter_bits(i, N) + dx = self.calc_dx(list(itertools.chain(selection_bits, target_bits))) + yield cirq.X(target[i]).with_classical_controls(QROMAdjCondition(cirq.MeasurementKey(self.mz_key), dx)) + + +@attrs.frozen +class QROMAdj(): + QROM_Bloq: QROM = field() + num_ancilla: SymbolicInt = field(default=1) + mz_key: str = field(default="target_mzs") + + @num_ancilla.validator + def num_ancilla_validator(self, field, val): + if is_symbolic(val): + return + if not log2(val).is_integer(): + raise ValueError(f"num_ancilla must be a power of 2, but got {val}") + + def signature(self) -> Signature: + return self.QROM_Bloq.signature + + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] + ) -> cirq.OP_TREE: + num_targets = len(self.QROM_Bloq.target_registers) + for i in range(num_targets): + targets = quregs[f'target{i}'] + for target in targets: + yield cirq.H(target) + for i in range(num_targets): + targets = quregs[f'target{i}'] + for j, target in enumerate(targets): + yield cirq.measure(target, key=f"target_mzs") + ancillas = context.qubit_manager.qalloc(self.num_ancilla) + if len(self.QROM_Bloq.selection_registers) == 1: + selection_regs = quregs['selection'] + else: + selection_regs = [quregs[f"selection{i}"] for i in range(len(self.QROM_Bloq.selection_registers))] + selection_regs = selection_regs.flatten() + binary_int_size = int(log2(self.num_ancilla)) + yield OneHotEncoding(binary_int_size).on_registers(a=selection_regs[-binary_int_size:], b=ancillas) +