Skip to content

Commit

Permalink
Merge pull request #269 from Consensys/267-expandtrace-should-return-…
Browse files Browse the repository at this point in the history
…computed-columns

feat:  ExpandTrace should Return Computed Columns
  • Loading branch information
DavePearce authored Jul 30, 2024
2 parents 2a96b89 + 7bdd87e commit fe243d6
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 249 deletions.
10 changes: 5 additions & 5 deletions pkg/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func summaryStats(tr trace.Trace) {
// contents of a given column.
type ColSummariser struct {
name string
summary func(trace.Column) string
summary func(*trace.Column) string
}

var colSummarisers []ColSummariser = []ColSummariser{
Expand All @@ -187,15 +187,15 @@ var colSummarisers []ColSummariser = []ColSummariser{
{"unique", uniqueSummariser},
}

func rowSummariser(col trace.Column) string {
func rowSummariser(col *trace.Column) string {
return fmt.Sprintf("%d rows", col.Data().Len())
}

func widthSummariser(col trace.Column) string {
func widthSummariser(col *trace.Column) string {
return fmt.Sprintf("%d bits", col.Data().BitWidth())
}

func bytesSummariser(col trace.Column) string {
func bytesSummariser(col *trace.Column) string {
bitwidth := col.Data().BitWidth()
byteWidth := bitwidth / 8
// Determine proper bytewidth
Expand All @@ -206,7 +206,7 @@ func bytesSummariser(col trace.Column) string {
return fmt.Sprintf("%d bytes", col.Data().Len()*byteWidth)
}

func uniqueSummariser(col trace.Column) string {
func uniqueSummariser(col *trace.Column) string {
data := col.Data()
elems := util.NewHashSet[util.BytesKey](data.Len() / 2)
// Add all the elements
Expand Down
27 changes: 12 additions & 15 deletions pkg/schema/assignment/byte_decomposition.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,38 +59,35 @@ func (p *ByteDecomposition) IsComputed() bool {
// Assignment Interface
// ============================================================================

// ExpandTrace expands a given trace to include the columns specified by a given
// ByteDecomposition. This requires computing the value of each byte column in
// the decomposition.
func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error {
// ComputeColumns computes the values of columns defined by this assignment.
// This requires computing the value of each byte column in the decomposition.
func (p *ByteDecomposition) ComputeColumns(tr trace.Trace) ([]*trace.Column, error) {
columns := tr.Columns()
// Calculate how many bytes required.
n := len(p.targets)
// Identify source column
source := columns.Get(p.source)
// Determine padding values
padding := decomposeIntoBytes(source.Padding(), n)
// Construct byte column data
cols := make([]util.FrArray, n)
cols := make([]*trace.Column, n)
// Initialise columns
for i := 0; i < n; i++ {
ith := p.targets[i]
// Construct a byte array for ith byte
data := util.NewFrArray(source.Height(), 8)
// Construct a byte column for ith byte
cols[i] = util.NewFrArray(source.Height(), 8)
cols[i] = trace.NewColumn(ith.Context(), ith.Name(), data, padding[i])
}
// Decompose each row of each column
for i := uint(0); i < source.Height(); i = i + 1 {
ith := decomposeIntoBytes(source.Get(int(i)), n)
for j := 0; j < n; j++ {
cols[j].Set(i, ith[j])
cols[j].Data().Set(i, ith[j])
}
}
// Determine padding values
padding := decomposeIntoBytes(source.Padding(), n)
// Finally, add byte columns to trace
for i := 0; i < n; i++ {
ith := p.targets[i]
columns.Add(ith.Context(), ith.Name(), cols[i], padding[i])
}
// Done
return nil
return cols, nil
}

// RequiredSpillage returns the minimum amount of spillage required to ensure
Expand Down
16 changes: 8 additions & 8 deletions pkg/schema/assignment/computed_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ func (p *ComputedColumn[E]) RequiredSpillage() uint {
return p.expr.Bounds().End
}

// ExpandTrace attempts to a new column to the trace which contains the result
// of evaluating a given expression on each row. If the column already exists,
// then an error is flagged.
func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error {
// ComputeColumns computes the values of columns defined by this assignment.
// Specifically, this creates a new column which contains the result of
// evaluating a given expression on each row.
func (p *ComputedColumn[E]) ComputeColumns(tr trace.Trace) ([]*trace.Column, error) {
columns := tr.Columns()
// Check whether a column already exists with the given name.
if _, ok := columns.IndexOf(p.target.Context().Module(), p.Name()); ok {
mod := tr.Modules().Get(p.target.Context().Module())
return fmt.Errorf("computed column already exists ({%s.%s})", mod.Name(), p.Name())
return nil, fmt.Errorf("computed column already exists ({%s.%s})", mod.Name(), p.Name())
}
// Extract length multipiler
multiplier := p.target.Context().LengthMultiplier()
Expand All @@ -102,10 +102,10 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error {
// that all columns return their padding value which is then used to compute
// the padding value for *this* column.
padding := p.expr.EvalAt(-1, tr)
// Colunm needs to be expanded.
columns.Add(p.target.Context(), p.Name(), data, padding)
// Construct column
col := trace.NewColumn(p.target.Context(), p.Name(), data, padding)
// Done
return nil
return []*trace.Column{col}, nil
}

// Dependencies returns the set of columns that this assignment depends upon.
Expand Down
15 changes: 8 additions & 7 deletions pkg/schema/assignment/interleave.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

"github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
tr "github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)
Expand Down Expand Up @@ -65,10 +66,10 @@ func (p *Interleaving) RequiredSpillage() uint {
return uint(0)
}

// ExpandTrace expands a given trace to include the columns specified by a given
// Interleaving. This requires copying the data in the source columns to create
// the interleaved column.
func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
// ComputeColumns computes the values of columns defined by this assignment.
// This requires copying the data in the source columns to create the
// interleaved column.
func (p *Interleaving) ComputeColumns(tr trace.Trace) ([]*trace.Column, error) {
columns := tr.Columns()
ctx := p.target.Context()
// Byte width records the largest width of any column.
Expand All @@ -80,7 +81,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
bit_width = max(bit_width, ith.Type().BitWidth())
// Sanity check no column already exists with this name.
if _, ok := columns.IndexOf(ctx.Module(), ith.Name()); ok {
return fmt.Errorf("interleaved column already exists ({%s})", ith.Name())
return nil, fmt.Errorf("interleaved column already exists ({%s})", ith.Name())
}
}
// Determine interleaving width
Expand Down Expand Up @@ -110,9 +111,9 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
// column in the interleaving.
padding := columns.Get(0).Padding()
// Colunm needs to be expanded.
columns.Add(ctx, p.target.Name(), data, padding)
col := trace.NewColumn(ctx, p.target.Name(), data, padding)
//
return nil
return []*trace.Column{col}, nil
}

// Dependencies returns the set of columns that this assignment depends upon.
Expand Down
47 changes: 21 additions & 26 deletions pkg/schema/assignment/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,46 +67,49 @@ func (p *LexicographicSort) RequiredSpillage() uint {
return uint(0)
}

// ExpandTrace adds columns as needed to support the LexicographicSortingGadget.
// That includes the delta column, and the bit selectors.
func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error {
// ComputeColumns computes the values of columns defined as needed to support
// the LexicographicSortingGadget. That includes the delta column, and the bit
// selectors.
func (p *LexicographicSort) ComputeColumns(tr trace.Trace) ([]*trace.Column, error) {
columns := tr.Columns()
zero := fr.NewElement(0)
one := fr.NewElement(1)
first := p.targets[0]
// Exact number of columns involved in the sort
ncols := len(p.sources)
nbits := len(p.sources)
//
multiplier := p.context.LengthMultiplier()
// Determine how many rows to be constrained.
nrows := tr.Modules().Get(p.context.Module()).Height() * multiplier
// Initialise new data columns
bit := make([]util.FrArray, ncols)
cols := make([]*trace.Column, nbits+1)
// Byte width records the largest width of any column.
bit_width := uint(0)

for i := 0; i < ncols; i++ {
// TODO: following can be optimised to use a single bit per element,
// rather than an entire byte.
bit[i] = util.NewFrArray(nrows, 1)
ith := columns.Get(p.sources[i])
bit_width = max(bit_width, ith.Data().BitWidth())
}

//
delta := util.NewFrArray(nrows, bit_width)
cols[0] = trace.NewColumn(first.Context(), first.Name(), delta, &zero)
//
for i := 0; i < nbits; i++ {
target := p.targets[1+i]
source := columns.Get(p.sources[i])
data := util.NewFrArray(nrows, 1)
cols[i+1] = trace.NewColumn(target.Context(), target.Name(), data, &zero)
bit_width = max(bit_width, source.Data().BitWidth())
}

for i := uint(0); i < nrows; i++ {
set := false
// Initialise delta to zero
delta.Set(i, &zero)
// Decide which row is the winner (if any)
for j := 0; j < ncols; j++ {
for j := 0; j < nbits; j++ {
prev := columns.Get(p.sources[j]).Get(int(i - 1))
curr := columns.Get(p.sources[j]).Get(int(i))

if !set && prev != nil && prev.Cmp(curr) != 0 {
var diff fr.Element

bit[j].Set(i, &one)
cols[j+1].Data().Set(i, &one)
// Compute curr - prev
if p.signs[j] {
diff.Set(curr)
Expand All @@ -118,20 +121,12 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error {

set = true
} else {
bit[j].Set(i, &zero)
cols[j+1].Data().Set(i, &zero)
}
}
}
// Add delta column data
first := p.targets[0]
columns.Add(first.Context(), first.Name(), delta, &zero)
// Add bit column data
for i := 0; i < ncols; i++ {
ith := p.targets[1+i]
columns.Add(ith.Context(), ith.Name(), bit[i], &zero)
}
// Done.
return nil
return cols, nil
}

// Dependencies returns the set of columns that this assignment depends upon.
Expand Down
34 changes: 17 additions & 17 deletions pkg/schema/assignment/sorted_permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,42 +117,42 @@ func (p *SortedPermutation) RequiredSpillage() uint {
return uint(0)
}

// ExpandTrace expands a given trace to include the columns specified by a given
// SortedPermutation. This requires copying the data in the source columns, and
// sorting that data according to the permutation criteria.
func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error {
// ComputeColumns computes the values of columns defined by this assignment.
// This requires copying the data in the source columns, and sorting that data
// according to the permutation criteria.
func (p *SortedPermutation) ComputeColumns(tr trace.Trace) ([]*trace.Column, error) {
columns := tr.Columns()
// Ensure target columns don't exist
for i := p.Columns(); i.HasNext(); {
name := i.Next().Name()
// Sanity check no column already exists with this name.
if _, ok := columns.IndexOf(p.context.Module(), name); ok {
return fmt.Errorf("permutation column already exists ({%s})", name)
return nil, fmt.Errorf("permutation column already exists ({%s})", name)
}
}

cols := make([]util.FrArray, len(p.sources))
data := make([]util.FrArray, len(p.sources))
// Construct target columns
for i := 0; i < len(p.sources); i++ {
src := p.sources[i]
// Read column data
data := columns.Get(src).Data()
src_data := columns.Get(src).Data()
// Clone it to initialise permutation.
cols[i] = data.Clone()
data[i] = src_data.Clone()
}
// Sort target columns
util.PermutationSort(cols, p.signs)
// Physically add the columns
index := 0

for i := p.Columns(); i.HasNext(); index++ {
ith := i.Next()
util.PermutationSort(data, p.signs)
// Physically construct the columns
cols := make([]*trace.Column, len(p.sources))
//
for i, iter := 0, p.Columns(); iter.HasNext(); i++ {
ith := iter.Next()
dstColName := ith.Name()
srcCol := tr.Columns().Get(p.sources[index])
columns.Add(ith.Context(), dstColName, cols[index], srcCol.Padding())
srcCol := tr.Columns().Get(p.sources[i])
cols[i] = trace.NewColumn(ith.Context(), dstColName, data[i], srcCol.Padding())
}
//
return nil
return cols, nil
}

// Dependencies returns the set of columns that this assignment depends upon.
Expand Down
12 changes: 7 additions & 5 deletions pkg/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
tr "github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)
Expand Down Expand Up @@ -53,11 +54,12 @@ type Declaration interface {
type Assignment interface {
Declaration

// ExpandTrace expands a given trace to include "computed
// columns". These are columns which do not exist in the
// original trace, but are added during trace expansion to
// form the final trace.
ExpandTrace(tr.Trace) error
// ComputeColumns computes the values of columns defined by this assignment.
// In order for this computation to makes sense, all columns on which this
// assignment depends must exist (e.g. are either inputs or have been
// computed already). Computed columns do not exist in the original trace,
// but are added during trace expansion to form the final trace.
ComputeColumns(tr.Trace) ([]*trace.Column, error)
// RequiredSpillage returns the minimum amount of spillage required to ensure
// valid traces are accepted in the presence of arbitrary padding. Note,
// spillage is currently assumed to be required only at the front of a
Expand Down
Loading

0 comments on commit fe243d6

Please sign in to comment.