Skip to content

Commit

Permalink
Merge pull request #296 from Consensys/293-refactor-evaluable-to-retu…
Browse files Browse the repository at this point in the history
…rn-frelement

feat: refactor `Evaluable` to return `fr.Element`
  • Loading branch information
DavePearce authored Sep 5, 2024
2 parents 7b5d9c1 + 83c68c4 commit 553c3fa
Show file tree
Hide file tree
Showing 33 changed files with 225 additions and 257 deletions.
31 changes: 14 additions & 17 deletions pkg/air/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,57 +8,54 @@ import (
// EvalAt evaluates a column access at a given row in a trace, which returns the
// value at that row of the column in question or nil is that row is
// out-of-bounds.
func (e *ColumnAccess) EvalAt(k int, tr trace.Trace) *fr.Element {
func (e *ColumnAccess) EvalAt(k int, tr trace.Trace) fr.Element {
return tr.Column(e.Column).Get(k + e.Shift)
}

// EvalAt evaluates a constant at a given row in a trace, which simply returns
// that constant.
func (e *Constant) EvalAt(k int, tr trace.Trace) *fr.Element {
func (e *Constant) EvalAt(k int, tr trace.Trace) fr.Element {
return e.Value
}

// EvalAt evaluates a sum at a given row in a trace by first evaluating all of
// its arguments at that row.
func (e *Add) EvalAt(k int, tr trace.Trace) *fr.Element {
var val fr.Element
func (e *Add) EvalAt(k int, tr trace.Trace) fr.Element {
// Evaluate first argument
val.Set(e.Args[0].EvalAt(k, tr))
val := e.Args[0].EvalAt(k, tr)
// Continue evaluating the rest
for i := 1; i < len(e.Args); i++ {
ith := e.Args[i].EvalAt(k, tr)
val.Add(&val, ith)
val.Add(&val, &ith)
}
// Done
return &val
return val
}

// EvalAt evaluates a product at a given row in a trace by first evaluating all of
// its arguments at that row.
func (e *Mul) EvalAt(k int, tr trace.Trace) *fr.Element {
var val fr.Element
func (e *Mul) EvalAt(k int, tr trace.Trace) fr.Element {
// Evaluate first argument
val.Set(e.Args[0].EvalAt(k, tr))
val := e.Args[0].EvalAt(k, tr)
// Continue evaluating the rest
for i := 1; i < len(e.Args); i++ {
ith := e.Args[i].EvalAt(k, tr)
val.Mul(&val, ith)
val.Mul(&val, &ith)
}
// Done
return &val
return val
}

// EvalAt evaluates a subtraction at a given row in a trace by first evaluating all of
// its arguments at that row.
func (e *Sub) EvalAt(k int, tr trace.Trace) *fr.Element {
var val fr.Element
func (e *Sub) EvalAt(k int, tr trace.Trace) fr.Element {
// Evaluate first argument
val.Set(e.Args[0].EvalAt(k, tr))
val := e.Args[0].EvalAt(k, tr)
// Continue evaluating the rest
for i := 1; i < len(e.Args); i++ {
ith := e.Args[i].EvalAt(k, tr)
val.Sub(&val, ith)
val.Sub(&val, &ith)
}
// Done
return &val
return val
}
17 changes: 3 additions & 14 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,29 +152,18 @@ func (p *Mul) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }
// ============================================================================

// Constant represents a constant value within an expression.
type Constant struct{ Value *fr.Element }
type Constant struct{ Value fr.Element }

// NewConst construct an AIR expression representing a given constant.
func NewConst(val *fr.Element) Expr {
func NewConst(val fr.Element) Expr {
return &Constant{val}
}

// NewConst64 construct an AIR expression representing a given constant from a
// uint64.
func NewConst64(val uint64) Expr {
element := fr.NewElement(val)
return &Constant{&element}
}

// NewConstCopy construct an AIR expression representing a given constant,
// and also clones that constant.
func NewConstCopy(val *fr.Element) Expr {
// Create ith term (for final sum)
var clone fr.Element
// Clone coefficient
clone.Set(val)
// DOne
return &Constant{&clone}
return &Constant{element}
}

// Context determines the evaluation context (i.e. enclosing module) for this
Expand Down
2 changes: 1 addition & 1 deletion pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
// Construct Columns
for i := uint(0); i < n; i++ {
// Create Column + Constraint
es[i] = air.NewColumnAccess(index+i, 0).Mul(air.NewConstCopy(&coefficient))
es[i] = air.NewColumnAccess(index+i, 0).Mul(air.NewConst(coefficient))

schema.AddRangeConstraint(index+i, &fr256)
// Update coefficient
Expand Down
9 changes: 6 additions & 3 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ type Inverse struct{ Expr air.Expr }

// EvalAt computes the multiplicative inverse of a given expression at a given
// row in the table.
func (e *Inverse) EvalAt(k int, tbl tr.Trace) *fr.Element {
inv := new(fr.Element)
func (e *Inverse) EvalAt(k int, tbl tr.Trace) fr.Element {
var inv fr.Element

val := e.Expr.EvalAt(k, tbl)
// Go syntax huh?
return inv.Inverse(val)
inv.Inverse(&val)
// Done
return inv
}

// Bounds returns max shift in either the negative (left) or positive
Expand Down
4 changes: 2 additions & 2 deletions pkg/binfile/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ func (e *jsonExprConst) ToHir(schema *hir.Schema) hir.Expr {
panic(fmt.Sprintf("Unknown BigInt sign: %d", sign))
}
// Construct Field Value
num := new(fr.Element)
num.SetBigInt(val)
var num fr.Element

num.SetBigInt(val)
// Done!
return &hir.Constant{Val: num}
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"fmt"
"math"
"os"

"github.com/consensys/go-corset/pkg/hir"
Expand Down Expand Up @@ -243,7 +244,7 @@ func validateColumn(colType sc.Type, col trace.Column, mod sc.Module) error {
jth := col.Get(j)
if !colType.Accept(jth) {
qualColName := trace.QualifiedColumnName(mod.Name(), col.Name())
return fmt.Errorf("row %d of column %s is out-of-bounds (%s)", j, qualColName, jth)
return fmt.Errorf("row %d of column %s is out-of-bounds (%s)", j, qualColName, jth.String())
}
}
// success
Expand Down Expand Up @@ -284,7 +285,7 @@ func init() {
checkCmd.Flags().BoolP("quiet", "q", false, "suppress output (e.g. warnings)")
checkCmd.Flags().Bool("sequential", false, "perform sequential trace expansion")
checkCmd.Flags().Uint("padding", 0, "specify amount of (front) padding to apply")
checkCmd.Flags().UintP("batch", "b", 1000, "specify batch size for constraint checking")
checkCmd.Flags().UintP("batch", "b", math.MaxUint, "specify batch size for constraint checking")
checkCmd.Flags().Int("spillage", -1,
"specify amount of splillage to account for (where -1 indicates this should be inferred)")
}
3 changes: 2 additions & 1 deletion pkg/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ func printTrace(start uint, end uint, max_width uint, cols []trace.RawColumn) {
if start < ith.Len() {
ith_height := min(ith.Len(), end) - start
for j := uint(0); j < ith_height; j++ {
tbl.Set(j+1, i+1, ith.Get(j+start).String())
jth := ith.Get(j + start)
tbl.Set(j+1, i+1, jth.String())
}
}
}
Expand Down
64 changes: 28 additions & 36 deletions pkg/hir/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,40 @@ import (
// EvalAllAt evaluates a column access at a given row in a trace, which returns the
// value at that row of the column in question or nil is that row is
// out-of-bounds.
func (e *ColumnAccess) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
func (e *ColumnAccess) EvalAllAt(k int, tr trace.Trace) []fr.Element {
val := tr.Column(e.Column).Get(k + e.Shift)

var clone fr.Element
// Clone original value
return []*fr.Element{clone.Set(val)}
return []fr.Element{val}
}

// EvalAllAt evaluates a constant at a given row in a trace, which simply returns
// that constant.
func (e *Constant) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
var clone fr.Element
// Clone original value
return []*fr.Element{clone.Set(e.Val)}
func (e *Constant) EvalAllAt(k int, tr trace.Trace) []fr.Element {
return []fr.Element{e.Val}
}

// EvalAllAt evaluates a sum at a given row in a trace by first evaluating all of
// its arguments at that row.
func (e *Add) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
fn := func(l *fr.Element, r *fr.Element) { l.Add(l, r) }
func (e *Add) EvalAllAt(k int, tr trace.Trace) []fr.Element {
fn := func(l fr.Element, r fr.Element) fr.Element { l.Add(&l, &r); return l }
return evalExprsAt(k, tr, e.Args, fn)
}

// EvalAllAt evaluates a product at a given row in a trace by first evaluating all of
// its arguments at that row.
func (e *Mul) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
fn := func(l *fr.Element, r *fr.Element) { l.Mul(l, r) }
func (e *Mul) EvalAllAt(k int, tr trace.Trace) []fr.Element {
fn := func(l fr.Element, r fr.Element) fr.Element { l.Mul(&l, &r); return l }
return evalExprsAt(k, tr, e.Args, fn)
}

// EvalAllAt evaluates a product at a given row in a trace by first evaluating all of
// its arguments at that row.
func (e *Exp) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
func (e *Exp) EvalAllAt(k int, tr trace.Trace) []fr.Element {
vals := e.Arg.EvalAllAt(k, tr)
for _, v := range vals {
util.Pow(v, e.Pow)
for i := range vals {
util.Pow(&vals[i], e.Pow)
}

// Done
return vals
}

Expand All @@ -55,8 +51,8 @@ func (e *Exp) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
// (if applicable) is evaluated; otherwise if the condition is non-zero then
// false branch (if applicable) is evaluated). If the branch to be evaluated is
// missing (i.e. nil), then nil is returned.
func (e *IfZero) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
vals := make([]*fr.Element, 0)
func (e *IfZero) EvalAllAt(k int, tr trace.Trace) []fr.Element {
vals := make([]fr.Element, 0)
// Evaluate condition
conditions := e.Condition.EvalAllAt(k, tr)
// Check all results
Expand All @@ -73,8 +69,8 @@ func (e *IfZero) EvalAllAt(k int, tr trace.Trace) []*fr.Element {

// EvalAllAt evaluates a list at a given row in a trace by evaluating each of its
// arguments at that row.
func (e *List) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
vals := make([]*fr.Element, 0)
func (e *List) EvalAllAt(k int, tr trace.Trace) []fr.Element {
vals := make([]fr.Element, 0)

for _, e := range e.Args {
vs := e.EvalAllAt(k, tr)
Expand All @@ -87,13 +83,13 @@ func (e *List) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
// EvalAllAt evaluates the normalisation of some expression by first evaluating
// that expression. Then, zero is returned if the result is zero; otherwise one
// is returned.
func (e *Normalise) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
func (e *Normalise) EvalAllAt(k int, tr trace.Trace) []fr.Element {
// Check whether argument evaluates to zero or not.
vals := e.Arg.EvalAllAt(k, tr)
// Normalise values (as necessary)
for _, e := range vals {
if !e.IsZero() {
e.SetOne()
for i := range vals {
if !vals[i].IsZero() {
vals[i].SetOne()
}
}

Expand All @@ -102,14 +98,14 @@ func (e *Normalise) EvalAllAt(k int, tr trace.Trace) []*fr.Element {

// EvalAllAt evaluates a subtraction at a given row in a trace by first evaluating all of
// its arguments at that row.
func (e *Sub) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
fn := func(l *fr.Element, r *fr.Element) { l.Sub(l, r) }
func (e *Sub) EvalAllAt(k int, tr trace.Trace) []fr.Element {
fn := func(l fr.Element, r fr.Element) fr.Element { l.Sub(&l, &r); return l }
return evalExprsAt(k, tr, e.Args, fn)
}

// EvalExprsAt evaluates all expressions in a given slice at a given row on the
// table, and fold their results together using a combinator.
func evalExprsAt(k int, tr trace.Trace, exprs []Expr, fn func(*fr.Element, *fr.Element)) []*fr.Element {
func evalExprsAt(k int, tr trace.Trace, exprs []Expr, fn func(fr.Element, fr.Element) fr.Element) []fr.Element {
// Evaluate first argument.
vals := exprs[0].EvalAllAt(k, tr)

Expand All @@ -124,25 +120,21 @@ func evalExprsAt(k int, tr trace.Trace, exprs []Expr, fn func(*fr.Element, *fr.E
}

// Perform a vector operation using the given primitive operator "fn".
func evalExprsAtApply(lhs []*fr.Element, rhs []*fr.Element, fn func(*fr.Element, *fr.Element)) []*fr.Element {
func evalExprsAtApply(lhs []fr.Element, rhs []fr.Element, fn func(fr.Element, fr.Element) fr.Element) []fr.Element {
if len(rhs) == 1 {
// Optimise for common case.
for _, ith := range lhs {
fn(ith, rhs[0])
for i, ith := range lhs {
lhs[i] = fn(ith, rhs[0])
}

return lhs
}
// Harder case
vals := make([]*fr.Element, 0)
vals := make([]fr.Element, 0)
// Perform n x m operations
for _, ith := range lhs {
for _, jth := range rhs {
var clone fr.Element

clone.Set(ith)
fn(&clone, jth)
vals = append(vals, &clone)
vals = append(vals, fn(ith, jth))
}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/hir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type Expr interface {
// undefined for several reasons: firstly, if it accesses a
// row which does not exist (e.g. at index -1); secondly, if
// it accesses a column which does not exist.
EvalAllAt(int, trace.Trace) []*fr.Element
EvalAllAt(int, trace.Trace) []fr.Element
// String produces a string representing this as an S-Expression.
String() string
}
Expand Down Expand Up @@ -171,7 +171,7 @@ func (p *List) RequiredColumns() *util.SortedSet[uint] {
// ============================================================================

// Constant represents a constant value within an expression.
type Constant struct{ Val *fr.Element }
type Constant struct{ Val fr.Element }

// Bounds returns max shift in either the negative (left) or positive
// direction (right). A constant has zero shift.
Expand Down
5 changes: 1 addition & 4 deletions pkg/hir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,10 @@ func extractIfZeroCondition(e *IfZero, schema *mir.Schema) mir.Expr {
panic(fmt.Sprintf("unexpanded expression (%s)", e.String()))
} else if e.TrueBranch != nil {
// (1 - NORM(cb)) for true branch
one := new(fr.Element)
one.SetOne()

normBody := &mir.Normalise{Arg: cb}
oneMinusNormBody := &mir.Sub{
Args: []mir.Expr{
&mir.Constant{Value: one},
&mir.Constant{Value: fr.One()},
normBody,
},
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/hir/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,15 +443,15 @@ func beginParserRule(args []Expr) (Expr, error) {

func constantParserRule(symbol string) (Expr, bool, error) {
if symbol[0] >= '0' && symbol[0] < '9' {
num := new(fr.Element)
var num fr.Element
// Attempt to parse
c, err := num.SetString(symbol)
_, err := num.SetString(symbol)
// Check for errors
if err != nil {
return nil, true, err
}
// Done
return &Constant{Val: c}, true, nil
return &Constant{Val: num}, true, nil
}
// Not applicable
return nil, false, nil
Expand Down
4 changes: 2 additions & 2 deletions pkg/hir/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (p ZeroArrayTest) TestAt(row int, tr trace.Trace) bool {
vals := p.Expr.EvalAllAt(row, tr)
// Check each value in turn against zero.
for _, val := range vals {
if val != nil && !val.IsZero() {
if !val.IsZero() {
// This expression does not evaluat to zero, hence failure.
return false
}
Expand Down Expand Up @@ -84,7 +84,7 @@ func NewUnitExpr(expr Expr) UnitExpr {
// EvalAt evaluates a column access at a given row in a trace, which returns the
// value at that row of the column in question or nil is that row is
// out-of-bounds.
func (e UnitExpr) EvalAt(k int, tr trace.Trace) *fr.Element {
func (e UnitExpr) EvalAt(k int, tr trace.Trace) fr.Element {
vals := e.expr.EvalAllAt(k, tr)
// Check we got exactly one thing
if len(vals) == 1 {
Expand Down
Loading

0 comments on commit 553c3fa

Please sign in to comment.