Skip to content

Commit

Permalink
Support Assignment.ComputedColumn
Browse files Browse the repository at this point in the history
This replaces the previous `Assignment.ExpandTrace` method which
directly added columns to the trace.  Instead, computed columns are
simply returned and its left upto the caller to add them to the trace.
For now, I've put in place some hacky glue code which does this ... and
passes the tests.

The next step I think is to package everything to do with trace
expansion into a trace builder.  That will provide a single point where
things will break related to the trace API.  In particular, the goal is
to dramatically simplify the notion of a trace to just a pair of
([]Module,[]Column) arrays.
  • Loading branch information
DavePearce committed Jul 30, 2024
1 parent fa79b0d commit 7bdd87e
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 130 deletions.
27 changes: 12 additions & 15 deletions pkg/schema/assignment/byte_decomposition.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,38 +59,35 @@ func (p *ByteDecomposition) IsComputed() bool {
// Assignment Interface
// ============================================================================

// ExpandTrace expands a given trace to include the columns specified by a given
// ByteDecomposition. This requires computing the value of each byte column in
// the decomposition.
func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error {
// ComputeColumns computes the values of columns defined by this assignment.
// This requires computing the value of each byte column in the decomposition.
func (p *ByteDecomposition) ComputeColumns(tr trace.Trace) ([]*trace.Column, error) {
columns := tr.Columns()
// Calculate how many bytes required.
n := len(p.targets)
// Identify source column
source := columns.Get(p.source)
// Determine padding values
padding := decomposeIntoBytes(source.Padding(), n)
// Construct byte column data
cols := make([]util.FrArray, n)
cols := make([]*trace.Column, n)
// Initialise columns
for i := 0; i < n; i++ {
ith := p.targets[i]
// Construct a byte array for ith byte
data := util.NewFrArray(source.Height(), 8)
// Construct a byte column for ith byte
cols[i] = util.NewFrArray(source.Height(), 8)
cols[i] = trace.NewColumn(ith.Context(), ith.Name(), data, padding[i])
}
// Decompose each row of each column
for i := uint(0); i < source.Height(); i = i + 1 {
ith := decomposeIntoBytes(source.Get(int(i)), n)
for j := 0; j < n; j++ {
cols[j].Set(i, ith[j])
cols[j].Data().Set(i, ith[j])
}
}
// Determine padding values
padding := decomposeIntoBytes(source.Padding(), n)
// Finally, add byte columns to trace
for i := 0; i < n; i++ {
ith := p.targets[i]
columns.Add(ith.Context(), ith.Name(), cols[i], padding[i])
}
// Done
return nil
return cols, nil
}

// RequiredSpillage returns the minimum amount of spillage required to ensure
Expand Down
16 changes: 8 additions & 8 deletions pkg/schema/assignment/computed_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ func (p *ComputedColumn[E]) RequiredSpillage() uint {
return p.expr.Bounds().End
}

// ExpandTrace attempts to a new column to the trace which contains the result
// of evaluating a given expression on each row. If the column already exists,
// then an error is flagged.
func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error {
// ComputeColumns computes the values of columns defined by this assignment.
// Specifically, this creates a new column which contains the result of
// evaluating a given expression on each row.
func (p *ComputedColumn[E]) ComputeColumns(tr trace.Trace) ([]*trace.Column, error) {
columns := tr.Columns()
// Check whether a column already exists with the given name.
if _, ok := columns.IndexOf(p.target.Context().Module(), p.Name()); ok {
mod := tr.Modules().Get(p.target.Context().Module())
return fmt.Errorf("computed column already exists ({%s.%s})", mod.Name(), p.Name())
return nil, fmt.Errorf("computed column already exists ({%s.%s})", mod.Name(), p.Name())
}
// Extract length multipiler
multiplier := p.target.Context().LengthMultiplier()
Expand All @@ -102,10 +102,10 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error {
// that all columns return their padding value which is then used to compute
// the padding value for *this* column.
padding := p.expr.EvalAt(-1, tr)
// Colunm needs to be expanded.
columns.Add(p.target.Context(), p.Name(), data, padding)
// Construct column
col := trace.NewColumn(p.target.Context(), p.Name(), data, padding)
// Done
return nil
return []*trace.Column{col}, nil
}

// Dependencies returns the set of columns that this assignment depends upon.
Expand Down
15 changes: 8 additions & 7 deletions pkg/schema/assignment/interleave.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

"github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
tr "github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)
Expand Down Expand Up @@ -65,10 +66,10 @@ func (p *Interleaving) RequiredSpillage() uint {
return uint(0)
}

// ExpandTrace expands a given trace to include the columns specified by a given
// Interleaving. This requires copying the data in the source columns to create
// the interleaved column.
func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
// ComputeColumns computes the values of columns defined by this assignment.
// This requires copying the data in the source columns to create the
// interleaved column.
func (p *Interleaving) ComputeColumns(tr trace.Trace) ([]*trace.Column, error) {
columns := tr.Columns()
ctx := p.target.Context()
// Byte width records the largest width of any column.
Expand All @@ -80,7 +81,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
bit_width = max(bit_width, ith.Type().BitWidth())
// Sanity check no column already exists with this name.
if _, ok := columns.IndexOf(ctx.Module(), ith.Name()); ok {
return fmt.Errorf("interleaved column already exists ({%s})", ith.Name())
return nil, fmt.Errorf("interleaved column already exists ({%s})", ith.Name())
}
}
// Determine interleaving width
Expand Down Expand Up @@ -110,9 +111,9 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
// column in the interleaving.
padding := columns.Get(0).Padding()
// Colunm needs to be expanded.
columns.Add(ctx, p.target.Name(), data, padding)
col := trace.NewColumn(ctx, p.target.Name(), data, padding)
//
return nil
return []*trace.Column{col}, nil
}

// Dependencies returns the set of columns that this assignment depends upon.
Expand Down
47 changes: 21 additions & 26 deletions pkg/schema/assignment/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,46 +67,49 @@ func (p *LexicographicSort) RequiredSpillage() uint {
return uint(0)
}

// ExpandTrace adds columns as needed to support the LexicographicSortingGadget.
// That includes the delta column, and the bit selectors.
func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error {
// ComputeColumns computes the values of columns defined as needed to support
// the LexicographicSortingGadget. That includes the delta column, and the bit
// selectors.
func (p *LexicographicSort) ComputeColumns(tr trace.Trace) ([]*trace.Column, error) {
columns := tr.Columns()
zero := fr.NewElement(0)
one := fr.NewElement(1)
first := p.targets[0]
// Exact number of columns involved in the sort
ncols := len(p.sources)
nbits := len(p.sources)
//
multiplier := p.context.LengthMultiplier()
// Determine how many rows to be constrained.
nrows := tr.Modules().Get(p.context.Module()).Height() * multiplier
// Initialise new data columns
bit := make([]util.FrArray, ncols)
cols := make([]*trace.Column, nbits+1)
// Byte width records the largest width of any column.
bit_width := uint(0)

for i := 0; i < ncols; i++ {
// TODO: following can be optimised to use a single bit per element,
// rather than an entire byte.
bit[i] = util.NewFrArray(nrows, 1)
ith := columns.Get(p.sources[i])
bit_width = max(bit_width, ith.Data().BitWidth())
}

//
delta := util.NewFrArray(nrows, bit_width)
cols[0] = trace.NewColumn(first.Context(), first.Name(), delta, &zero)
//
for i := 0; i < nbits; i++ {
target := p.targets[1+i]
source := columns.Get(p.sources[i])
data := util.NewFrArray(nrows, 1)
cols[i+1] = trace.NewColumn(target.Context(), target.Name(), data, &zero)
bit_width = max(bit_width, source.Data().BitWidth())
}

for i := uint(0); i < nrows; i++ {
set := false
// Initialise delta to zero
delta.Set(i, &zero)
// Decide which row is the winner (if any)
for j := 0; j < ncols; j++ {
for j := 0; j < nbits; j++ {
prev := columns.Get(p.sources[j]).Get(int(i - 1))
curr := columns.Get(p.sources[j]).Get(int(i))

if !set && prev != nil && prev.Cmp(curr) != 0 {
var diff fr.Element

bit[j].Set(i, &one)
cols[j+1].Data().Set(i, &one)
// Compute curr - prev
if p.signs[j] {
diff.Set(curr)
Expand All @@ -118,20 +121,12 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error {

set = true
} else {
bit[j].Set(i, &zero)
cols[j+1].Data().Set(i, &zero)
}
}
}
// Add delta column data
first := p.targets[0]
columns.Add(first.Context(), first.Name(), delta, &zero)
// Add bit column data
for i := 0; i < ncols; i++ {
ith := p.targets[1+i]
columns.Add(ith.Context(), ith.Name(), bit[i], &zero)
}
// Done.
return nil
return cols, nil
}

// Dependencies returns the set of columns that this assignment depends upon.
Expand Down
34 changes: 17 additions & 17 deletions pkg/schema/assignment/sorted_permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,42 +117,42 @@ func (p *SortedPermutation) RequiredSpillage() uint {
return uint(0)
}

// ExpandTrace expands a given trace to include the columns specified by a given
// SortedPermutation. This requires copying the data in the source columns, and
// sorting that data according to the permutation criteria.
func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error {
// ComputeColumns computes the values of columns defined by this assignment.
// This requires copying the data in the source columns, and sorting that data
// according to the permutation criteria.
func (p *SortedPermutation) ComputeColumns(tr trace.Trace) ([]*trace.Column, error) {
columns := tr.Columns()
// Ensure target columns don't exist
for i := p.Columns(); i.HasNext(); {
name := i.Next().Name()
// Sanity check no column already exists with this name.
if _, ok := columns.IndexOf(p.context.Module(), name); ok {
return fmt.Errorf("permutation column already exists ({%s})", name)
return nil, fmt.Errorf("permutation column already exists ({%s})", name)
}
}

cols := make([]util.FrArray, len(p.sources))
data := make([]util.FrArray, len(p.sources))
// Construct target columns
for i := 0; i < len(p.sources); i++ {
src := p.sources[i]
// Read column data
data := columns.Get(src).Data()
src_data := columns.Get(src).Data()
// Clone it to initialise permutation.
cols[i] = data.Clone()
data[i] = src_data.Clone()
}
// Sort target columns
util.PermutationSort(cols, p.signs)
// Physically add the columns
index := 0

for i := p.Columns(); i.HasNext(); index++ {
ith := i.Next()
util.PermutationSort(data, p.signs)
// Physically construct the columns
cols := make([]*trace.Column, len(p.sources))
//
for i, iter := 0, p.Columns(); iter.HasNext(); i++ {
ith := iter.Next()
dstColName := ith.Name()
srcCol := tr.Columns().Get(p.sources[index])
columns.Add(ith.Context(), dstColName, cols[index], srcCol.Padding())
srcCol := tr.Columns().Get(p.sources[i])
cols[i] = trace.NewColumn(ith.Context(), dstColName, data[i], srcCol.Padding())
}
//
return nil
return cols, nil
}

// Dependencies returns the set of columns that this assignment depends upon.
Expand Down
12 changes: 7 additions & 5 deletions pkg/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
tr "github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)
Expand Down Expand Up @@ -53,11 +54,12 @@ type Declaration interface {
type Assignment interface {
Declaration

// ExpandTrace expands a given trace to include "computed
// columns". These are columns which do not exist in the
// original trace, but are added during trace expansion to
// form the final trace.
ExpandTrace(tr.Trace) error
// ComputeColumns computes the values of columns defined by this assignment.
// In order for this computation to makes sense, all columns on which this
// assignment depends must exist (e.g. are either inputs or have been
// computed already). Computed columns do not exist in the original trace,
// but are added during trace expansion to form the final trace.
ComputeColumns(tr.Trace) ([]*trace.Column, error)
// RequiredSpillage returns the minimum amount of spillage required to ensure
// valid traces are accepted in the presence of arbitrary padding. Note,
// spillage is currently assumed to be required only at the front of a
Expand Down
67 changes: 15 additions & 52 deletions pkg/schema/schemas.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package schema

import (
"fmt"

tr "github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

// JoinContexts combines one or more evaluation contexts together. If all
Expand Down Expand Up @@ -64,19 +61,30 @@ func RequiredSpillage(schema Schema) uint {
// Observe that assignments have to be computed in the correct order.
func ExpandTrace(schema Schema, trace tr.Trace) error {
index := schema.InputColumns().Count()
m := schema.Assignments().Count()
batchjobs := make([]expandTraceJob, m)
//m := schema.Assignments().Count()
//batchjobs := make([]expandTraceJob, m)
// Compute each assignment in turn
for i, j := schema.Assignments(), uint(0); i.HasNext(); j++ {
// Get ith assignment
ith := i.Next()
// Compute ith assignment(s)
batchjobs[j] = expandTraceJob{index, ith, trace}
//batchjobs[j] = expandTraceJob{index, ith, trace}
cols, err := ith.ComputeColumns(trace)
// Check error
if err != nil {
return err
}
// Add all columns
for k := 0; k < len(cols); k++ {
kth := cols[k]
trace.Columns().Add(kth.Context(), kth.Name(), kth.Data(), kth.Padding())
}
// Update index
index += ith.Columns().Count()
}
//
return util.ParExec[expandTraceJob](batchjobs)
// return util.ParExec[expandTraceJob](batchjobs)
return nil
}

// Accepts determines whether this schema will accept a given trace. That
Expand Down Expand Up @@ -119,48 +127,3 @@ func ColumnIndexOf(schema Schema, module uint, name string) (uint, bool) {
return c.Context().Module() == module && c.Name() == name
})
}

// ----------------------------------------------------------------------------

// ExpandTraceJob represents a unit of work which can be parallelised during
// trace expansion. N Specifically, the unit of work is a single assignment. In
// the terminology of ParExec, an assignment is a batch of jobs, each of which
// assigns values to a given column.
type expandTraceJob struct {
// Index of first column in the assignment
index uint
// Assignment itself being computed
assignment Assignment
// Trace being expanded
trace tr.Trace
}

// Jobs returns the underlying jobs that this "trace job" computes.
// Specifically, it identifies the columns that the trace job completes.
func (p expandTraceJob) Jobs() []uint {
n := p.assignment.Columns().Count()
cols := make([]uint, n)
// TODO: this is really a hack for now. I think we should use some kind of
// range iterator.
for i := uint(0); i < n; i++ {
cols[i] = p.index + i
}

return cols
}

// Dependencies returns the columns that this batch job depends upon. That is
// the set of source columns (if any) required before this job can be completed.
func (p expandTraceJob) Dependencies() []uint {
return p.assignment.Dependencies()
}

// Run computes the coumns values for a given assignment.
func (p expandTraceJob) Run() error {
n := p.trace.Columns().Len()
if n != p.index {
panic(fmt.Sprintf("internal failure (%d trace columns, versus index %d)", n, p.index))
}
//
return p.assignment.ExpandTrace(p.trace)
}

0 comments on commit 7bdd87e

Please sign in to comment.