Skip to content

Commit

Permalink
clean fields (#21269)
Browse files Browse the repository at this point in the history
doc https://github.com/matrixorigin/docs/blob/main/design/mo/sql/20250116-daviszhen-clean-rec-cte.md

调整递归cte bind过程
- 删除不必要字段
- 调整加sink node的位置

Approved by: @badboynt1, @ouyuanning
  • Loading branch information
daviszhen authored Jan 17, 2025
1 parent 48f33fc commit 02d0e42
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 63 deletions.
23 changes: 10 additions & 13 deletions pkg/sql/plan/bind_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,10 @@ func NewBindContext(builder *QueryBuilder, parent *BindContext) *BindContext {

if parent != nil {
bc.defaultDatabase = parent.defaultDatabase
bc.normalCTE = parent.normalCTE
bc.cteName = parent.cteName
if parent.recSelect || parent.initSelect || parent.finalSelect {
if parent.bindingCte() {
bc.cteByName = parent.cteByName
bc.initSelect = parent.initSelect
bc.recSelect = parent.recSelect
bc.finalSelect = parent.finalSelect
bc.recRecursiveScanNodeId = parent.recRecursiveScanNodeId
bc.cteState = parent.cteState
}
bc.snapshot = parent.snapshot
}
Expand All @@ -65,7 +61,7 @@ func NewBindContext(builder *QueryBuilder, parent *BindContext) *BindContext {
}

func (bc *BindContext) rootTag() int32 {
if bc.initSelect || bc.recSelect {
if bc.bindingRecurCte() {
return bc.sinkTag
} else if bc.resultTag > 0 {
return bc.resultTag
Expand All @@ -84,23 +80,23 @@ func (bc *BindContext) topTag() int32 {

func (bc *BindContext) findCTE(name string) *CTERef {
// the cte is masked already, we don't go further
if bc.maskedCTEs[name] {
if bc.cteState.masked(name) {
return nil
}
if cte, ok := bc.cteByName[name]; ok {
if !bc.maskedCTEs[name] {
if !bc.cteState.masked(name) {
return cte
}
}

parent := bc.parent
for parent != nil {
// the cte is masked already, we don't go further
if _, ok2 := parent.maskedCTEs[name]; ok2 {
if parent.cteState.masked(name) {
break
}
if cte, ok := parent.cteByName[name]; ok {
if !parent.maskedCTEs[name] {
if !parent.cteState.masked(name) {
return cte
}
}
Expand All @@ -111,8 +107,9 @@ func (bc *BindContext) findCTE(name string) *CTERef {
return nil
}

func (bc *BindContext) recordCteInBinding(name string, cte *CTERef) {
bc.boundCtes[name] = cte
func (bc *BindContext) recordCteInBinding(name string, cte CteBindState) {
bc.boundCtes[name] = cte.cte
bc.cteState = cte
}

func (bc *BindContext) cteInBinding(name string) bool {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/plan/build_sample.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func validSample(ctx *BindContext, builder *QueryBuilder) error {
if len(ctx.aggregates) > 0 {
return moerr.NewSyntaxError(builder.GetContext(), "cannot fixed non-scalar function and scalar function in the same query")
}
if ctx.recSelect || builder.isForUpdate {
if ctx.bindingRecurStmt() || builder.isForUpdate {
return moerr.NewInternalError(builder.GetContext(), "not support sample function recursive cte or for update")
}
if len(ctx.windows) > 0 {
Expand Down
110 changes: 76 additions & 34 deletions pkg/sql/plan/query_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1981,7 +1981,6 @@ func (builder *QueryBuilder) buildUnion(stmt *tree.UnionClause, astOrderBy tree.
var nodeID int32
for idx, sltStmt := range selectStmts {
subCtx := NewBindContext(builder, ctx)
subCtx.unionSelect = subCtx.initSelect
if slt, ok := sltStmt.(*tree.Select); ok {
nodeID, err = builder.bindSelect(slt, subCtx, isRoot)
} else {
Expand Down Expand Up @@ -2254,10 +2253,6 @@ func (builder *QueryBuilder) buildUnion(stmt *tree.UnionClause, astOrderBy tree.
ctx.results = ctx.projects
}

if ctx.initSelect {
lastNodeID = appendSinkNodeWithTag(builder, ctx, lastNodeID, ctx.sinkTag)
}

// set heading
if isRoot {
builder.qry.Headings = append(builder.qry.Headings, ctx.headings...)
Expand Down Expand Up @@ -2307,11 +2302,13 @@ func (builder *QueryBuilder) bindNoRecursiveCte(
cteRef *CTERef,
table string) (nodeID int32, err error) {
subCtx := NewBindContext(builder, ctx)
subCtx.normalCTE = true
subCtx.maskedCTEs = cteRef.maskedCTEs
subCtx.cteName = table
subCtx.snapshot = cteRef.snapshot
subCtx.recordCteInBinding(table, cteRef)
subCtx.recordCteInBinding(table,
CteBindState{
cteBindType: CteBindTypeNonRecur,
cte: cteRef,
recScanNodeId: -1})
cteRef.isRecursive = false

oldSnapshot := builder.compCtx.GetSnapshot()
Expand Down Expand Up @@ -2351,42 +2348,62 @@ func (builder *QueryBuilder) bindRecursiveCte(
if len(s.OrderBy) > 0 {
return 0, moerr.NewParseError(builder.GetContext(), "not support ORDER BY in recursive cte")
}
// initial statement
//1. bind initial statement
initCtx := NewBindContext(builder, ctx)
initCtx.initSelect = true
initCtx.cteName = table
initCtx.recordCteInBinding(table,
CteBindState{
cteBindType: CteBindTypeInitStmt,
cte: cteRef,
recScanNodeId: -1})
initCtx.sinkTag = builder.genNewTag()
initLastNodeID, err1 := builder.bindSelect(&tree.Select{Select: *left}, initCtx, false)
if err1 != nil {
err = err1
return
}

//2. add Sink Node on top of initial statement
initLastNodeID = appendSinkNodeWithTag(builder, initCtx, initLastNodeID, initCtx.sinkTag)
builder.qry.Nodes[initLastNodeID].RecursiveCte = false

projects := builder.qry.Nodes[builder.qry.Nodes[initLastNodeID].Children[0]].ProjectList
// recursive statement
recursiveLastNodeID := initLastNodeID
initSourceStep := int32(len(builder.qry.Steps))
recursiveSteps := make([]int32, len(stmts))
recursiveNodeIDs := make([]int32, len(stmts))

//3. bind recursive parts
for i, r := range stmts {
subCtx := NewBindContext(builder, ctx)
subCtx.maskedCTEs = cteRef.maskedCTEs
subCtx.cteName = table
subCtx.recSelect = true
subCtx.sinkTag = initCtx.sinkTag
subCtx.cteByName = make(map[string]*CTERef)
subCtx.cteByName[table] = cteRef
//3.0 add initial statement as table binding into the subCtx of recursive part
err = builder.addBinding(initLastNodeID, *cteRef.ast.Name, subCtx)
if err != nil {
return
}
//3.1 add recursive cte Node
_ = builder.appendStep(recursiveLastNodeID)
subCtx.recRecursiveScanNodeId = appendRecursiveScanNode(builder, subCtx, initSourceStep, subCtx.sinkTag)
recursiveNodeIDs[i] = subCtx.recRecursiveScanNodeId
recScanId := appendRecursiveScanNode(builder, subCtx, initSourceStep, subCtx.sinkTag)
recursiveNodeIDs[i] = recScanId
recursiveSteps[i] = int32(len(builder.qry.Steps))

subCtx.recordCteInBinding(table,
CteBindState{
cteBindType: CteBindTypeRecurStmt,
cte: cteRef,
recScanNodeId: recScanId})

recursiveLastNodeID, err = builder.bindSelect(&tree.Select{Select: r}, subCtx, false)
if err != nil {
return
}

//3.2 add Sink Node on the top of single recursive part
recursiveLastNodeID = appendSinkNodeWithTag(builder, subCtx, recursiveLastNodeID, subCtx.sinkTag)
builder.qry.Nodes[recursiveLastNodeID].RecursiveCte = true
if !checkOnly {
// some check
n := builder.qry.Nodes[builder.qry.Nodes[recursiveLastNodeID].Children[0]]
Expand Down Expand Up @@ -2445,35 +2462,63 @@ func (builder *QueryBuilder) bindRecursiveCte(
}
}

//4. add CTE Scan Node
_ = builder.appendStep(recursiveLastNodeID)
nodeID = appendCTEScanNode(builder, ctx, initSourceStep, initCtx.sinkTag)
if limitExpr != nil || offsetExpr != nil {
node := builder.qry.Nodes[nodeID]
node.Limit = limitExpr
node.Offset = offsetExpr
}
//4.1 make recursive parts as the source step of the CTE Scan Node besides initSourceStep of initial statement
for i := 0; i < len(recursiveSteps); i++ {
builder.qry.Nodes[nodeID].SourceStep = append(builder.qry.Nodes[nodeID].SourceStep, recursiveSteps[i])
}

//4.2 make CTE scan as the source step of the Recursive cte Node
curStep := int32(len(builder.qry.Steps))
for _, id := range recursiveNodeIDs {
builder.qry.Nodes[id].SourceStep[0] = curStep
}

//4.3 add Sink Node on top of CTE Scan Node
unionAllLastNodeID := appendSinkNodeWithTag(builder, ctx, nodeID, ctx.sinkTag)
builder.qry.Nodes[unionAllLastNodeID].RecursiveSink = true

// final statement
ctx.finalSelect = true
//5. bind final statement
ctx.sinkTag = initCtx.sinkTag
//5.0 add initial statement as table binding into the ctx of main query
err = builder.addBinding(initLastNodeID, *cteRef.ast.Name, ctx)
if err != nil {
return
}
//5.1 add Sink Scan Node as the scan node of the recursive cte
sourceStep := builder.appendStep(unionAllLastNodeID)
nodeID = appendSinkScanNodeWithTag(builder, ctx, sourceStep, initCtx.sinkTag)
return
}

// check if binding cte currently
func (bc *BindContext) bindingCte() bool {
return bc.cteState.cteBindType != CteBindTypeNone
}

// check if binding non recursive cte currently
func (bc *BindContext) bindingNonRecurCte() bool {
return bc.cteState.cteBindType == CteBindTypeNonRecur
}

// check if binding recursive cte currently
func (bc *BindContext) bindingRecurCte() bool {
return bc.cteState.cteBindType == CteBindTypeInitStmt ||
bc.cteState.cteBindType == CteBindTypeRecurStmt
}

// check if binding recursive part of recursive cte currently
func (bc *BindContext) bindingRecurStmt() bool {
return bc.cteState.cteBindType == CteBindTypeRecurStmt
}

func (builder *QueryBuilder) bindCte(
ctx *BindContext,
stmt tree.NodeFormatter,
Expand Down Expand Up @@ -2783,7 +2828,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR
return 0, err
}
} else {
if ctx.recSelect && clause.Distinct {
if ctx.bindingRecurStmt() && clause.Distinct {
return 0, moerr.NewParseError(builder.GetContext(), "not support DISTINCT in recursive cte")
}

Expand Down Expand Up @@ -2936,7 +2981,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR
}
// bind GROUP BY clause
if clause.GroupBy != nil || astTimeWindow != nil {
if ctx.recSelect {
if ctx.bindingRecurStmt() {
return 0, moerr.NewParseErrorf(builder.GetContext(), "not support group by in recursive cte: '%v'", tree.String(clause.GroupBy, dialect.MYSQL))
}
groupBinder := NewGroupBinder(builder, ctx, selectList)
Expand Down Expand Up @@ -2996,7 +3041,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR
// bind HAVING clause
havingBinder = NewHavingBinder(builder, ctx)
if clause.Having != nil {
if ctx.recSelect {
if ctx.bindingRecurStmt() {
return 0, moerr.NewParseErrorf(builder.GetContext(), "not support having in recursive cte: '%v'", tree.String(clause.Having, dialect.MYSQL))
}
ctx.binder = havingBinder
Expand Down Expand Up @@ -3149,7 +3194,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR
// bind ORDER BY clause
var orderBys []*plan.OrderBySpec
if astOrderBy != nil {
if ctx.recSelect {
if ctx.bindingRecurStmt() {
return 0, moerr.NewParseErrorf(builder.GetContext(), "not support order by in recursive cte: '%v'", tree.String(&astOrderBy, dialect.MYSQL))
}
orderBinder := NewOrderBinder(projectionBinder, selectList)
Expand Down Expand Up @@ -3290,7 +3335,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR
}
} else {
if len(ctx.groups) > 0 || len(ctx.aggregates) > 0 {
if ctx.recSelect {
if ctx.bindingRecurStmt() {
return 0, moerr.NewInternalError(builder.GetContext(), "not support aggregate function recursive cte")
}
if builder.isForUpdate {
Expand Down Expand Up @@ -3339,7 +3384,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR

// append TIME WINDOW node
if len(ctx.times) > 0 {
if ctx.recSelect {
if ctx.bindingRecurStmt() {
return 0, moerr.NewInternalError(builder.GetContext(), "not support time window in recursive cte")
}
nodeID = builder.appendNode(&plan.Node{
Expand Down Expand Up @@ -3373,7 +3418,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR

// append WINDOW node
if len(ctx.windows) > 0 {
if ctx.recSelect {
if ctx.bindingRecurStmt() {
return 0, moerr.NewInternalError(builder.GetContext(), "not support window function in recursive cte")
}

Expand Down Expand Up @@ -3499,11 +3544,6 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR
ctx.results = ctx.projects
}

if (ctx.initSelect || ctx.recSelect) && !ctx.unionSelect {
nodeID = appendSinkNodeWithTag(builder, ctx, nodeID, ctx.sinkTag)
builder.qry.Nodes[nodeID].RecursiveCte = ctx.recSelect
}

if isRoot {
builder.qry.Headings = append(builder.qry.Headings, ctx.headings...)
}
Expand Down Expand Up @@ -4045,16 +4085,18 @@ func (builder *QueryBuilder) buildTable(stmt tree.TableExpr, ctx *BindContext, p
break
}

if len(schema) == 0 && ctx.normalCTE && table == ctx.cteName {
if len(schema) == 0 && ctx.bindingNonRecurCte() && table == ctx.cteName {
return 0, moerr.NewParseErrorf(builder.GetContext(), "In recursive query block of Recursive Common Table Expression %s, the recursive table must be referenced only once, and not in any subquery", table)
} else if len(schema) == 0 {
cteRef := ctx.findCTE(table)
if cteRef != nil {
if ctx.cteInBinding(table) {
return 0, moerr.NewParseErrorf(builder.GetContext(), "cte %s reference itself", table)
if ctx.bindingNonRecurCte() {
return 0, moerr.NewParseErrorf(builder.GetContext(), "cte %s reference itself", table)
}
}
if ctx.recSelect {
nodeID = ctx.recRecursiveScanNodeId
if ctx.bindingRecurStmt() {
nodeID = ctx.cteState.recScanNodeId
return
}

Expand Down
Loading

0 comments on commit 02d0e42

Please sign in to comment.