diff --git a/pkg/sql/plan/bind_context.go b/pkg/sql/plan/bind_context.go index b633ff997e150..fb7c5bbe77db0 100644 --- a/pkg/sql/plan/bind_context.go +++ b/pkg/sql/plan/bind_context.go @@ -49,14 +49,10 @@ func NewBindContext(builder *QueryBuilder, parent *BindContext) *BindContext { if parent != nil { bc.defaultDatabase = parent.defaultDatabase - bc.normalCTE = parent.normalCTE bc.cteName = parent.cteName - if parent.recSelect || parent.initSelect || parent.finalSelect { + if parent.bindingCte() { bc.cteByName = parent.cteByName - bc.initSelect = parent.initSelect - bc.recSelect = parent.recSelect - bc.finalSelect = parent.finalSelect - bc.recRecursiveScanNodeId = parent.recRecursiveScanNodeId + bc.cteState = parent.cteState } bc.snapshot = parent.snapshot } @@ -65,7 +61,7 @@ func NewBindContext(builder *QueryBuilder, parent *BindContext) *BindContext { } func (bc *BindContext) rootTag() int32 { - if bc.initSelect || bc.recSelect { + if bc.bindingRecurCte() { return bc.sinkTag } else if bc.resultTag > 0 { return bc.resultTag @@ -84,11 +80,11 @@ func (bc *BindContext) topTag() int32 { func (bc *BindContext) findCTE(name string) *CTERef { // the cte is masked already, we don't go further - if bc.maskedCTEs[name] { + if bc.cteState.masked(name) { return nil } if cte, ok := bc.cteByName[name]; ok { - if !bc.maskedCTEs[name] { + if !bc.cteState.masked(name) { return cte } } @@ -96,11 +92,11 @@ func (bc *BindContext) findCTE(name string) *CTERef { parent := bc.parent for parent != nil { // the cte is masked already, we don't go further - if _, ok2 := parent.maskedCTEs[name]; ok2 { + if parent.cteState.masked(name) { break } if cte, ok := parent.cteByName[name]; ok { - if !parent.maskedCTEs[name] { + if !parent.cteState.masked(name) { return cte } } @@ -111,8 +107,9 @@ func (bc *BindContext) findCTE(name string) *CTERef { return nil } -func (bc *BindContext) recordCteInBinding(name string, cte *CTERef) { - bc.boundCtes[name] = cte +func (bc *BindContext) recordCteInBinding(name string, cte CteBindState) { + bc.boundCtes[name] = cte.cte + bc.cteState = cte } func (bc *BindContext) cteInBinding(name string) bool { diff --git a/pkg/sql/plan/build_sample.go b/pkg/sql/plan/build_sample.go index 0c394bbdda54a..f9bb0d33b2f87 100644 --- a/pkg/sql/plan/build_sample.go +++ b/pkg/sql/plan/build_sample.go @@ -111,7 +111,7 @@ func validSample(ctx *BindContext, builder *QueryBuilder) error { if len(ctx.aggregates) > 0 { return moerr.NewSyntaxError(builder.GetContext(), "cannot fixed non-scalar function and scalar function in the same query") } - if ctx.recSelect || builder.isForUpdate { + if ctx.bindingRecurStmt() || builder.isForUpdate { return moerr.NewInternalError(builder.GetContext(), "not support sample function recursive cte or for update") } if len(ctx.windows) > 0 { diff --git a/pkg/sql/plan/query_builder.go b/pkg/sql/plan/query_builder.go index 08426bb029813..3cb08ac72bc6e 100644 --- a/pkg/sql/plan/query_builder.go +++ b/pkg/sql/plan/query_builder.go @@ -1981,7 +1981,6 @@ func (builder *QueryBuilder) buildUnion(stmt *tree.UnionClause, astOrderBy tree. var nodeID int32 for idx, sltStmt := range selectStmts { subCtx := NewBindContext(builder, ctx) - subCtx.unionSelect = subCtx.initSelect if slt, ok := sltStmt.(*tree.Select); ok { nodeID, err = builder.bindSelect(slt, subCtx, isRoot) } else { @@ -2254,10 +2253,6 @@ func (builder *QueryBuilder) buildUnion(stmt *tree.UnionClause, astOrderBy tree. ctx.results = ctx.projects } - if ctx.initSelect { - lastNodeID = appendSinkNodeWithTag(builder, ctx, lastNodeID, ctx.sinkTag) - } - // set heading if isRoot { builder.qry.Headings = append(builder.qry.Headings, ctx.headings...) @@ -2307,11 +2302,13 @@ func (builder *QueryBuilder) bindNoRecursiveCte( cteRef *CTERef, table string) (nodeID int32, err error) { subCtx := NewBindContext(builder, ctx) - subCtx.normalCTE = true - subCtx.maskedCTEs = cteRef.maskedCTEs subCtx.cteName = table subCtx.snapshot = cteRef.snapshot - subCtx.recordCteInBinding(table, cteRef) + subCtx.recordCteInBinding(table, + CteBindState{ + cteBindType: CteBindTypeNonRecur, + cte: cteRef, + recScanNodeId: -1}) cteRef.isRecursive = false oldSnapshot := builder.compCtx.GetSnapshot() @@ -2351,15 +2348,25 @@ func (builder *QueryBuilder) bindRecursiveCte( if len(s.OrderBy) > 0 { return 0, moerr.NewParseError(builder.GetContext(), "not support ORDER BY in recursive cte") } - // initial statement + //1. bind initial statement initCtx := NewBindContext(builder, ctx) - initCtx.initSelect = true + initCtx.cteName = table + initCtx.recordCteInBinding(table, + CteBindState{ + cteBindType: CteBindTypeInitStmt, + cte: cteRef, + recScanNodeId: -1}) initCtx.sinkTag = builder.genNewTag() initLastNodeID, err1 := builder.bindSelect(&tree.Select{Select: *left}, initCtx, false) if err1 != nil { err = err1 return } + + //2. add Sink Node on top of initial statement + initLastNodeID = appendSinkNodeWithTag(builder, initCtx, initLastNodeID, initCtx.sinkTag) + builder.qry.Nodes[initLastNodeID].RecursiveCte = false + projects := builder.qry.Nodes[builder.qry.Nodes[initLastNodeID].Children[0]].ProjectList // recursive statement recursiveLastNodeID := initLastNodeID @@ -2367,26 +2374,36 @@ func (builder *QueryBuilder) bindRecursiveCte( recursiveSteps := make([]int32, len(stmts)) recursiveNodeIDs := make([]int32, len(stmts)) + //3. bind recursive parts for i, r := range stmts { subCtx := NewBindContext(builder, ctx) - subCtx.maskedCTEs = cteRef.maskedCTEs subCtx.cteName = table - subCtx.recSelect = true subCtx.sinkTag = initCtx.sinkTag - subCtx.cteByName = make(map[string]*CTERef) - subCtx.cteByName[table] = cteRef + //3.0 add initial statement as table binding into the subCtx of recursive part err = builder.addBinding(initLastNodeID, *cteRef.ast.Name, subCtx) if err != nil { return } + //3.1 add recursive cte Node _ = builder.appendStep(recursiveLastNodeID) - subCtx.recRecursiveScanNodeId = appendRecursiveScanNode(builder, subCtx, initSourceStep, subCtx.sinkTag) - recursiveNodeIDs[i] = subCtx.recRecursiveScanNodeId + recScanId := appendRecursiveScanNode(builder, subCtx, initSourceStep, subCtx.sinkTag) + recursiveNodeIDs[i] = recScanId recursiveSteps[i] = int32(len(builder.qry.Steps)) + + subCtx.recordCteInBinding(table, + CteBindState{ + cteBindType: CteBindTypeRecurStmt, + cte: cteRef, + recScanNodeId: recScanId}) + recursiveLastNodeID, err = builder.bindSelect(&tree.Select{Select: r}, subCtx, false) if err != nil { return } + + //3.2 add Sink Node on the top of single recursive part + recursiveLastNodeID = appendSinkNodeWithTag(builder, subCtx, recursiveLastNodeID, subCtx.sinkTag) + builder.qry.Nodes[recursiveLastNodeID].RecursiveCte = true if !checkOnly { // some check n := builder.qry.Nodes[builder.qry.Nodes[recursiveLastNodeID].Children[0]] @@ -2445,6 +2462,7 @@ func (builder *QueryBuilder) bindRecursiveCte( } } + //4. add CTE Scan Node _ = builder.appendStep(recursiveLastNodeID) nodeID = appendCTEScanNode(builder, ctx, initSourceStep, initCtx.sinkTag) if limitExpr != nil || offsetExpr != nil { @@ -2452,28 +2470,55 @@ func (builder *QueryBuilder) bindRecursiveCte( node.Limit = limitExpr node.Offset = offsetExpr } + //4.1 make recursive parts as the source step of the CTE Scan Node besides initSourceStep of initial statement for i := 0; i < len(recursiveSteps); i++ { builder.qry.Nodes[nodeID].SourceStep = append(builder.qry.Nodes[nodeID].SourceStep, recursiveSteps[i]) } + + //4.2 make CTE scan as the source step of the Recursive cte Node curStep := int32(len(builder.qry.Steps)) for _, id := range recursiveNodeIDs { builder.qry.Nodes[id].SourceStep[0] = curStep } + + //4.3 add Sink Node on top of CTE Scan Node unionAllLastNodeID := appendSinkNodeWithTag(builder, ctx, nodeID, ctx.sinkTag) builder.qry.Nodes[unionAllLastNodeID].RecursiveSink = true - // final statement - ctx.finalSelect = true + //5. bind final statement ctx.sinkTag = initCtx.sinkTag + //5.0 add initial statement as table binding into the ctx of main query err = builder.addBinding(initLastNodeID, *cteRef.ast.Name, ctx) if err != nil { return } + //5.1 add Sink Scan Node as the scan node of the recursive cte sourceStep := builder.appendStep(unionAllLastNodeID) nodeID = appendSinkScanNodeWithTag(builder, ctx, sourceStep, initCtx.sinkTag) return } +// check if binding cte currently +func (bc *BindContext) bindingCte() bool { + return bc.cteState.cteBindType != CteBindTypeNone +} + +// check if binding non recursive cte currently +func (bc *BindContext) bindingNonRecurCte() bool { + return bc.cteState.cteBindType == CteBindTypeNonRecur +} + +// check if binding recursive cte currently +func (bc *BindContext) bindingRecurCte() bool { + return bc.cteState.cteBindType == CteBindTypeInitStmt || + bc.cteState.cteBindType == CteBindTypeRecurStmt +} + +// check if binding recursive part of recursive cte currently +func (bc *BindContext) bindingRecurStmt() bool { + return bc.cteState.cteBindType == CteBindTypeRecurStmt +} + func (builder *QueryBuilder) bindCte( ctx *BindContext, stmt tree.NodeFormatter, @@ -2783,7 +2828,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR return 0, err } } else { - if ctx.recSelect && clause.Distinct { + if ctx.bindingRecurStmt() && clause.Distinct { return 0, moerr.NewParseError(builder.GetContext(), "not support DISTINCT in recursive cte") } @@ -2936,7 +2981,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR } // bind GROUP BY clause if clause.GroupBy != nil || astTimeWindow != nil { - if ctx.recSelect { + if ctx.bindingRecurStmt() { return 0, moerr.NewParseErrorf(builder.GetContext(), "not support group by in recursive cte: '%v'", tree.String(clause.GroupBy, dialect.MYSQL)) } groupBinder := NewGroupBinder(builder, ctx, selectList) @@ -2996,7 +3041,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR // bind HAVING clause havingBinder = NewHavingBinder(builder, ctx) if clause.Having != nil { - if ctx.recSelect { + if ctx.bindingRecurStmt() { return 0, moerr.NewParseErrorf(builder.GetContext(), "not support having in recursive cte: '%v'", tree.String(clause.Having, dialect.MYSQL)) } ctx.binder = havingBinder @@ -3149,7 +3194,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR // bind ORDER BY clause var orderBys []*plan.OrderBySpec if astOrderBy != nil { - if ctx.recSelect { + if ctx.bindingRecurStmt() { return 0, moerr.NewParseErrorf(builder.GetContext(), "not support order by in recursive cte: '%v'", tree.String(&astOrderBy, dialect.MYSQL)) } orderBinder := NewOrderBinder(projectionBinder, selectList) @@ -3290,7 +3335,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR } } else { if len(ctx.groups) > 0 || len(ctx.aggregates) > 0 { - if ctx.recSelect { + if ctx.bindingRecurStmt() { return 0, moerr.NewInternalError(builder.GetContext(), "not support aggregate function recursive cte") } if builder.isForUpdate { @@ -3339,7 +3384,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR // append TIME WINDOW node if len(ctx.times) > 0 { - if ctx.recSelect { + if ctx.bindingRecurStmt() { return 0, moerr.NewInternalError(builder.GetContext(), "not support time window in recursive cte") } nodeID = builder.appendNode(&plan.Node{ @@ -3373,7 +3418,7 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR // append WINDOW node if len(ctx.windows) > 0 { - if ctx.recSelect { + if ctx.bindingRecurStmt() { return 0, moerr.NewInternalError(builder.GetContext(), "not support window function in recursive cte") } @@ -3499,11 +3544,6 @@ func (builder *QueryBuilder) bindSelect(stmt *tree.Select, ctx *BindContext, isR ctx.results = ctx.projects } - if (ctx.initSelect || ctx.recSelect) && !ctx.unionSelect { - nodeID = appendSinkNodeWithTag(builder, ctx, nodeID, ctx.sinkTag) - builder.qry.Nodes[nodeID].RecursiveCte = ctx.recSelect - } - if isRoot { builder.qry.Headings = append(builder.qry.Headings, ctx.headings...) } @@ -4045,16 +4085,18 @@ func (builder *QueryBuilder) buildTable(stmt tree.TableExpr, ctx *BindContext, p break } - if len(schema) == 0 && ctx.normalCTE && table == ctx.cteName { + if len(schema) == 0 && ctx.bindingNonRecurCte() && table == ctx.cteName { return 0, moerr.NewParseErrorf(builder.GetContext(), "In recursive query block of Recursive Common Table Expression %s, the recursive table must be referenced only once, and not in any subquery", table) } else if len(schema) == 0 { cteRef := ctx.findCTE(table) if cteRef != nil { if ctx.cteInBinding(table) { - return 0, moerr.NewParseErrorf(builder.GetContext(), "cte %s reference itself", table) + if ctx.bindingNonRecurCte() { + return 0, moerr.NewParseErrorf(builder.GetContext(), "cte %s reference itself", table) + } } - if ctx.recSelect { - nodeID = ctx.recRecursiveScanNodeId + if ctx.bindingRecurStmt() { + nodeID = ctx.cteState.recScanNodeId return } diff --git a/pkg/sql/plan/types.go b/pkg/sql/plan/types.go index c1df617be3661..7e9a9bdd85097 100644 --- a/pkg/sql/plan/types.go +++ b/pkg/sql/plan/types.go @@ -222,6 +222,32 @@ type CTERef struct { snapshot *Snapshot } +type CteBindState struct { + cte *CTERef + cteBindType int + recScanNodeId int32 +} + +func (state CteBindState) masked(name string) bool { + if state.cte == nil { + return false + } else { + _, ok := state.cte.maskedCTEs[name] + return ok + } +} + +const ( + // does not bind cte currently + CteBindTypeNone = 0 + // bind initial select stmt of recursive cte currently + CteBindTypeInitStmt = 1 + // bind recursive parts of recursive cte currently + CteBindTypeRecurStmt = 2 + // bind non recursive cte currently + CteBindTypeNonRecur = 3 +) + type aliasItem struct { idx int32 astExpr tree.Expr @@ -230,21 +256,19 @@ type aliasItem struct { type BindContext struct { binder Binder - cteByName map[string]*CTERef - maskedCTEs map[string]bool - normalCTE bool - initSelect bool - recSelect bool - finalSelect bool - unionSelect bool - sliding bool - isDistinct bool - isCorrelated bool - hasSingleRow bool - forceWindows bool - isGroupingSet bool - recRecursiveScanNodeId int32 - + //cteByName saves all cte definitions in the current stmt + cteByName map[string]*CTERef + //cteState records state of binding cte + cteState CteBindState + sliding bool + isDistinct bool + isCorrelated bool + hasSingleRow bool + forceWindows bool + isGroupingSet bool + + //cteName denotes the alias of this BindContext. + //it may be from view name, cte name or subquery name cteName string //cte in binding or bound already boundCtes map[string]*CTERef