Skip to content

Commit

Permalink
Support Sequential ParExec
Browse files Browse the repository at this point in the history
This puts in place support for using a sequential ParExec
implementation.  This is a first step towards parallelising trace
expansion.  Specifically, there are two outstanding issues to resolve:
firstly, we must be able to add expanded columns in arbitrary order;
secondly, we need to actually make `ParExec` operate in parallel (which
make require making trace thread-safe or something).

The issue around the order in which expanded columns are added to the
trace is problematic as the current design essentially expects them to
be added in a specific order.  Instead, we need assignments and input
columns to know what their target column index is.  Then, we expanding
the trace, they can specify this.
  • Loading branch information
DavePearce committed Jul 29, 2024
1 parent 5876153 commit b9c2786
Show file tree
Hide file tree
Showing 19 changed files with 665 additions and 19 deletions.
44 changes: 44 additions & 0 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ func (p *Add) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Add) RequiredColumns() *util.SortedSet[uint] {
return util.UnionSortedSets(p.Args, func(e Expr) *util.SortedSet[uint] {
return e.RequiredColumns()
})
}

// Add two expressions together, producing a third.
func (p *Add) Add(other Expr) Expr { return &Add{Args: []Expr{p, other}} }

Expand Down Expand Up @@ -75,6 +84,15 @@ func (p *Sub) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Sub) RequiredColumns() *util.SortedSet[uint] {
return util.UnionSortedSets(p.Args, func(e Expr) *util.SortedSet[uint] {
return e.RequiredColumns()
})
}

// Add two expressions together, producing a third.
func (p *Sub) Add(other Expr) Expr { return &Add{Args: []Expr{p, other}} }

Expand Down Expand Up @@ -104,6 +122,15 @@ func (p *Mul) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Mul) RequiredColumns() *util.SortedSet[uint] {
return util.UnionSortedSets(p.Args, func(e Expr) *util.SortedSet[uint] {
return e.RequiredColumns()
})
}

// Add two expressions together, producing a third.
func (p *Mul) Add(other Expr) Expr { return &Add{Args: []Expr{p, other}} }

Expand Down Expand Up @@ -156,6 +183,13 @@ func (p *Constant) Context(schema sc.Schema) trace.Context {
return trace.VoidContext()
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Constant) RequiredColumns() *util.SortedSet[uint] {
return util.NewSortedSet[uint]()
}

// Add two expressions together, producing a third.
func (p *Constant) Add(other Expr) Expr { return &Add{Args: []Expr{p, other}} }

Expand Down Expand Up @@ -196,6 +230,16 @@ func (p *ColumnAccess) Context(schema sc.Schema) trace.Context {
return col.Context()
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *ColumnAccess) RequiredColumns() *util.SortedSet[uint] {
r := util.NewSortedSet[uint]()
r.Insert(p.Column)
// Done
return r
}

// Add two expressions together, producing a third.
func (p *ColumnAccess) Add(other Expr) Expr { return &Add{Args: []Expr{p, other}} }

Expand Down
7 changes: 7 additions & 0 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ func (e *Inverse) Context(schema sc.Schema) tr.Context {
return e.Expr.Context(schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (e *Inverse) RequiredColumns() *util.SortedSet[uint] {
return e.Expr.RequiredColumns()
}

func (e *Inverse) String() string {
return fmt.Sprintf("(inv %s)", e.Expr)
}
84 changes: 84 additions & 0 deletions pkg/hir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ func (p *Add) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Add) RequiredColumns() *util.SortedSet[uint] {
return util.UnionSortedSets(p.Args, func(e Expr) *util.SortedSet[uint] {
return e.RequiredColumns()
})
}

// ============================================================================
// Subtraction
// ============================================================================
Expand All @@ -69,6 +78,15 @@ func (p *Sub) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Sub) RequiredColumns() *util.SortedSet[uint] {
return util.UnionSortedSets(p.Args, func(e Expr) *util.SortedSet[uint] {
return e.RequiredColumns()
})
}

// ============================================================================
// Multiplication
// ============================================================================
Expand All @@ -86,6 +104,15 @@ func (p *Mul) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Mul) RequiredColumns() *util.SortedSet[uint] {
return util.UnionSortedSets(p.Args, func(e Expr) *util.SortedSet[uint] {
return e.RequiredColumns()
})
}

// ============================================================================
// Exponentiation
// ============================================================================
Expand All @@ -106,6 +133,13 @@ func (p *Exp) Context(schema sc.Schema) trace.Context {
return p.Arg.Context(schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Exp) RequiredColumns() *util.SortedSet[uint] {
return p.Arg.RequiredColumns()
}

// ============================================================================
// List
// ============================================================================
Expand All @@ -123,6 +157,15 @@ func (p *List) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *List) RequiredColumns() *util.SortedSet[uint] {
return util.UnionSortedSets(p.Args, func(e Expr) *util.SortedSet[uint] {
return e.RequiredColumns()
})
}

// ============================================================================
// Constant
// ============================================================================
Expand All @@ -140,6 +183,13 @@ func (p *Constant) Context(schema sc.Schema) trace.Context {
return trace.VoidContext()
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Constant) RequiredColumns() *util.SortedSet[uint] {
return util.NewSortedSet[uint]()
}

// ============================================================================
// IfZero
// ============================================================================
Expand Down Expand Up @@ -190,6 +240,23 @@ func (p *IfZero) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](args, schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *IfZero) RequiredColumns() *util.SortedSet[uint] {
set := p.Condition.RequiredColumns()
// Include true branch (if applicable)
if p.TrueBranch != nil {
set.InsertSorted(p.TrueBranch.RequiredColumns())
}
// Include false branch (if applicable)
if p.FalseBranch != nil {
set.InsertSorted(p.FalseBranch.RequiredColumns())
}
// Done
return set
}

