diff --git a/qualtran/bloqs/factoring/ecc/ec_add.py b/qualtran/bloqs/factoring/ecc/ec_add.py index f764c70cb..a844f3a2b 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add.py +++ b/qualtran/bloqs/factoring/ecc/ec_add.py @@ -23,6 +23,7 @@ bloq_example, BloqBuilder, BloqDocSpec, + CtrlSpec, DecomposeTypeError, QBit, QMontgomeryUInt, @@ -255,10 +256,6 @@ def on_classical_vals( QMontgomeryUInt(self.n).montgomery_inverse(int(x), int(self.mod)), int(self.mod), ) - # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit - # which flips f1 when lam and lam_r are equal. - if lam == lam_r: - f1 = (f1 + 1) % 2 else: lam = 0 return {'f1': f1, 'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam, 'lam_r': lam_r} @@ -298,6 +295,12 @@ def build_composite_bloq( y=y, ) + # Allocate an ancilla qubit that acts as a flag for the rare condition that the + # pre-computed lambda_r is equal to the calculated lambda. This ancilla is used to properly + # clear the f1 qubit when lambda is set to lambda_r. + ancilla = bb.allocate() + z4, lam_r, ancilla = bb.add(Equals(QMontgomeryUInt(self.n)), x=z4, y=lam_r, target=ancilla) + # If ctrl = 1 and x != a: lam = (y - b) / (x - a) % p. z4_split = bb.split(z4) lam_split = bb.split(lam) @@ -325,7 +328,18 @@ def build_composite_bloq( lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) # If lam = lam_r: return f1 = 0. (If not we will flip f1 to 0 at the end iff x_r = y_r = 0). - lam, lam_r, f1 = bb.add(Equals(QMontgomeryUInt(self.n)), x=lam, y=lam_r, target=f1) + # Only flip when lam is set to lam_r. + ancilla, lam, lam_r, f1 = bb.add( + Equals(QMontgomeryUInt(self.n)).controlled(ctrl_spec=CtrlSpec(cvs=0)), + ctrl=ancilla, + x=lam, + y=lam_r, + target=f1, + ) + + # Clear the ancilla bit and free it. + z4, lam_r, ancilla = bb.add(Equals(QMontgomeryUInt(self.n)), x=z4, y=lam_r, target=ancilla) + bb.free(ancilla) # Uncompute the modular multiplication then the modular inversion. x, y = bb.add( @@ -345,7 +359,8 @@ def build_composite_bloq( def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: return { - Equals(QMontgomeryUInt(self.n)): 1, + Equals(QMontgomeryUInt(self.n)): 2, + Equals(QMontgomeryUInt(self.n)).controlled(ctrl_spec=CtrlSpec(cvs=0)): 1, ModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, KaliskiModInverse(bitsize=self.n, mod=self.mod): 1, @@ -654,6 +669,7 @@ class _ECAddStepFive(Bloq): will contain the x component of the resultant curve point. y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the y component of the resultant curve point. + lam_r: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form. lam: The lambda slope used in the addition operation. References: @@ -674,6 +690,7 @@ def signature(self) -> 'Signature': Register('b', QMontgomeryUInt(self.n)), Register('x', QMontgomeryUInt(self.n)), Register('y', QMontgomeryUInt(self.n)), + Register('lam_r', QMontgomeryUInt(self.n)), Register('lam', QMontgomeryUInt(self.n), side=Side.LEFT), ] ) @@ -685,6 +702,7 @@ def on_classical_vals( b: 'ClassicalValT', x: 'ClassicalValT', y: 'ClassicalValT', + lam_r: 'ClassicalValT', lam: 'ClassicalValT', ) -> Dict[str, 'ClassicalValT']: if ctrl == 1: @@ -692,7 +710,7 @@ def on_classical_vals( y = (y - b) % self.mod else: x = (x + a) % self.mod - return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y} + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r} def build_composite_bloq( self, @@ -702,6 +720,7 @@ def build_composite_bloq( b: Soquet, x: Soquet, y: Soquet, + lam_r: Soquet, lam: Soquet, ) -> Dict[str, 'SoquetT']: if is_symbolic(self.n): @@ -731,9 +750,31 @@ def build_composite_bloq( z4_split[i] = ctrls[1] z4 = bb.join(z4_split, dtype=QMontgomeryUInt(self.n)) lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) - # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit where lambda - # is not set to 0 before being freed. - bb.add(Free(QMontgomeryUInt(self.n), dirty=True), reg=lam) + + # If the denominator of lambda is 0, lam = lam_r so we clear lam with lam_r. + ancilla = bb.allocate() + x_split = bb.split(x) + x_split, ancilla = bb.add( + MultiControlX(cvs=[0] * int(self.n)), controls=x_split, target=ancilla + ) + lam_r_split = bb.split(lam_r) + lam_split = bb.split(lam) + for i in range(int(self.n)): + ctrls = [ctrl, ancilla, lam_r_split[i]] + ctrls, lam_split[i] = bb.add( + MultiControlX(cvs=[1, 1, 1]), controls=ctrls, target=lam_split[i] + ) + ctrl = ctrls[0] + ancilla = ctrls[1] + lam_r_split[i] = ctrls[2] + lam_r = bb.join(lam_r_split, dtype=QMontgomeryUInt(self.n)) + lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) + x_split, ancilla = bb.add( + MultiControlX(cvs=[0] * int(self.n)), controls=x_split, target=ancilla + ) + x = bb.join(x_split, dtype=QMontgomeryUInt(self.n)) + bb.free(ancilla) + bb.add(Free(QMontgomeryUInt(self.n)), reg=lam) # Uncompute multiplication and inverse. x, y = bb.add( @@ -758,9 +799,14 @@ def build_composite_bloq( ctrl, b, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=b, y=y) # Return the output registers. - return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y} + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r} def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + cvs: Union[list[int], HasLength] + if isinstance(self.n, int): + cvs = [0] * self.n + else: + cvs = HasLength(self.n) return { CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, KaliskiModInverse(bitsize=self.n, mod=self.mod): 1, @@ -773,6 +819,8 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: KaliskiModInverse(bitsize=self.n, mod=self.mod).adjoint(): 1, ModAdd(self.n, mod=self.mod): 1, MultiControlX(cvs=[1, 1]): self.n, + MultiControlX(cvs=cvs): 2, + MultiControlX(cvs=[1, 1, 1]): self.n, CModNeg(QMontgomeryUInt(self.n), mod=self.mod): 1, } @@ -865,6 +913,21 @@ def build_composite_bloq( f3 = f_ctrls[1] f4 = f_ctrls[2] + # Unset f2 if ((a, b) = (0, 0) AND y = 0) OR ((x, y) = (0, 0) AND b = 0). + aby_arr = np.concatenate([bb.split(a), bb.split(b), bb.split(y)]) + aby_arr, f2 = bb.add(MultiControlX(cvs=[0] * 3 * self.n), controls=aby_arr, target=f2) + aby_arr = np.split(aby_arr, 3) + a = bb.join(aby_arr[0], dtype=QMontgomeryUInt(self.n)) + b = bb.join(aby_arr[1], dtype=QMontgomeryUInt(self.n)) + y = bb.join(aby_arr[2], dtype=QMontgomeryUInt(self.n)) + + xyb_arr = np.concatenate([bb.split(x), bb.split(y), bb.split(b)]) + xyb_arr, f2 = bb.add(MultiControlX(cvs=[0] * 3 * self.n), controls=xyb_arr, target=f2) + xyb_arr = np.split(xyb_arr, 3) + x = bb.join(xyb_arr[0], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xyb_arr[1], dtype=QMontgomeryUInt(self.n)) + b = bb.join(xyb_arr[2], dtype=QMontgomeryUInt(self.n)) + # Set (x, y) to (a, b) if f4 is set. a_split = bb.split(a) x_split = bb.split(x) @@ -885,24 +948,6 @@ def build_composite_bloq( b = bb.join(b_split, QMontgomeryUInt(self.n)) y = bb.join(y_split, QMontgomeryUInt(self.n)) - # Unset f4 if (x, y) = (a, b). - ab = bb.join(np.concatenate([bb.split(a), bb.split(b)]), dtype=QMontgomeryUInt(2 * self.n)) - xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n)) - ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4) - ab_split = bb.split(ab) - a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n)) - b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n)) - xy_split = bb.split(xy) - x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n)) - y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n)) - - # Unset f3 if (a, b) = (0, 0). - ab_arr = np.concatenate([bb.split(a), bb.split(b)]) - ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3) - ab_arr = np.split(ab_arr, 2) - a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n)) - b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n)) - # If f1 and f2 are set, subtract a from x and add b to y. ancilla = bb.add(ZeroState()) toff_ctrl = [f1, f2] @@ -925,6 +970,24 @@ def build_composite_bloq( f2 = toff_ctrl[1] bb.add(Free(QBit()), reg=ancilla) + # Unset f4 if (x, y) = (a, b). + ab = bb.join(np.concatenate([bb.split(a), bb.split(b)]), dtype=QMontgomeryUInt(2 * self.n)) + xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n)) + ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4) + ab_split = bb.split(ab) + a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n)) + xy_split = bb.split(xy) + x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n)) + + # Unset f3 if (a, b) = (0, 0). + ab_arr = np.concatenate([bb.split(a), bb.split(b)]) + ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3) + ab_arr = np.split(ab_arr, 2) + a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n)) + # Unset f1 and f2 if (x, y) = (0, 0). xy_arr = np.concatenate([bb.split(x), bb.split(y)]) xy_arr, junk, out = bb.add(MultiAnd(cvs=[0] * 2 * self.n), ctrl=xy_arr) @@ -941,33 +1004,35 @@ def build_composite_bloq( y = bb.join(xy_arr[1], dtype=QMontgomeryUInt(self.n)) # Free all ancilla qubits in the zero state. - # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bugs in circuit where f1, - # f2, and f4 are freed before being set to 0. - bb.add(Free(QBit(), dirty=True), reg=f1) - bb.add(Free(QBit(), dirty=True), reg=f2) + bb.add(Free(QBit()), reg=f1) + bb.add(Free(QBit()), reg=f2) bb.add(Free(QBit()), reg=f3) - bb.add(Free(QBit(), dirty=True), reg=f4) + bb.add(Free(QBit()), reg=f4) bb.add(Free(QBit()), reg=ctrl) # Return the output registers. return {'a': a, 'b': b, 'x': x, 'y': y} def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: - cvs: Union[list[int], HasLength] + cvs2: Union[list[int], HasLength] + cvs3: Union[list[int], HasLength] if isinstance(self.n, int): - cvs = [0] * 2 * self.n + cvs2 = [0] * 2 * self.n + cvs3 = [0] * 3 * self.n else: - cvs = HasLength(2 * self.n) + cvs2 = HasLength(2 * self.n) + cvs3 = HasLength(3 * self.n) return { - MultiControlX(cvs=cvs): 1, + MultiControlX(cvs=cvs2): 1, + MultiControlX(cvs=cvs3): 2, MultiControlX(cvs=[0] * 3): 1, CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, CModAdd(QMontgomeryUInt(self.n), mod=self.mod): 1, Toffoli(): 2 * self.n + 4, Equals(QMontgomeryUInt(2 * self.n)): 1, - MultiAnd(cvs=cvs): 1, + MultiAnd(cvs=cvs2): 1, MultiTargetCNOT(2): 1, - MultiAnd(cvs=cvs).adjoint(): 1, + MultiAnd(cvs=cvs2).adjoint(): 1, } @@ -1046,13 +1111,14 @@ def build_composite_bloq( x, y, lam = bb.add( _ECAddStepFour(n=self.n, mod=self.mod, window_size=self.window_size), x=x, y=y, lam=lam ) - ctrl, a, b, x, y = bb.add( + ctrl, a, b, x, y, lam_r = bb.add( _ECAddStepFive(n=self.n, mod=self.mod, window_size=self.window_size), ctrl=ctrl, a=a, b=b, x=x, y=y, + lam_r=lam_r, lam=lam, ) a, b, x, y = bb.add( diff --git a/qualtran/bloqs/factoring/ecc/ec_add_test.py b/qualtran/bloqs/factoring/ecc/ec_add_test.py index 7f7439d2d..e8fc66fa9 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add_test.py +++ b/qualtran/bloqs/factoring/ecc/ec_add_test.py @@ -110,6 +110,7 @@ def test_ec_add_steps_classical_fast(n, m, a, b, x, y): b=step_3['b'], x=step_4['x'], y=step_4['y'], + lam_r=step_2['lam_r'], lam=step_4['lam'], ) ret2 = bloq.decompose_bloq().call_classically( @@ -118,6 +119,7 @@ def test_ec_add_steps_classical_fast(n, m, a, b, x, y): b=step_3['b'], x=step_4['x'], y=step_4['y'], + lam_r=step_2['lam_r'], lam=step_4['lam'], ) assert ret1 == ret2 @@ -128,6 +130,7 @@ def test_ec_add_steps_classical_fast(n, m, a, b, x, y): b=step_3['b'], x=step_4['x'], y=step_4['y'], + lam_r=step_2['lam_r'], lam=step_4['lam'], ) bloq = _ECAddStepSix(n=n, mod=p) @@ -250,6 +253,7 @@ def test_ec_add_steps_classical(n, m, a, b, x, y): b=step_3['b'], x=step_4['x'], y=step_4['y'], + lam_r=step_2['lam_r'], lam=step_4['lam'], ) ret2 = bloq.decompose_bloq().call_classically( @@ -258,6 +262,7 @@ def test_ec_add_steps_classical(n, m, a, b, x, y): b=step_3['b'], x=step_4['x'], y=step_4['y'], + lam_r=step_2['lam_r'], lam=step_4['lam'], ) assert ret1 == ret2 @@ -268,6 +273,7 @@ def test_ec_add_steps_classical(n, m, a, b, x, y): b=step_3['b'], x=step_4['x'], y=step_4['y'], + lam_r=step_2['lam_r'], lam=step_4['lam'], ) bloq = _ECAddStepSix(n=n, mod=p) @@ -413,12 +419,12 @@ def test_ec_add_symbolic_cost(): # Litinski 2023 https://arxiv.org/abs/2306.08585 # Based on the counts from Figures 3, 5, and 8 the toffoli count for ECAdd is 126.5n^2 + 189n. - # The following formula is 126.5n^2 + 195.5n - 31. We account for the discrepancy in the + # The following formula is 126.5n^2 + 217.5n - 36. We account for the discrepancy in the # coefficient of n by a reduction in the toffoli cost of Montgomery ModMult, an increase in the # toffoli cost for Kaliski Mod Inverse, n extra toffolis in ModNeg, 2n extra toffolis to do n - # 3-controlled toffolis in step 2. The expression is written with rationals because sympy - # comparison fails with floats. - assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(407, 2) * n - 31 + # 3-controlled toffolis in step 2, and a few extra gates added to fix bugs found in the + # circuit. The expression is written with rationals because sympy comparison fails with floats. + assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(435, 2) * n - 36 def test_ec_add(bloq_autotester):