diff --git a/packages/circuits/lib/fp.circom b/packages/circuits/lib/fp.circom index e6059a2e..03ecd25f 100644 --- a/packages/circuits/lib/fp.circom +++ b/packages/circuits/lib/fp.circom @@ -12,7 +12,7 @@ include "./bigint-func.circom"; /// @param a Input 1 to FpMul; assumes to consist of `k` chunks, each of which must fit in `n` bits /// @param b Input 2 to FpMul; assumes to consist of `k` chunks, each of which must fit in `n` bits /// @param p The modulus; assumes to consist of `k` chunks, each of which must fit in `n` bits -/// @output out The result of the FpMul +/// @output out The result of the FpMul; asserted to be less than `p` template FpMul(n, k) { assert(n + n + log_ceil(k) + 2 <= 252); @@ -41,6 +41,7 @@ template FpMul(n, k) { component q_range_check[k]; signal r[k]; component r_range_check[k]; + component r_p_lt_check = BigLessThan(n,k); for (var i = 0; i < k; i++) { q[i] <-- long_div_out[0][i]; q_range_check[i] = Num2Bits(n); @@ -49,7 +50,11 @@ template FpMul(n, k) { r[i] <-- long_div_out[1][i]; r_range_check[i] = Num2Bits(n); r_range_check[i].in <== r[i]; + + r_p_lt_check.a[i] <== r[i]; + r_p_lt_check.b[i] <== p[i]; } + r_p_lt_check.out === 1; signal v_pq_r[2*k-1]; for (var x = 0; x < 2*k-1; x++) { diff --git a/packages/circuits/tests/fp-mul.test.ts b/packages/circuits/tests/fp-mul.test.ts new file mode 100644 index 00000000..0a454dde --- /dev/null +++ b/packages/circuits/tests/fp-mul.test.ts @@ -0,0 +1,66 @@ +import { wasm as wasm_tester } from 'circom_tester'; +import path from 'path'; + +describe('FpMul', () => { + let circuit1: any; + let circuit2: any; + + beforeAll(async () => { + circuit1 = await wasm_tester( + path.join( + __dirname, + './test-circuits/fp-mul-test.circom' + ), + { + recompile: true, + include: path.join(__dirname, '../../../node_modules'), + output: path.join(__dirname, './compiled-test-circuits'), + } + ); + circuit2 = await wasm_tester( + path.join( + __dirname, + './test-circuits/fp-mul-test-range-check.circom' + ), + { + recompile: true, + include: path.join(__dirname, '../../../node_modules'), + output: path.join(__dirname, './compiled-test-circuits'), + } + ); + }); + + it('should correctly match with the output', async () => { + const input = { + a: [1, 0, 1, 0], + b: [0, 1, 1, 0], + p: [1, 1, 1, 1] + }; + + const witness = await circuit1.calculateWitness(input); + await circuit1.checkConstraints(witness); + + await circuit1.assertOut(witness, { + out: [0, 0, 0, 0], + }); + }); + + it('should fail when r exceeds p', async () => { + const input = { + a: [4, 0], + b: [4, 0], + p: [5, 0], + q: [2, 0], + r: [6, 0] + }; + + expect.assertions(1); + try { + const witness = await circuit2.calculateWitness(input); + await circuit2.checkConstraints(witness); + } catch (error) { + expect((error as Error).message).toMatch("Assert Failed"); + } + }); + +}); diff --git a/packages/circuits/tests/test-circuits/fp-mul-test-range-check.circom b/packages/circuits/tests/test-circuits/fp-mul-test-range-check.circom new file mode 100644 index 00000000..21bc3659 --- /dev/null +++ b/packages/circuits/tests/test-circuits/fp-mul-test-range-check.circom @@ -0,0 +1,81 @@ +pragma circom 2.1.6; + + +include "circomlib/circuits/bitify.circom"; +include "circomlib/circuits/comparators.circom"; +include "circomlib/circuits/sign.circom"; +include "../../lib/bigint.circom"; +include "../../lib/bigint-func.circom"; + + +/// @title FpMul_TestRangeCheck +/// @notice Multiple two numbers in Fp, where q and r are also provided as inputs only for test purposes +/// @param a Input 1 to FpMul; assumes to consist of `k` chunks, each of which must fit in `n` bits +/// @param b Input 2 to FpMul; assumes to consist of `k` chunks, each of which must fit in `n` bits +/// @param p The modulus; assumes to consist of `k` chunks, each of which must fit in `n` bits +/// @param q The quotient; assumes to consist of `k` chunks, each of which must fit in `n` bits +/// @param r The remainder; assumes to consist of `k` chunks, each of which must fit in `n` bits +/// @output out The result of the FpMul +template FpMul_TestRangeCheck(n, k) { + assert(n + n + log_ceil(k) + 2 <= 252); + + signal input a[k]; + signal input b[k]; + signal input p[k]; + signal input q[k]; + signal input r[k]; + + signal output out[k]; + + signal v_ab[2*k-1]; + for (var x = 0; x < 2*k-1; x++) { + var v_a = poly_eval(k, a, x); + var v_b = poly_eval(k, b, x); + v_ab[x] <== v_a * v_b; + } + + + + // Since we're only computing a*b, we know that q < p will suffice, so we + // know it fits into k chunks and can do size n range checks. + component q_range_check[k]; + component r_range_check[k]; + component r_p_lt_check = BigLessThan(n,k); + for (var i = 0; i < k; i++) { + q_range_check[i] = Num2Bits(n); + q_range_check[i].in <== q[i]; + + r_range_check[i] = Num2Bits(n); + r_range_check[i].in <== r[i]; + + r_p_lt_check.a[i] <== r[i]; + r_p_lt_check.b[i] <== p[i]; + } + r_p_lt_check.out === 1; + + signal v_pq_r[2*k-1]; + for (var x = 0; x < 2*k-1; x++) { + var v_p = poly_eval(k, p, x); + var v_q = poly_eval(k, q, x); + var v_r = poly_eval(k, r, x); + v_pq_r[x] <== v_p * v_q + v_r; + } + + signal v_t[2*k-1]; + for (var x = 0; x < 2*k-1; x++) { + v_t[x] <== v_ab[x] - v_pq_r[x]; + } + + var t[200] = poly_interp(2*k-1, v_t); + component tCheck = CheckCarryToZero(n, n + n + log_ceil(k) + 2, 2*k-1); + for (var i = 0; i < 2*k-1; i++) { + tCheck.in[i] <== t[i]; + } + + for (var i = 0; i < k; i++) { + out[i] <== r[i]; + } +} + + +component main = FpMul_TestRangeCheck(4, 2); diff --git a/packages/circuits/tests/test-circuits/fp-mul-test.circom b/packages/circuits/tests/test-circuits/fp-mul-test.circom new file mode 100644 index 00000000..60142b46 --- /dev/null +++ b/packages/circuits/tests/test-circuits/fp-mul-test.circom @@ -0,0 +1,5 @@ +pragma circom 2.1.6; + +include "../../lib/fp.circom"; + +component main = FpMul(2, 4);