Skip to content

Commit

Permalink
Merge pull request #306 from Consensys/270-spillage-on-module-by-modu…
Browse files Browse the repository at this point in the history
…le-basis

feat: support spillage on a per-module basis
  • Loading branch information
DavePearce authored Sep 10, 2024
2 parents c690efc + ce841ac commit be5aebd
Show file tree
Hide file tree
Showing 26 changed files with 1,626 additions and 24 deletions.
7 changes: 6 additions & 1 deletion pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ var checkCmd = &cobra.Command{
cfg.hir = getFlag(cmd, "hir")
cfg.expand = !getFlag(cmd, "raw")
cfg.report = getFlag(cmd, "report")
cfg.reportPadding = getUint(cmd, "report-context")
cfg.spillage = getInt(cmd, "spillage")
cfg.strict = !getFlag(cmd, "warn")
cfg.quiet = getFlag(cmd, "quiet")
Expand Down Expand Up @@ -92,6 +93,9 @@ type checkConfig struct {
// Specifies whether or not to report details of the failure (e.g. for
// debugging purposes).
report bool
// Specifies the number of additional rows to show eitherside of the failing
// area. This essentially allows more contextual information to be shown.
reportPadding uint
// Perform trace expansion in parallel (or not)
parallelExpansion bool
// Size of constraint batches to execute in parallel
Expand Down Expand Up @@ -255,7 +259,7 @@ func reportVanishingFailure(failure *constraint.VanishingFailure, trace tr.Trace
cols.Insert(c.Column)
}
// Construct & configure printer
tp := tr.NewPrinter().Start(start).End(end).MaxCellWidth(16)
tp := tr.NewPrinter().Start(start).End(end).MaxCellWidth(16).Padding(cfg.reportPadding)
// Determine whether to enable ANSI escapes (e.g. for colour in the terminal)
tp = tp.AnsiEscapes(cfg.ansiEscapes)
// Filter out columns not used in evaluating the constraint.
Expand Down Expand Up @@ -293,6 +297,7 @@ func reportErrors(error bool, ir string, errs []error) {
func init() {
rootCmd.AddCommand(checkCmd)
checkCmd.Flags().Bool("report", false, "report details of failure for debugging")
checkCmd.Flags().Uint("report-context", 2, "specify number of rows to show eitherside of failure in report")
checkCmd.Flags().Bool("raw", false, "assume input trace already expanded")
checkCmd.Flags().Bool("hir", false, "check at HIR level")
checkCmd.Flags().Bool("mir", false, "check at MIR level")
Expand Down
5 changes: 5 additions & 0 deletions pkg/schema/assignment/byte_decomposition.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ func (p *ByteDecomposition) String() string {
// Declaration Interface
// ============================================================================

// Context returns the evaluation context for this declaration.
func (p *ByteDecomposition) Context() trace.Context {
return p.targets[0].Context()
}

// Columns returns the columns declared by this byte decomposition (in the order
// of declaration).
func (p *ByteDecomposition) Columns() util.Iterator[schema.Column] {
Expand Down
5 changes: 5 additions & 0 deletions pkg/schema/assignment/computed_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ func (p *ComputedColumn[E]) Name() string {
// Declaration Interface
// ============================================================================

// Context returns the evaluation context for this computed column.
func (p *ComputedColumn[E]) Context() trace.Context {
return p.target.Context()
}

// Columns returns the columns declared by this computed column.
func (p *ComputedColumn[E]) Columns() util.Iterator[sc.Column] {
// TODO: figure out appropriate type for computed column
Expand Down
5 changes: 5 additions & 0 deletions pkg/schema/assignment/interleave.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func (p *Interleaving) Sources() []uint {
// Declaration Interface
// ============================================================================

// Context returns the evaluation context for this interleaving.
func (p *Interleaving) Context() trace.Context {
return p.target.Context()
}

// Columns returns the column declared by this interleaving.
func (p *Interleaving) Columns() util.Iterator[schema.Column] {
return util.NewUnitIterator(p.target)
Expand Down
5 changes: 5 additions & 0 deletions pkg/schema/assignment/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ func NewLexicographicSort(prefix string, context trace.Context,
// Declaration Interface
// ============================================================================

// Context returns the evaluation context for this declaration.
func (p *LexicographicSort) Context() trace.Context {
return p.context
}

// Columns returns the columns declared by this assignment.
func (p *LexicographicSort) Columns() util.Iterator[schema.Column] {
return util.NewArrayIterator(p.targets)
Expand Down
5 changes: 5 additions & 0 deletions pkg/schema/assignment/sorted_permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ func (p *SortedPermutation) String() string {
// Declaration Interface
// ============================================================================

// Context returns the evaluation context for this declaration.
func (p *SortedPermutation) Context() trace.Context {
return p.context
}

// Columns returns the columns declared by this sorted permutation (in the order
// of declaration).
func (p *SortedPermutation) Columns() util.Iterator[schema.Column] {
Expand Down
27 changes: 9 additions & 18 deletions pkg/schema/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func (tb TraceBuilder) Build(columns []trace.RawColumn) (trace.Trace, []error) {
// Critical failure
return nil, errs
} else if tb.expand {
// TODO: this is not done properly.
padColumns(tr, requiredSpillage(tb.schema))
// Apply spillage
applySpillage(tr, tb.schema)
// Expand trace
if tb.parallel {
// Run (parallel) trace expansion
Expand Down Expand Up @@ -222,23 +222,14 @@ func validateTraceColumns(schema Schema, tr *trace.ArrayTrace) (error, []error)
return nil, warnings
}

// RequiredSpillage returns the minimum amount of spillage required to ensure
// valid traces are accepted in the presence of arbitrary padding. Spillage can
// only arise from computations as this is where values outside of the user's
// control are determined.
func requiredSpillage(schema Schema) uint {
// Ensures always at least one row of spillage (referred to as the "initial
// padding row")
mx := uint(1)
// Determine if any more spillage required
for i := schema.Assignments(); i.HasNext(); {
// Get ith assignment
ith := i.Next()
// Incorporate its spillage requirements
mx = max(mx, ith.RequiredSpillage())
// applySpillage pads each module with its given level of spillage
func applySpillage(tr *trace.ArrayTrace, schema Schema) {
n := tr.Modules().Count()
// Iterate over modules
for i := uint(0); i < n; i++ {
spillage := RequiredSpillage(i, schema)
tr.Pad(i, spillage)
}

return mx
}

// PadColumns pads every column in a given trace with a given amount of padding.
Expand Down
5 changes: 5 additions & 0 deletions pkg/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ type Declaration interface {
// Return the declared columns (in the order of declaration).
Columns() util.Iterator[Column]

// Context returns the evaluation context (i.e. enclosing module + length
// multiplier) for this declaration. Every declaration must have a single,
// unique context.
Context() tr.Context

// Determines whether or not this declaration is computed.
IsComputed() bool
}
Expand Down
22 changes: 22 additions & 0 deletions pkg/schema/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,28 @@ import (
"github.com/consensys/go-corset/pkg/util"
)

// RequiredSpillage returns the minimum amount of spillage required for a given
// module to ensure valid traces are accepted in the presence of arbitrary
// padding. Spillage can only arise from computations as this is where values
// outside of the user's control are determined.
func RequiredSpillage(module uint, schema Schema) uint {
// Ensures always at least one row of spillage (referred to as the "initial
// padding row")
mx := uint(1)
// Determine if any more spillage required
for i := schema.Assignments(); i.HasNext(); {
// Get ith assignment
ith := i.Next()
//
if ith.Context().Module() == module {
// Incorporate its spillage requirements
mx = max(mx, ith.RequiredSpillage())
}
}

return mx
}

// JoinContexts combines one or more evaluation contexts together. If all
// expressions have the void context, then this is returned. Likewise, if any
// expression has a conflicting context then this is returned. Finally, if any
Expand Down
22 changes: 21 additions & 1 deletion pkg/test/ir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ func Test_Spillage_04(t *testing.T) {
Check(t, "spillage_04")
}

func Test_Spillage_05(t *testing.T) {
Check(t, "spillage_05")
}

func Test_Spillage_06(t *testing.T) {
Check(t, "spillage_06")
}

func Test_Spillage_07(t *testing.T) {
Check(t, "spillage_07")
}

func Test_Spillage_08(t *testing.T) {
Check(t, "spillage_08")
}

func Test_Spillage_09(t *testing.T) {
Check(t, "spillage_09")
}

// ===================================================================
// Normalisation Tests
// ===================================================================
Expand Down Expand Up @@ -456,7 +476,7 @@ func TestSlow_Mxp(t *testing.T) {

// Determines the maximum amount of padding to use when testing. Specifically,
// every trace is tested with varying amounts of padding upto this value.
const MAX_PADDING uint = 5
const MAX_PADDING uint = 7

// For a given set of constraints, check that all traces which we
// expect to be accepted are accepted, and all traces that we expect
Expand Down
16 changes: 12 additions & 4 deletions pkg/trace/printer.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,19 @@ func (p *Printer) Start(start uint) *Printer {
return p
}

// End configures tne ending row (inclusive) for this printer.
// End configures the ending row (inclusive) for this printer.
func (p *Printer) End(end uint) *Printer {
p.endRow = end
return p
}

// Padding configures the number of padding rows (i.e. rows outside the affected
// area) to include for additional context.
func (p *Printer) Padding(padding uint) *Printer {
p.padding = padding
return p
}

// Columns configures a filter which selects columns to be included in the final
// print out.
func (p *Printer) Columns(filter ColumnFilter) *Printer {
Expand Down Expand Up @@ -92,14 +99,15 @@ func (p *Printer) Print(trace Trace) {
var start uint
if p.startRow >= p.padding {
start = p.startRow - p.padding
} else if p.padding > 0 {
start = 0
} else {
start = p.startRow
}

end := p.startRow + p.padding + 1
end := min(MaxHeight(trace), p.startRow+p.padding+1)
columns := make([]uint, 0)
endRow := min(MaxHeight(trace), end)
width := 1 + endRow - start
width := 1 + end - start
// Filter columns
for i := uint(0); i < trace.Width(); i++ {
if p.colFilter(i, trace) {
Expand Down
Loading

0 comments on commit be5aebd

Please sign in to comment.