// ============================================================================
// Normalise
// ============================================================================
Expand All @@ -210,6 +277,13 @@ func (p *Normalise) Context(schema sc.Schema) trace.Context {
return p.Arg.Context(schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Normalise) RequiredColumns() *util.SortedSet[uint] {
return p.Arg.RequiredColumns()
}

// ============================================================================
// ColumnAccess
// ============================================================================
Expand Down Expand Up @@ -242,3 +316,13 @@ func (p *ColumnAccess) Context(schema sc.Schema) trace.Context {
col := schema.Columns().Nth(p.Column)
return col.Context()
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *ColumnAccess) RequiredColumns() *util.SortedSet[uint] {
r := util.NewSortedSet[uint]()
r.Insert(p.Column)
// Done
return r
}
14 changes: 14 additions & 0 deletions pkg/hir/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ func (p ZeroArrayTest) Context(schema sc.Schema) trace.Context {
return p.Expr.Context(schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p ZeroArrayTest) RequiredColumns() *util.SortedSet[uint] {
return p.Expr.RequiredColumns()
}

// ============================================================================
// UnitExpr
// ============================================================================
Expand Down Expand Up @@ -98,3 +105,10 @@ func (e UnitExpr) Bounds() util.Bounds {
func (e UnitExpr) Context(schema sc.Schema) trace.Context {
return e.expr.Context(schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (e UnitExpr) RequiredColumns() *util.SortedSet[uint] {
return e.expr.RequiredColumns()
}
58 changes: 58 additions & 0 deletions pkg/mir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ func (p *Add) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Add) RequiredColumns() *util.SortedSet[uint] {
return util.UnionSortedSets(p.Args, func(e Expr) *util.SortedSet[uint] {
return e.RequiredColumns()
})
}

// ============================================================================
// Subtraction
// ============================================================================
Expand All @@ -53,6 +62,15 @@ func (p *Sub) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Sub) RequiredColumns() *util.SortedSet[uint] {
return util.UnionSortedSets(p.Args, func(e Expr) *util.SortedSet[uint] {
return e.RequiredColumns()
})
}

// ============================================================================
// Multiplication
// ============================================================================
Expand All @@ -70,6 +88,15 @@ func (p *Mul) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Mul) RequiredColumns() *util.SortedSet[uint] {
return util.UnionSortedSets(p.Args, func(e Expr) *util.SortedSet[uint] {
return e.RequiredColumns()
})
}

// ============================================================================
// Exponentiation
// ============================================================================
Expand All @@ -90,6 +117,13 @@ func (p *Exp) Context(schema sc.Schema) trace.Context {
return p.Arg.Context(schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Exp) RequiredColumns() *util.SortedSet[uint] {
return p.Arg.RequiredColumns()
}

// ============================================================================
// Constant
// ============================================================================
Expand All @@ -107,6 +141,13 @@ func (p *Constant) Context(schema sc.Schema) trace.Context {
return trace.VoidContext()
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Constant) RequiredColumns() *util.SortedSet[uint] {
return util.NewSortedSet[uint]()
}

// ============================================================================
// Normalise
// ============================================================================
Expand All @@ -125,6 +166,13 @@ func (p *Normalise) Context(schema sc.Schema) trace.Context {
return p.Arg.Context(schema)
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *Normalise) RequiredColumns() *util.SortedSet[uint] {
return p.Arg.RequiredColumns()
}

// ============================================================================
// ColumnAccess
// ============================================================================
Expand Down Expand Up @@ -157,3 +205,13 @@ func (p *ColumnAccess) Context(schema sc.Schema) trace.Context {
col := schema.Columns().Nth(p.Column)
return col.Context()
}

// RequiredColumns returns the set of columns on which this term depends.
// That is, columns whose values may be accessed when evaluating this term
// on a given trace.
func (p *ColumnAccess) RequiredColumns() *util.SortedSet[uint] {
r := util.NewSortedSet[uint]()
r.Insert(p.Column)
// Done
return r
}
6 changes: 6 additions & 0 deletions pkg/schema/assignment/byte_decomposition.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ func (p *ByteDecomposition) RequiredSpillage() uint {
return uint(0)
}

// Dependencies returns the set of columns that this assignment depends upon.
// That can include both input columns, as well as other computed columns.
func (p *ByteDecomposition) Dependencies() []uint {
return []uint{p.source}
}

// Decompose a given element into n bytes in little endian form. For example,
// decomposing 41b into 2 bytes gives [0x1b,0x04].
func decomposeIntoBytes(val *fr.Element, n int) []*fr.Element {
Expand Down
6 changes: 6 additions & 0 deletions pkg/schema/assignment/computed_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,9 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error {
// Done
return nil
}

// Dependencies returns the set of columns that this assignment depends upon.
// That can include both input columns, as well as other computed columns.
func (p *ComputedColumn[E]) Dependencies() []uint {
return *p.expr.RequiredColumns()
}
6 changes: 6 additions & 0 deletions pkg/schema/assignment/interleave.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,9 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
//
return nil
}

// Dependencies returns the set of columns that this assignment depends upon.
// That can include both input columns, as well as other computed columns.
func (p *Interleaving) Dependencies() []uint {
return p.sources
}
Loading

0 comments on commit b9c2786

Please sign in to comment.