From e116a82bc8a4978163249fe07a706a07cc3dac88 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 12 Nov 2021 17:49:47 -0600 Subject: [PATCH] fix: incorrect handling of hints in r1cs solver --- internal/backend/bls12-377/cs/r1cs.go | 7 ++++++- internal/backend/bls12-381/cs/r1cs.go | 7 ++++++- internal/backend/bls24-315/cs/r1cs.go | 7 ++++++- internal/backend/bn254/cs/r1cs.go | 7 ++++++- internal/backend/bw6-761/cs/r1cs.go | 7 ++++++- .../backend/template/representations/r1cs.go.tmpl | 7 ++++++- 6 files changed, 36 insertions(+), 6 deletions(-) diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index c0c39f92ac..e21c870eae 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -208,7 +208,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // first we check if this is a hint wire if hint, ok := cs.MHints[vID]; ok { - return solution.solveWithHint(vID, hint) + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + v := solution.computeTerm(t) + val.Add(val, &v) + return nil } if loc != 0 { diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index fd5065ce66..ba9333b1e4 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -208,7 +208,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // first we check if this is a hint wire if hint, ok := cs.MHints[vID]; ok { - return solution.solveWithHint(vID, hint) + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + v := solution.computeTerm(t) + val.Add(val, &v) + return nil } if loc != 0 { diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index ddbab50d06..bd726a6152 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -208,7 +208,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // first we check if this is a hint wire if hint, ok := cs.MHints[vID]; ok { - return solution.solveWithHint(vID, hint) + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + v := solution.computeTerm(t) + val.Add(val, &v) + return nil } if loc != 0 { diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index 8314157dee..50823245b8 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -208,7 +208,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // first we check if this is a hint wire if hint, ok := cs.MHints[vID]; ok { - return solution.solveWithHint(vID, hint) + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + v := solution.computeTerm(t) + val.Add(val, &v) + return nil } if loc != 0 { diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index d672fcbd03..37b383b43c 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -208,7 +208,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // first we check if this is a hint wire if hint, ok := cs.MHints[vID]; ok { - return solution.solveWithHint(vID, hint) + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + v := solution.computeTerm(t) + val.Add(val, &v) + return nil } if loc != 0 { diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index 5e0574a76b..a584b052d9 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -202,7 +202,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // first we check if this is a hint wire if hint, ok := cs.MHints[vID]; ok { - return solution.solveWithHint(vID, hint) + if err := solution.solveWithHint(vID, hint); err != nil { + return err + } + v := solution.computeTerm(t) + val.Add(val, &v) + return nil } if loc != 0 {