From 5793d609fbaffa46fee01a781cd3e6fe7bb072db Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Tue, 28 Jan 2025 09:47:33 -0600 Subject: [PATCH] Add VALUES operator and SQL_builder bit Signed-off-by: Florent Poinsard --- .../planbuilder/operators/SQL_builder.go | 363 +---------------- .../planbuilder/operators/SQL_builder_test.go | 58 +++ .../planbuilder/operators/query_builder.go | 375 ++++++++++++++++++ go/vt/vtgate/planbuilder/operators/values.go | 90 +++++ 4 files changed, 535 insertions(+), 351 deletions(-) create mode 100644 go/vt/vtgate/planbuilder/operators/SQL_builder_test.go create mode 100644 go/vt/vtgate/planbuilder/operators/query_builder.go create mode 100644 go/vt/vtgate/planbuilder/operators/values.go diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index ca15b5c9134..3c332412029 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -18,33 +18,13 @@ package operators import ( "fmt" - "slices" - "sort" "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" - "vitess.io/vitess/go/vt/vtgate/semantics" ) -type ( - queryBuilder struct { - ctx *plancontext.PlanningContext - stmt sqlparser.Statement - tableNames []string - dmlOperator Operator - } -) - -func (qb *queryBuilder) asSelectStatement() sqlparser.TableStatement { - return qb.stmt.(sqlparser.TableStatement) - -} -func (qb *queryBuilder) asOrderAndLimit() sqlparser.OrderAndLimit { - return qb.stmt.(sqlparser.OrderAndLimit) -} - func ToSQL(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement, _ Operator, err error) { defer PanicHandler(&err) @@ -56,337 +36,6 @@ func ToSQL(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement return q.stmt, q.dmlOperator, nil } -// includeTable will return false if the table is a CTE, and it is not merged -// it will return true if the table is not a CTE or if it is a CTE and it is merged -func (qb *queryBuilder) includeTable(op *Table) bool { - if qb.ctx.SemTable == nil { - return true - } - tbl, err := qb.ctx.SemTable.TableInfoFor(op.QTable.ID) - if err != nil { - panic(err) - } - cteTbl, isCTE := tbl.(*semantics.CTETable) - if !isCTE { - return true - } - - return cteTbl.Merged -} - -func (qb *queryBuilder) addTable(db, tableName, alias string, tableID semantics.TableSet, hints sqlparser.IndexHints) { - if tableID.NumberOfTables() == 1 && qb.ctx.SemTable != nil { - tblInfo, err := qb.ctx.SemTable.TableInfoFor(tableID) - if err != nil { - panic(err) - } - cte, isCTE := tblInfo.(*semantics.CTETable) - if isCTE { - tableName = cte.TableName - db = "" - } - } - tableExpr := sqlparser.TableName{ - Name: sqlparser.NewIdentifierCS(tableName), - Qualifier: sqlparser.NewIdentifierCS(db), - } - qb.addTableExpr(tableName, alias, tableID, tableExpr, hints, nil) -} - -func (qb *queryBuilder) addTableExpr( - tableName, alias string, - tableID semantics.TableSet, - tblExpr sqlparser.SimpleTableExpr, - hints sqlparser.IndexHints, - columnAliases sqlparser.Columns, -) { - if qb.stmt == nil { - qb.stmt = &sqlparser.Select{} - } - tbl := &sqlparser.AliasedTableExpr{ - Expr: tblExpr, - Partitions: nil, - As: sqlparser.NewIdentifierCS(alias), - Hints: hints, - Columns: columnAliases, - } - qb.ctx.SemTable.ReplaceTableSetFor(tableID, tbl) - qb.stmt.(FromStatement).SetFrom(append(qb.stmt.(FromStatement).GetFrom(), tbl)) - qb.tableNames = append(qb.tableNames, tableName) -} - -func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { - if qb.ctx.ShouldSkip(expr) { - // This is a predicate that was added to the RHS of an ApplyJoin. - // The original predicate will be added, so we don't have to add this here - return - } - - var addPred func(sqlparser.Expr) - - switch stmt := qb.stmt.(type) { - case *sqlparser.Select: - if qb.ctx.ContainsAggr(expr) { - addPred = stmt.AddHaving - } else { - addPred = stmt.AddWhere - } - case *sqlparser.Update: - addPred = stmt.AddWhere - case *sqlparser.Delete: - addPred = stmt.AddWhere - case nil: - // this would happen if we are adding a predicate on a dual query. - // we use this when building recursive CTE queries - sel := &sqlparser.Select{} - addPred = sel.AddWhere - qb.stmt = sel - default: - panic(fmt.Sprintf("cant add WHERE to %T", qb.stmt)) - } - - for _, exp := range sqlparser.SplitAndExpression(nil, expr) { - addPred(exp) - } -} - -func (qb *queryBuilder) addGroupBy(original sqlparser.Expr) { - sel := qb.stmt.(*sqlparser.Select) - sel.AddGroupBy(original) -} - -func (qb *queryBuilder) setWithRollup() { - sel := qb.stmt.(*sqlparser.Select) - sel.GroupBy.WithRollup = true -} - -func (qb *queryBuilder) addProjection(projection sqlparser.SelectExpr) { - switch stmt := qb.stmt.(type) { - case *sqlparser.Select: - stmt.SelectExprs = append(stmt.SelectExprs, projection) - return - case *sqlparser.Union: - if ae, ok := projection.(*sqlparser.AliasedExpr); ok { - if col, ok := ae.Expr.(*sqlparser.ColName); ok { - checkUnionColumnByName(col, stmt) - return - } - } - - qb.pushUnionInsideDerived() - qb.addProjection(projection) - return - } - panic(vterrors.VT13001(fmt.Sprintf("unknown select statement type: %T", qb.stmt))) -} - -func (qb *queryBuilder) pushUnionInsideDerived() { - selStmt := qb.asSelectStatement() - dt := &sqlparser.DerivedTable{ - Lateral: false, - Select: selStmt, - } - sel := &sqlparser.Select{ - From: []sqlparser.TableExpr{&sqlparser.AliasedTableExpr{ - Expr: dt, - As: sqlparser.NewIdentifierCS("dt"), - }}, - } - firstSelect := getFirstSelect(selStmt) - sel.SelectExprs = unionSelects(firstSelect.SelectExprs) - qb.stmt = sel -} - -func unionSelects(exprs sqlparser.SelectExprs) (selectExprs sqlparser.SelectExprs) { - for _, col := range exprs { - switch col := col.(type) { - case *sqlparser.AliasedExpr: - expr := sqlparser.NewColName(col.ColumnName()) - selectExprs = append(selectExprs, &sqlparser.AliasedExpr{Expr: expr}) - default: - selectExprs = append(selectExprs, col) - } - } - return -} - -func checkUnionColumnByName(column *sqlparser.ColName, sel sqlparser.TableStatement) { - colName := column.Name.String() - firstSelect := getFirstSelect(sel) - exprs := firstSelect.SelectExprs - offset := slices.IndexFunc(exprs, func(expr sqlparser.SelectExpr) bool { - switch ae := expr.(type) { - case *sqlparser.StarExpr: - return true - case *sqlparser.AliasedExpr: - // When accessing columns on top of a UNION, we fall back to this simple strategy of string comparisons - return ae.ColumnName() == colName - } - return false - }) - if offset == -1 { - panic(vterrors.VT12001(fmt.Sprintf("did not find column [%s] on UNION", sqlparser.String(column)))) - } -} - -func (qb *queryBuilder) clearProjections() { - sel, isSel := qb.stmt.(*sqlparser.Select) - if !isSel { - return - } - sel.SelectExprs = nil -} - -func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) { - qb.stmt = &sqlparser.Union{ - Left: qb.asSelectStatement(), - Right: other.asSelectStatement(), - Distinct: distinct, - } -} - -func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string, distinct bool, columns sqlparser.Columns) { - cteUnion := &sqlparser.Union{ - Left: qb.stmt.(sqlparser.TableStatement), - Right: other.stmt.(sqlparser.TableStatement), - Distinct: distinct, - } - - qb.stmt = &sqlparser.Select{ - With: &sqlparser.With{ - Recursive: true, - CTEs: []*sqlparser.CommonTableExpr{{ - ID: sqlparser.NewIdentifierCS(name), - Columns: columns, - Subquery: cteUnion, - }}, - }, - } - - qb.addTable("", name, alias, "", nil) -} - -type FromStatement interface { - GetFrom() []sqlparser.TableExpr - SetFrom([]sqlparser.TableExpr) - GetWherePredicate() sqlparser.Expr - SetWherePredicate(sqlparser.Expr) -} - -var _ FromStatement = (*sqlparser.Select)(nil) -var _ FromStatement = (*sqlparser.Update)(nil) -var _ FromStatement = (*sqlparser.Delete)(nil) - -func (qb *queryBuilder) joinWith(other *queryBuilder, onCondition sqlparser.Expr, joinType sqlparser.JoinType) { - stmt := qb.stmt.(FromStatement) - otherStmt := other.stmt.(FromStatement) - - if sel, isSel := stmt.(*sqlparser.Select); isSel { - otherSel := otherStmt.(*sqlparser.Select) - sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...) - } - - qb.mergeWhereClauses(stmt, otherStmt) - - var newFromClause []sqlparser.TableExpr - switch joinType { - case sqlparser.NormalJoinType: - newFromClause = append(stmt.GetFrom(), otherStmt.GetFrom()...) - for _, pred := range sqlparser.SplitAndExpression(nil, onCondition) { - qb.addPredicate(pred) - } - default: - newFromClause = []sqlparser.TableExpr{buildJoin(stmt, otherStmt, onCondition, joinType)} - } - - stmt.SetFrom(newFromClause) -} - -func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) { - predicate := stmt.GetWherePredicate() - if otherPredicate := otherStmt.GetWherePredicate(); otherPredicate != nil { - predExprs := sqlparser.SplitAndExpression(nil, predicate) - otherExprs := sqlparser.SplitAndExpression(nil, otherPredicate) - predicate = qb.ctx.SemTable.AndExpressions(append(predExprs, otherExprs...)...) - } - if predicate != nil { - stmt.SetWherePredicate(predicate) - } -} - -func buildJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr, joinType sqlparser.JoinType) *sqlparser.JoinTableExpr { - var lhs sqlparser.TableExpr - fromClause := stmt.GetFrom() - if len(fromClause) == 1 { - lhs = fromClause[0] - } else { - lhs = &sqlparser.ParenTableExpr{Exprs: fromClause} - } - var rhs sqlparser.TableExpr - otherFromClause := otherStmt.GetFrom() - if len(otherFromClause) == 1 { - rhs = otherFromClause[0] - } else { - rhs = &sqlparser.ParenTableExpr{Exprs: otherFromClause} - } - - return &sqlparser.JoinTableExpr{ - LeftExpr: lhs, - RightExpr: rhs, - Join: joinType, - Condition: &sqlparser.JoinCondition{ - On: onCondition, - }, - } -} - -func (qb *queryBuilder) sortTables() { - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - sel, isSel := node.(*sqlparser.Select) - if !isSel { - return true, nil - } - ts := &tableSorter{ - sel: sel, - tbl: qb.ctx.SemTable, - } - sort.Sort(ts) - return true, nil - }, qb.stmt) - -} - -type tableSorter struct { - sel *sqlparser.Select - tbl *semantics.SemTable -} - -// Len implements the Sort interface -func (ts *tableSorter) Len() int { - return len(ts.sel.From) -} - -// Less implements the Sort interface -func (ts *tableSorter) Less(i, j int) bool { - lhs := ts.sel.From[i] - rhs := ts.sel.From[j] - left, ok := lhs.(*sqlparser.AliasedTableExpr) - if !ok { - return i < j - } - right, ok := rhs.(*sqlparser.AliasedTableExpr) - if !ok { - return i < j - } - - return ts.tbl.TableSetFor(left).TableOffset() < ts.tbl.TableSetFor(right).TableOffset() -} - -// Swap implements the Sort interface -func (ts *tableSorter) Swap(i, j int) { - ts.sel.From[i], ts.sel.From[j] = ts.sel.From[j], ts.sel.From[i] -} - func removeKeyspaceFromSelectExpr(expr sqlparser.SelectExpr) { switch expr := expr.(type) { case *sqlparser.AliasedExpr: @@ -467,11 +116,23 @@ func buildQuery(op Operator, qb *queryBuilder) { buildDML(op, qb) case *RecurseCTE: buildRecursiveCTE(op, qb) + case *Values: + buildValues(op, qb) default: panic(vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op))) } } +func buildValues(op *Values, qb *queryBuilder) { + buildQuery(op.Source, qb) + qb.addTableExpr(op.Name, op.Name, TableID(op), &sqlparser.DerivedTable{ + Select: &sqlparser.ValuesStatement{ + ListArg: sqlparser.NewListArg(op.Arg), + }, + }, nil, op.Columns) + +} + func buildDelete(op *Delete, qb *queryBuilder) { qb.stmt = &sqlparser.Delete{ Ignore: op.Ignore, diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder_test.go b/go/vt/vtgate/planbuilder/operators/SQL_builder_test.go new file mode 100644 index 00000000000..153b0e2a4d7 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder_test.go @@ -0,0 +1,58 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +func TestToSQLValues(t *testing.T) { + ctx := plancontext.PlanningContext{} + + tableName := sqlparser.NewTableName("x") + tableColumn := sqlparser.NewColName("id") + valuesColumn := sqlparser.NewIdentifierCI("user_id") + op := &Values{ + unaryOperator: newUnaryOp(&Table{ + QTable: &QueryTable{ + Table: tableName, + Alias: sqlparser.NewAliasedTableExpr(tableName, ""), + }, + Columns: []*sqlparser.ColName{tableColumn}, + }), + Columns: sqlparser.Columns{valuesColumn}, + Name: "t", + Arg: "toto", + } + + stmt, _, err := ToSQL(&ctx, op) + require.NoError(t, err) + require.Equal(t, "select id from x, (values ::toto) as t(user_id)", sqlparser.String(stmt)) + + proj := newAliasedProjection(op) + proj.addUnexploredExpr(sqlparser.NewAliasedExpr(tableColumn, ""), tableColumn) + proj.addUnexploredExpr(sqlparser.NewAliasedExpr(sqlparser.NewColNameWithQualifier("user_id", sqlparser.NewTableName("t")), ""), sqlparser.NewColNameWithQualifier("user_id", sqlparser.NewTableName("t"))) + + stmt, _, err = ToSQL(&ctx, proj) + require.NoError(t, err) + require.Equal(t, "select id, t.user_id from x, (values ::toto) as t(user_id)", sqlparser.String(stmt)) +} diff --git a/go/vt/vtgate/planbuilder/operators/query_builder.go b/go/vt/vtgate/planbuilder/operators/query_builder.go new file mode 100644 index 00000000000..8937f99e347 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/query_builder.go @@ -0,0 +1,375 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "fmt" + "slices" + "sort" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +type ( + queryBuilder struct { + ctx *plancontext.PlanningContext + stmt sqlparser.Statement + tableNames []string + dmlOperator Operator + } +) + +func (qb *queryBuilder) asSelectStatement() sqlparser.TableStatement { + return qb.stmt.(sqlparser.TableStatement) + +} +func (qb *queryBuilder) asOrderAndLimit() sqlparser.OrderAndLimit { + return qb.stmt.(sqlparser.OrderAndLimit) +} + +// includeTable will return false if the table is a CTE, and it is not merged +// it will return true if the table is not a CTE or if it is a CTE and it is merged +func (qb *queryBuilder) includeTable(op *Table) bool { + if qb.ctx.SemTable == nil { + return true + } + tbl, err := qb.ctx.SemTable.TableInfoFor(op.QTable.ID) + if err != nil { + panic(err) + } + cteTbl, isCTE := tbl.(*semantics.CTETable) + if !isCTE { + return true + } + + return cteTbl.Merged +} + +func (qb *queryBuilder) addTable(db, tableName, alias string, tableID semantics.TableSet, hints sqlparser.IndexHints) { + if tableID.NumberOfTables() == 1 && qb.ctx.SemTable != nil { + tblInfo, err := qb.ctx.SemTable.TableInfoFor(tableID) + if err != nil { + panic(err) + } + cte, isCTE := tblInfo.(*semantics.CTETable) + if isCTE { + tableName = cte.TableName + db = "" + } + } + tableExpr := sqlparser.TableName{ + Name: sqlparser.NewIdentifierCS(tableName), + Qualifier: sqlparser.NewIdentifierCS(db), + } + qb.addTableExpr(tableName, alias, tableID, tableExpr, hints, nil) +} + +func (qb *queryBuilder) addTableExpr( + tableName, alias string, + tableID semantics.TableSet, + tblExpr sqlparser.SimpleTableExpr, + hints sqlparser.IndexHints, + columnAliases sqlparser.Columns, +) { + if qb.stmt == nil { + qb.stmt = &sqlparser.Select{} + } + tbl := &sqlparser.AliasedTableExpr{ + Expr: tblExpr, + Partitions: nil, + As: sqlparser.NewIdentifierCS(alias), + Hints: hints, + Columns: columnAliases, + } + qb.ctx.SemTable.ReplaceTableSetFor(tableID, tbl) + qb.stmt.(FromStatement).SetFrom(append(qb.stmt.(FromStatement).GetFrom(), tbl)) + qb.tableNames = append(qb.tableNames, tableName) +} + +func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { + if qb.ctx.ShouldSkip(expr) { + // This is a predicate that was added to the RHS of an ApplyJoin. + // The original predicate will be added, so we don't have to add this here + return + } + + var addPred func(sqlparser.Expr) + + switch stmt := qb.stmt.(type) { + case *sqlparser.Select: + if qb.ctx.ContainsAggr(expr) { + addPred = stmt.AddHaving + } else { + addPred = stmt.AddWhere + } + case *sqlparser.Update: + addPred = stmt.AddWhere + case *sqlparser.Delete: + addPred = stmt.AddWhere + case nil: + // this would happen if we are adding a predicate on a dual query. + // we use this when building recursive CTE queries + sel := &sqlparser.Select{} + addPred = sel.AddWhere + qb.stmt = sel + default: + panic(fmt.Sprintf("cant add WHERE to %T", qb.stmt)) + } + + for _, exp := range sqlparser.SplitAndExpression(nil, expr) { + addPred(exp) + } +} + +func (qb *queryBuilder) addGroupBy(original sqlparser.Expr) { + sel := qb.stmt.(*sqlparser.Select) + sel.AddGroupBy(original) +} + +func (qb *queryBuilder) setWithRollup() { + sel := qb.stmt.(*sqlparser.Select) + sel.GroupBy.WithRollup = true +} + +func (qb *queryBuilder) addProjection(projection sqlparser.SelectExpr) { + switch stmt := qb.stmt.(type) { + case *sqlparser.Select: + stmt.SelectExprs = append(stmt.SelectExprs, projection) + return + case *sqlparser.Union: + if ae, ok := projection.(*sqlparser.AliasedExpr); ok { + if col, ok := ae.Expr.(*sqlparser.ColName); ok { + checkUnionColumnByName(col, stmt) + return + } + } + + qb.pushUnionInsideDerived() + qb.addProjection(projection) + return + } + panic(vterrors.VT13001(fmt.Sprintf("unknown select statement type: %T", qb.stmt))) +} + +func (qb *queryBuilder) pushUnionInsideDerived() { + selStmt := qb.asSelectStatement() + dt := &sqlparser.DerivedTable{ + Lateral: false, + Select: selStmt, + } + sel := &sqlparser.Select{ + From: []sqlparser.TableExpr{&sqlparser.AliasedTableExpr{ + Expr: dt, + As: sqlparser.NewIdentifierCS("dt"), + }}, + } + firstSelect := getFirstSelect(selStmt) + sel.SelectExprs = unionSelects(firstSelect.SelectExprs) + qb.stmt = sel +} + +func (qb *queryBuilder) clearProjections() { + sel, isSel := qb.stmt.(*sqlparser.Select) + if !isSel { + return + } + sel.SelectExprs = nil +} + +func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) { + qb.stmt = &sqlparser.Union{ + Left: qb.asSelectStatement(), + Right: other.asSelectStatement(), + Distinct: distinct, + } +} + +func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string, distinct bool, columns sqlparser.Columns) { + cteUnion := &sqlparser.Union{ + Left: qb.stmt.(sqlparser.TableStatement), + Right: other.stmt.(sqlparser.TableStatement), + Distinct: distinct, + } + + qb.stmt = &sqlparser.Select{ + With: &sqlparser.With{ + Recursive: true, + CTEs: []*sqlparser.CommonTableExpr{{ + ID: sqlparser.NewIdentifierCS(name), + Columns: columns, + Subquery: cteUnion, + }}, + }, + } + + qb.addTable("", name, alias, "", nil) +} + +func (qb *queryBuilder) joinWith(other *queryBuilder, onCondition sqlparser.Expr, joinType sqlparser.JoinType) { + stmt := qb.stmt.(FromStatement) + otherStmt := other.stmt.(FromStatement) + + if sel, isSel := stmt.(*sqlparser.Select); isSel { + otherSel := otherStmt.(*sqlparser.Select) + sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...) + } + + qb.mergeWhereClauses(stmt, otherStmt) + + var newFromClause []sqlparser.TableExpr + switch joinType { + case sqlparser.NormalJoinType: + newFromClause = append(stmt.GetFrom(), otherStmt.GetFrom()...) + for _, pred := range sqlparser.SplitAndExpression(nil, onCondition) { + qb.addPredicate(pred) + } + default: + newFromClause = []sqlparser.TableExpr{buildJoin(stmt, otherStmt, onCondition, joinType)} + } + + stmt.SetFrom(newFromClause) +} + +func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) { + predicate := stmt.GetWherePredicate() + if otherPredicate := otherStmt.GetWherePredicate(); otherPredicate != nil { + predExprs := sqlparser.SplitAndExpression(nil, predicate) + otherExprs := sqlparser.SplitAndExpression(nil, otherPredicate) + predicate = qb.ctx.SemTable.AndExpressions(append(predExprs, otherExprs...)...) + } + if predicate != nil { + stmt.SetWherePredicate(predicate) + } +} + +type tableSorter struct { + sel *sqlparser.Select + tbl *semantics.SemTable +} + +// Len implements the Sort interface +func (ts *tableSorter) Len() int { + return len(ts.sel.From) +} + +// Less implements the Sort interface +func (ts *tableSorter) Less(i, j int) bool { + lhs := ts.sel.From[i] + rhs := ts.sel.From[j] + left, ok := lhs.(*sqlparser.AliasedTableExpr) + if !ok { + return i < j + } + right, ok := rhs.(*sqlparser.AliasedTableExpr) + if !ok { + return i < j + } + + return ts.tbl.TableSetFor(left).TableOffset() < ts.tbl.TableSetFor(right).TableOffset() +} + +// Swap implements the Sort interface +func (ts *tableSorter) Swap(i, j int) { + ts.sel.From[i], ts.sel.From[j] = ts.sel.From[j], ts.sel.From[i] +} + +func (qb *queryBuilder) sortTables() { + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + sel, isSel := node.(*sqlparser.Select) + if !isSel { + return true, nil + } + ts := &tableSorter{ + sel: sel, + tbl: qb.ctx.SemTable, + } + sort.Sort(ts) + return true, nil + }, qb.stmt) +} + +func unionSelects(exprs sqlparser.SelectExprs) (selectExprs sqlparser.SelectExprs) { + for _, col := range exprs { + switch col := col.(type) { + case *sqlparser.AliasedExpr: + expr := sqlparser.NewColName(col.ColumnName()) + selectExprs = append(selectExprs, &sqlparser.AliasedExpr{Expr: expr}) + default: + selectExprs = append(selectExprs, col) + } + } + return +} + +func checkUnionColumnByName(column *sqlparser.ColName, sel sqlparser.TableStatement) { + colName := column.Name.String() + firstSelect := getFirstSelect(sel) + exprs := firstSelect.SelectExprs + offset := slices.IndexFunc(exprs, func(expr sqlparser.SelectExpr) bool { + switch ae := expr.(type) { + case *sqlparser.StarExpr: + return true + case *sqlparser.AliasedExpr: + // When accessing columns on top of a UNION, we fall back to this simple strategy of string comparisons + return ae.ColumnName() == colName + } + return false + }) + if offset == -1 { + panic(vterrors.VT12001(fmt.Sprintf("did not find column [%s] on UNION", sqlparser.String(column)))) + } +} + +type FromStatement interface { + GetFrom() []sqlparser.TableExpr + SetFrom([]sqlparser.TableExpr) + GetWherePredicate() sqlparser.Expr + SetWherePredicate(sqlparser.Expr) +} + +var _ FromStatement = (*sqlparser.Select)(nil) +var _ FromStatement = (*sqlparser.Update)(nil) +var _ FromStatement = (*sqlparser.Delete)(nil) + +func buildJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr, joinType sqlparser.JoinType) *sqlparser.JoinTableExpr { + var lhs sqlparser.TableExpr + fromClause := stmt.GetFrom() + if len(fromClause) == 1 { + lhs = fromClause[0] + } else { + lhs = &sqlparser.ParenTableExpr{Exprs: fromClause} + } + var rhs sqlparser.TableExpr + otherFromClause := otherStmt.GetFrom() + if len(otherFromClause) == 1 { + rhs = otherFromClause[0] + } else { + rhs = &sqlparser.ParenTableExpr{Exprs: otherFromClause} + } + + return &sqlparser.JoinTableExpr{ + LeftExpr: lhs, + RightExpr: rhs, + Join: joinType, + Condition: &sqlparser.JoinCondition{ + On: onCondition, + }, + } +} diff --git a/go/vt/vtgate/planbuilder/operators/values.go b/go/vt/vtgate/planbuilder/operators/values.go new file mode 100644 index 00000000000..9d78f61f4b1 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/values.go @@ -0,0 +1,90 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "fmt" + "slices" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +type Values struct { + unaryOperator + + Columns sqlparser.Columns + Name string + Arg string +} + +func (v *Values) Clone(inputs []Operator) Operator { + clone := *v + clone.Columns = slices.Clone(v.Columns) + return &clone +} + +func (v *Values) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { + return newFilter(v, expr) +} + +func (v *Values) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { + panic(vterrors.VT13001("we cannot add new columns to a Values operator")) +} + +func (v *Values) AddWSColumn(ctx *plancontext.PlanningContext, offset int, underRoute bool) int { + panic(vterrors.VT13001("we cannot add new columns to a Values operator")) +} + +func (v *Values) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { + col, ok := expr.(*sqlparser.ColName) + if !ok { + return -1 + } + for i, column := range v.Columns { + if col.Name.Equal(column) { + return i + } + } + return -1 +} + +func (v *Values) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { + var cols []*sqlparser.AliasedExpr + for _, column := range v.Columns { + cols = append(cols, sqlparser.NewAliasedExpr(sqlparser.NewColName(column.String()), "")) + } + return cols +} + +func (v *Values) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { + r := v.GetColumns(ctx) + var selectExprs sqlparser.SelectExprs + for _, expr := range r { + selectExprs = append(selectExprs, expr) + } + return selectExprs +} + +func (v *Values) ShortDescription() string { + return fmt.Sprintf("%s (%s)", v.Name, sqlparser.String(v.Columns)) +} + +func (v *Values) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy { + return v.Source.GetOrdering(ctx) +}