diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 50d3a4b6bbf..77257ed8303 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -1475,6 +1475,41 @@ func (cached *VStream) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.Position))) return size } +func (cached *ValuesJoin) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(128) + } + // field Left vitess.io/vitess/go/vt/vtgate/engine.Primitive + if cc, ok := cached.Left.(cachedObject); ok { + size += cc.CachedSize(true) + } + // field Right vitess.io/vitess/go/vt/vtgate/engine.Primitive + if cc, ok := cached.Right.(cachedObject); ok { + size += cc.CachedSize(true) + } + // field CopyColumnsToRHS []int + { + size += hack.RuntimeAllocSize(int64(cap(cached.CopyColumnsToRHS)) * int64(8)) + } + // field BindVarName string + size += hack.RuntimeAllocSize(int64(len(cached.BindVarName))) + // field Cols []int + { + size += hack.RuntimeAllocSize(int64(cap(cached.Cols)) * int64(8)) + } + // field ColNames []string + { + size += hack.RuntimeAllocSize(int64(cap(cached.ColNames)) * int64(16)) + for _, elem := range cached.ColNames { + size += hack.RuntimeAllocSize(int64(len(elem))) + } + } + return size +} func (cached *Verify) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/engine/delete_test.go b/go/vt/vtgate/engine/delete_test.go index 18dcef5cbe4..56d3467aac3 100644 --- a/go/vt/vtgate/engine/delete_test.go +++ b/go/vt/vtgate/engine/delete_test.go @@ -45,7 +45,7 @@ func TestDeleteUnsharded(t *testing.T) { }, } - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -80,7 +80,7 @@ func TestDeleteEqual(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -112,7 +112,7 @@ func TestDeleteEqualMultiCol(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -148,7 +148,7 @@ func TestDeleteEqualNoRoute(t *testing.T) { }, } - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -181,7 +181,7 @@ func TestDeleteEqualNoScatter(t *testing.T) { }, } - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.EqualError(t, err, "cannot map vindex to unique keyspace id: DestinationKeyRange(-)") } @@ -213,7 +213,7 @@ func TestDeleteOwnedVindex(t *testing.T) { "1|4|5|6", )} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.results = results _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -231,7 +231,7 @@ func TestDeleteOwnedVindex(t *testing.T) { }) // No rows changing - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") _, err = del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -252,7 +252,7 @@ func TestDeleteOwnedVindex(t *testing.T) { "1|4|5|6", "1|7|8|9", )} - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") vc.results = results _, err = del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -300,7 +300,7 @@ func TestDeleteOwnedVindexMultiCol(t *testing.T) { "1|2|4", )} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.results = results _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -371,7 +371,7 @@ func TestDeleteSharded(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -399,7 +399,7 @@ func TestDeleteShardedStreaming(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") err := del.TryStreamExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false, func(result *sqltypes.Result) error { return nil }) @@ -435,7 +435,7 @@ func TestDeleteScatterOwnedVindex(t *testing.T) { "1|4|5|6", )} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.results = results _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -453,7 +453,7 @@ func TestDeleteScatterOwnedVindex(t *testing.T) { }) // No rows changing - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") _, err = del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) @@ -475,7 +475,7 @@ func TestDeleteScatterOwnedVindex(t *testing.T) { "1|4|5|6", "1|7|8|9", )} - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") vc.results = results _, err = del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -528,7 +528,7 @@ func TestDeleteInChangedVindexMultiCol(t *testing.T) { "1|3|6", "2|3|7", )} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.results = results _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -565,7 +565,7 @@ func TestDeleteEqualSubshard(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"-20", "20-"} _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) @@ -602,7 +602,7 @@ func TestDeleteMultiEqual(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"-20", "20-"} _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) @@ -635,7 +635,7 @@ func TestDeleteInUnique(t *testing.T) { Type: querypb.Type_TUPLE, Values: append([]*querypb.Value{sqltypes.ValueToProto(sqltypes.NewInt64(1))}, sqltypes.ValueToProto(sqltypes.NewInt64(2)), sqltypes.ValueToProto(sqltypes.NewInt64(4))), } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"-20", "20-"} _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{"__vals": tupleBV}, false) require.NoError(t, err) diff --git a/go/vt/vtgate/engine/dml_with_input_test.go b/go/vt/vtgate/engine/dml_with_input_test.go index 6fcf2040dfc..38d9068b433 100644 --- a/go/vt/vtgate/engine/dml_with_input_test.go +++ b/go/vt/vtgate/engine/dml_with_input_test.go @@ -51,7 +51,7 @@ func TestDeleteWithInputSingleOffset(t *testing.T) { OutputCols: [][]int{{0}}, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -95,7 +95,7 @@ func TestDeleteWithInputMultiOffset(t *testing.T) { OutputCols: [][]int{{1, 0}}, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -160,7 +160,7 @@ func TestDeleteWithMultiTarget(t *testing.T) { OutputCols: [][]int{{0}, {1, 2}}, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -210,7 +210,7 @@ func TestUpdateWithInputNonLiteral(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.results = []*sqltypes.Result{ {RowsAffected: 1}, {RowsAffected: 1}, {RowsAffected: 1}, } diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index f3ab5ad5336..bddbca87664 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -46,6 +46,8 @@ type fakePrimitive struct { allResultsInOneCall bool async bool + + useNewPrintBindVars bool } func (f *fakePrimitive) Inputs() ([]Primitive, []map[string]any) { @@ -72,7 +74,12 @@ func (f *fakePrimitive) GetTableName() string { } func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { - f.log = append(f.log, fmt.Sprintf("Execute %v %v", printBindVars(bindVars), wantfields)) + if f.useNewPrintBindVars { + f.log = append(f.log, fmt.Sprintf("Execute %v %v", printBindVars(bindVars), wantfields)) + } else { + f.log = append(f.log, fmt.Sprintf("Execute %v %v", deprecatedPrintBindVars(bindVars), wantfields)) + } + if f.results == nil { return nil, f.sendErr } @@ -87,7 +94,7 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { if !f.noLog { - f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields)) + f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", deprecatedPrintBindVars(bindVars), wantfields)) } if f.results == nil { return f.sendErr @@ -171,7 +178,7 @@ func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error { } func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars))) + f.log = append(f.log, fmt.Sprintf("GetFields %v", deprecatedPrintBindVars(bindVars))) return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */) } diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 060d2ebcfcb..03df2aaea88 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -597,7 +597,7 @@ func (f *loggingVCursor) Execute(ctx context.Context, method string, query strin case vtgatepb.CommitOrder_AUTOCOMMIT: name = "ExecuteAutocommit" } - f.log = append(f.log, fmt.Sprintf("%s %s %v %v", name, query, printBindVars(bindvars), rollbackOnError)) + f.log = append(f.log, fmt.Sprintf("%s %s %v %v", name, query, deprecatedPrintBindVars(bindvars), rollbackOnError)) return f.nextResult() } @@ -621,7 +621,7 @@ func (f *loggingVCursor) AutocommitApproval() bool { } func (f *loggingVCursor) ExecuteStandalone(ctx context.Context, _ Primitive, query string, bindvars map[string]*querypb.BindVariable, rs *srvtopo.ResolvedShard, fetchLastInsertID bool) (*sqltypes.Result, error) { - f.log = append(f.log, fmt.Sprintf("ExecuteStandalone %s %v %s %s", query, printBindVars(bindvars), rs.Target.Keyspace, rs.Target.Shard)) + f.log = append(f.log, fmt.Sprintf("ExecuteStandalone %s %v %s %s", query, deprecatedPrintBindVars(bindvars), rs.Target.Keyspace, rs.Target.Shard)) return f.nextResult() } @@ -943,6 +943,24 @@ func expectResultAnyOrder(t *testing.T, result, want *sqltypes.Result) { } } +// deprecatedPrintBindVars does not print bind variables, specifically tuples, correctly. +// We should use printBindVars instead. +func deprecatedPrintBindVars(bindvars map[string]*querypb.BindVariable) string { + var keys []string + for k := range bindvars { + keys = append(keys, k) + } + sort.Strings(keys) + buf := &bytes.Buffer{} + for i, k := range keys { + if i > 0 { + fmt.Fprintf(buf, " ") + } + fmt.Fprintf(buf, "%s: %v", k, bindvars[k]) + } + return buf.String() +} + func printBindVars(bindvars map[string]*querypb.BindVariable) string { var keys []string for k := range bindvars { @@ -954,6 +972,27 @@ func printBindVars(bindvars map[string]*querypb.BindVariable) string { if i > 0 { fmt.Fprintf(buf, " ") } + + if bindvars[k].Type == querypb.Type_TUPLE { + fmt.Fprintf(buf, "%s: [", k) + for _, val := range bindvars[k].Values { + if val.Type != querypb.Type_TUPLE { + fmt.Fprintf(buf, "[%s]", val.String()) + continue + } + var s []string + v := sqltypes.ProtoToValue(val) + err := v.ForEachValue(func(bv sqltypes.Value) { + s = append(s, bv.String()) + }) + if err != nil { + panic(err) + } + fmt.Fprintf(buf, "[%s]", strings.Join(s, " ")) + } + fmt.Fprintf(buf, "]") + continue + } fmt.Fprintf(buf, "%s: %v", k, bindvars[k]) } return buf.String() @@ -962,7 +1001,7 @@ func printBindVars(bindvars map[string]*querypb.BindVariable) string { func printResolvedShardQueries(rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery) string { buf := &bytes.Buffer{} for i, rs := range rss { - fmt.Fprintf(buf, "%s.%s: %s {%s} ", rs.Target.Keyspace, rs.Target.Shard, queries[i].Sql, printBindVars(queries[i].BindVariables)) + fmt.Fprintf(buf, "%s.%s: %s {%s} ", rs.Target.Keyspace, rs.Target.Shard, queries[i].Sql, deprecatedPrintBindVars(queries[i].BindVariables)) } return buf.String() } @@ -970,7 +1009,7 @@ func printResolvedShardQueries(rss []*srvtopo.ResolvedShard, queries []*querypb. func printResolvedShardsBindVars(rss []*srvtopo.ResolvedShard, bvs []map[string]*querypb.BindVariable) string { buf := &bytes.Buffer{} for i, rs := range rss { - fmt.Fprintf(buf, "%s.%s: {%v} ", rs.Target.Keyspace, rs.Target.Shard, printBindVars(bvs[i])) + fmt.Fprintf(buf, "%s.%s: {%v} ", rs.Target.Keyspace, rs.Target.Shard, deprecatedPrintBindVars(bvs[i])) } return buf.String() } diff --git a/go/vt/vtgate/engine/fk_cascade_test.go b/go/vt/vtgate/engine/fk_cascade_test.go index 942fe44a709..c93e487067b 100644 --- a/go/vt/vtgate/engine/fk_cascade_test.go +++ b/go/vt/vtgate/engine/fk_cascade_test.go @@ -62,7 +62,7 @@ func TestDeleteCascade(t *testing.T) { Parent: parentP, } - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") vc.results = []*sqltypes.Result{fakeRes} _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) require.NoError(t, err) @@ -123,7 +123,7 @@ func TestUpdateCascade(t *testing.T) { Parent: parentP, } - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") vc.results = []*sqltypes.Result{fakeRes} _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) require.NoError(t, err) @@ -195,7 +195,7 @@ func TestNonLiteralUpdateCascade(t *testing.T) { Parent: parentP, } - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") vc.results = []*sqltypes.Result{fakeRes} _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) require.NoError(t, err) diff --git a/go/vt/vtgate/engine/fk_verify_test.go b/go/vt/vtgate/engine/fk_verify_test.go index 5c9ff83c2ec..465dd81d3b2 100644 --- a/go/vt/vtgate/engine/fk_verify_test.go +++ b/go/vt/vtgate/engine/fk_verify_test.go @@ -58,7 +58,7 @@ func TestFKVerifyUpdate(t *testing.T) { t.Run("foreign key verification success", func(t *testing.T) { fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("1", "int64")) - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") vc.results = []*sqltypes.Result{fakeRes} _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) require.NoError(t, err) @@ -83,7 +83,7 @@ func TestFKVerifyUpdate(t *testing.T) { t.Run("parent foreign key verification failure", func(t *testing.T) { // No results from select, should cause the foreign key verification to fail. fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("1", "int64"), "1", "1", "1") - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") vc.results = []*sqltypes.Result{fakeRes} _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) require.ErrorContains(t, err, "Cannot add or update a child row: a foreign key constraint fails") @@ -105,7 +105,7 @@ func TestFKVerifyUpdate(t *testing.T) { t.Run("child foreign key verification failure", func(t *testing.T) { // No results from select, should cause the foreign key verification to fail. fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("1", "int64"), "1", "1", "1") - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") vc.results = []*sqltypes.Result{fakeRes} _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) require.ErrorContains(t, err, "Cannot delete or update a parent row: a foreign key constraint fails") diff --git a/go/vt/vtgate/engine/insert_test.go b/go/vt/vtgate/engine/insert_test.go index 2de95e5d186..5e66649f82e 100644 --- a/go/vt/vtgate/engine/insert_test.go +++ b/go/vt/vtgate/engine/insert_test.go @@ -42,7 +42,7 @@ func TestInsertUnsharded(t *testing.T) { "dummy_insert", ) - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") vc.results = []*sqltypes.Result{{ InsertID: 4, }} @@ -91,7 +91,7 @@ func TestInsertUnshardedGenerate(t *testing.T) { ), } - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") vc.results = []*sqltypes.Result{ sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -144,7 +144,7 @@ func TestInsertUnshardedGenerate_Zeros(t *testing.T) { ), } - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") vc.results = []*sqltypes.Result{ sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -215,7 +215,7 @@ func TestInsertShardedSimple(t *testing.T) { }, nil, ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -254,7 +254,7 @@ func TestInsertShardedSimple(t *testing.T) { }, nil, ) - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} _, err = ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -297,7 +297,7 @@ func TestInsertShardedSimple(t *testing.T) { ) ins.MultiShardAutocommit = true - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} _, err = ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -366,7 +366,7 @@ func TestInsertShardWithONDuplicateKey(t *testing.T) { }, }}}, ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{ @@ -412,7 +412,7 @@ func TestInsertShardWithONDuplicateKey(t *testing.T) { &sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix"), Expr: &sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}}, }, ) - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} _, err = ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -457,7 +457,7 @@ func TestInsertShardWithONDuplicateKey(t *testing.T) { ) ins.MultiShardAutocommit = true - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} _, err = ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -590,7 +590,7 @@ func TestInsertShardedGenerate(t *testing.T) { ), } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} vc.results = []*sqltypes.Result{ sqltypes.MakeTestResult( @@ -715,7 +715,7 @@ func TestInsertShardedOwned(t *testing.T) { nil, ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -807,7 +807,7 @@ func TestInsertShardedOwnedWithNull(t *testing.T) { nil, ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -893,7 +893,7 @@ func TestInsertShardedGeo(t *testing.T) { nil, ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20"} _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -1029,7 +1029,7 @@ func TestInsertShardedIgnoreOwned(t *testing.T) { "\x00", ) noresult := &sqltypes.Result{} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20"} vc.results = []*sqltypes.Result{ // primary vindex lookups: fail row 2. @@ -1147,7 +1147,7 @@ func TestInsertShardedIgnoreOwnedWithNull(t *testing.T) { ), "\x00", ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"-20", "20-"} vc.results = []*sqltypes.Result{ ksid0, @@ -1267,7 +1267,7 @@ func TestInsertShardedUnownedVerify(t *testing.T) { "1", ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} vc.results = []*sqltypes.Result{ nonemptyResult, @@ -1381,7 +1381,7 @@ func TestInsertShardedIgnoreUnownedVerify(t *testing.T) { "1", ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20"} vc.results = []*sqltypes.Result{ nonemptyResult, @@ -1472,7 +1472,7 @@ func TestInsertShardedIgnoreUnownedVerifyFail(t *testing.T) { nil, ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.EqualError(t, err, `values [[INT64(2)]] for column [c3] does not map to keyspace ids`) @@ -1578,7 +1578,7 @@ func TestInsertShardedUnownedReverseMap(t *testing.T) { "1", ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} vc.results = []*sqltypes.Result{ nonemptyResult, @@ -1663,7 +1663,7 @@ func TestInsertShardedUnownedReverseMapSuccess(t *testing.T) { nil, ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) @@ -1694,7 +1694,7 @@ func TestInsertSelectSimple(t *testing.T) { Keyspace: ks.Keyspace}} ins := newInsertSelect(false, ks.Keyspace, ks.Tables["t1"], "prefix ", nil, [][]int{{1}}, rb) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} vc.results = []*sqltypes.Result{ sqltypes.MakeTestResult( @@ -1787,7 +1787,7 @@ func TestInsertSelectOwned(t *testing.T) { rb, ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} vc.results = []*sqltypes.Result{ sqltypes.MakeTestResult( @@ -1894,7 +1894,7 @@ func TestInsertSelectGenerate(t *testing.T) { Offset: 1, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} vc.results = []*sqltypes.Result{ // This is the result from the input SELECT @@ -1987,7 +1987,7 @@ func TestStreamingInsertSelectGenerate(t *testing.T) { Offset: 1, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} vc.results = []*sqltypes.Result{ // This is the result from the input SELECT @@ -2082,7 +2082,7 @@ func TestInsertSelectGenerateNotProvided(t *testing.T) { Offset: 2, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} vc.results = []*sqltypes.Result{ // This is the result from the input SELECT @@ -2169,7 +2169,7 @@ func TestStreamingInsertSelectGenerateNotProvided(t *testing.T) { Offset: 2, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} vc.results = []*sqltypes.Result{ // This is the result from the input SELECT @@ -2258,7 +2258,7 @@ func TestInsertSelectUnowned(t *testing.T) { rb, ) - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} vc.results = []*sqltypes.Result{ sqltypes.MakeTestResult(sqltypes.MakeTestFields("id", "int64"), "1", "3", "2"), diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index 51976396cba..8134d78ff4a 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -220,10 +220,10 @@ func joinFields(lfields, rfields []*querypb.Field, cols []int) []*querypb.Field fields := make([]*querypb.Field, len(cols)) for i, index := range cols { if index < 0 { - fields[i] = lfields[-index-1] + fields[i] = lfields[-index-1].CloneVT() continue } - fields[i] = rfields[index-1] + fields[i] = rfields[index-1].CloneVT() } return fields } diff --git a/go/vt/vtgate/engine/routing.go b/go/vt/vtgate/engine/routing.go index 067278c1a93..dd6143f6aa4 100644 --- a/go/vt/vtgate/engine/routing.go +++ b/go/vt/vtgate/engine/routing.go @@ -431,6 +431,7 @@ func (rp *RoutingParameters) multiEqual(ctx context.Context, vcursor VCursor, bi if err != nil { return nil, nil, err } + multiBindVars := make([]map[string]*querypb.BindVariable, len(rss)) for i := range multiBindVars { multiBindVars[i] = bindVars @@ -480,7 +481,13 @@ func setReplaceSchemaName(bindVars map[string]*querypb.BindVariable) { bindVars[sqltypes.BvReplaceSchemaName] = sqltypes.Int64BindVariable(1) } -func resolveShards(ctx context.Context, vcursor VCursor, vindex vindexes.SingleColumn, keyspace *vindexes.Keyspace, vindexKeys []sqltypes.Value) ([]*srvtopo.ResolvedShard, [][]*querypb.Value, error) { +func resolveShards( + ctx context.Context, + vcursor VCursor, + vindex vindexes.SingleColumn, + keyspace *vindexes.Keyspace, + vindexKeys []sqltypes.Value, +) ([]*srvtopo.ResolvedShard, [][]*querypb.Value, error) { // Convert vindexKeys to []*querypb.Value ids := make([]*querypb.Value, len(vindexKeys)) for i, vik := range vindexKeys { diff --git a/go/vt/vtgate/engine/routing_parameter_test.go b/go/vt/vtgate/engine/routing_parameter_test.go new file mode 100644 index 00000000000..596a2f7f424 --- /dev/null +++ b/go/vt/vtgate/engine/routing_parameter_test.go @@ -0,0 +1,71 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +func TestFindRouteValuesJoin(t *testing.T) { + vindex, err := vindexes.CreateVindex("hash", "", nil) + require.NoError(t, err) + + const valueBvName = "v" + rp := &RoutingParameters{ + Opcode: MultiEqual, + + Keyspace: &vindexes.Keyspace{ + Name: "ks", + Sharded: true, + }, + + Vindex: vindex, + + Values: []evalengine.Expr{ + &evalengine.TupleBindVariable{Key: valueBvName, Index: 0, Collation: collations.Unknown}, + }, + } + + bv := &querypb.BindVariable{ + Type: querypb.Type_TUPLE, + Values: []*querypb.Value{ + sqltypes.TupleToProto([]sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewVarBinary("hello")}), + sqltypes.TupleToProto([]sqltypes.Value{sqltypes.NewInt64(2), sqltypes.NewVarBinary("good morning")}), + sqltypes.TupleToProto([]sqltypes.Value{sqltypes.NewInt64(3), sqltypes.NewVarBinary("bonjour")}), + sqltypes.TupleToProto([]sqltypes.Value{sqltypes.NewInt64(4), sqltypes.NewVarBinary("bonjour")}), + }, + } + + vc := newTestVCursor("-20", "20-") + vc.shardForKsid = []string{"-20", "-20", "20-", "20-"} + rss, bvs, err := rp.findRoute(context.Background(), vc, map[string]*querypb.BindVariable{ + valueBvName: bv, + }) + + require.NoError(t, err) + require.Len(t, rss, 2) + require.Len(t, bvs, 2) +} diff --git a/go/vt/vtgate/engine/update_test.go b/go/vt/vtgate/engine/update_test.go index eb6af5a5299..e29ffeccd6f 100644 --- a/go/vt/vtgate/engine/update_test.go +++ b/go/vt/vtgate/engine/update_test.go @@ -50,7 +50,7 @@ func TestUpdateUnsharded(t *testing.T) { }, } - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -85,7 +85,7 @@ func TestUpdateEqual(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -116,7 +116,7 @@ func TestUpdateEqualMultiCol(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -142,7 +142,7 @@ func TestUpdateScatter(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) @@ -168,7 +168,7 @@ func TestUpdateScatter(t *testing.T) { }, } - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") _, err = upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) @@ -199,7 +199,7 @@ func TestUpdateEqualNoRoute(t *testing.T) { }, } - vc := newDMLTestVCursor("0") + vc := newTestVCursor("0") _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -250,7 +250,7 @@ func TestUpdateEqualChangedVindex(t *testing.T) { ), "1|4|5|6|0|0", )} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.results = results _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -272,7 +272,7 @@ func TestUpdateEqualChangedVindex(t *testing.T) { }) // No rows changing - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") _, err = upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) @@ -294,7 +294,7 @@ func TestUpdateEqualChangedVindex(t *testing.T) { "1|4|5|6|0|0", "1|7|8|9|0|0", )} - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") vc.results = results _, err = upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -330,7 +330,7 @@ func TestUpdateEqualChangedVindex(t *testing.T) { "1|4|5|6|0|1", // twocol changes "1|7|8|9|1|0", // onecol changes )} - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") vc.results = results _, err = upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -387,7 +387,7 @@ func TestUpdateEqualMultiColChangedVindex(t *testing.T) { ), "1|2|4|0", )} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.results = results _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -514,7 +514,7 @@ func TestUpdateScatterChangedVindex(t *testing.T) { ), "1|4|5|6|0|0", )} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.results = results _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -534,7 +534,7 @@ func TestUpdateScatterChangedVindex(t *testing.T) { }) // No rows changing - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") _, err = upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) if err != nil { @@ -558,7 +558,7 @@ func TestUpdateScatterChangedVindex(t *testing.T) { "1|4|5|6|0|0", "1|7|8|9|0|0", )} - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") vc.results = results _, err = upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -604,7 +604,7 @@ func TestUpdateIn(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -628,7 +628,7 @@ func TestUpdateInStreamExecute(t *testing.T) { Query: "dummy_update", }} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") err := upd.TryStreamExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false, func(result *sqltypes.Result) error { return nil }) @@ -655,7 +655,7 @@ func TestUpdateInMultiCol(t *testing.T) { Query: "dummy_update", }} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) vc.ExpectLog(t, []string{ @@ -710,7 +710,7 @@ func TestUpdateInChangedVindex(t *testing.T) { "1|4|5|6|0|0", "2|21|22|23|0|0", )} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.results = results _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -738,7 +738,7 @@ func TestUpdateInChangedVindex(t *testing.T) { }) // No rows changing - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") _, err = upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) @@ -761,7 +761,7 @@ func TestUpdateInChangedVindex(t *testing.T) { "1|7|8|9|0|0", "2|21|22|23|0|0", )} - vc = newDMLTestVCursor("-20", "20-") + vc = newTestVCursor("-20", "20-") vc.results = results _, err = upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -835,7 +835,7 @@ func TestUpdateInChangedVindexMultiCol(t *testing.T) { "1|3|6|0", "2|3|7|0", )} - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.results = results _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) @@ -874,7 +874,7 @@ func TestUpdateEqualSubshard(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"-20", "20-"} _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) @@ -911,7 +911,7 @@ func TestUpdateMultiEqual(t *testing.T) { }, } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"-20", "20-"} _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) @@ -944,7 +944,7 @@ func TestUpdateInUnique(t *testing.T) { Type: querypb.Type_TUPLE, Values: append([]*querypb.Value{sqltypes.ValueToProto(sqltypes.NewInt64(1))}, sqltypes.ValueToProto(sqltypes.NewInt64(2)), sqltypes.ValueToProto(sqltypes.NewInt64(4))), } - vc := newDMLTestVCursor("-20", "20-") + vc := newTestVCursor("-20", "20-") vc.shardForKsid = []string{"-20", "20-"} _, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{"__vals": tupleBV}, false) require.NoError(t, err) @@ -1033,6 +1033,6 @@ func buildTestVSchema() *vindexes.VSchema { return vs } -func newDMLTestVCursor(shards ...string) *loggingVCursor { +func newTestVCursor(shards ...string) *loggingVCursor { return &loggingVCursor{shards: shards, resolvedTargetTabletType: topodatapb.TabletType_PRIMARY} } diff --git a/go/vt/vtgate/engine/values_join.go b/go/vt/vtgate/engine/values_join.go new file mode 100644 index 00000000000..ced6283dbe8 --- /dev/null +++ b/go/vt/vtgate/engine/values_join.go @@ -0,0 +1,180 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vterrors" +) + +var _ Primitive = (*ValuesJoin)(nil) + +// ValuesJoin is a primitive that joins two primitives by constructing a table from the rows of the LHS primitive. +// The table is passed in as a bind variable to the RHS primitive. +// It's called ValuesJoin because the LHS of the join is sent to the RHS as a VALUES clause. +type ValuesJoin struct { + // Left and Right are the LHS and RHS primitives + // of the Join. They can be any primitive. + Left, Right Primitive + + // The name for the bind var containing the tuple-of-tuples being sent to the RHS + BindVarName string + + // LHSRowID is the offset of the row ID in the LHS, used to use columns from the LHS in the output + // If LHSRowID is false, the output will be the same as the RHS, so the following fields are ignored - Cols, ColNames. + // We copy everything from the LHS to the RHS in this case, and column names are taken from the RHS. + RowID bool + + // CopyColumnsToRHS are the offsets of columns from LHS we are copying over to the RHS + // []int{0,2} means that the first column in the t-o-t is the first offset from the left and the second column is the third offset + CopyColumnsToRHS []int + + // Cols tells use which side the output columns come from: + // negative numbers are offsets to the left, and positive to the right + Cols []int + + // ColNames are the output column names + ColNames []string +} + +// TryExecute performs a non-streaming exec. +func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + lresult, err := vcursor.ExecutePrimitive(ctx, jv.Left, bindVars, wantfields) + if err != nil { + return nil, err + } + bv := &querypb.BindVariable{ + Type: querypb.Type_TUPLE, + } + if len(lresult.Rows) == 0 && wantfields { + // If there are no rows, we still need to construct a single row + // to send down to RHS for Values Table to execute correctly. + // It will be used to execute the field query to provide the output fields. + var vals []sqltypes.Value + for _, field := range lresult.Fields { + val, _ := sqltypes.NewValue(field.Type, nil) + vals = append(vals, val) + } + bv.Values = append(bv.Values, sqltypes.TupleToProto(vals)) + + bindVars[jv.BindVarName] = bv + if jv.RowID { + panic("implement me") + } + return jv.Right.GetFields(ctx, vcursor, bindVars) + } + + rowSize := len(jv.CopyColumnsToRHS) + if jv.RowID { + rowSize++ // +1 since we add the row ID + } + for i, row := range lresult.Rows { + newRow := make(sqltypes.Row, 0, rowSize) + + if jv.RowID { + for _, loffset := range jv.CopyColumnsToRHS { + newRow = append(newRow, row[loffset]) + } + newRow = append(newRow, sqltypes.NewInt64(int64(i))) // Adding the LHS row ID + } else { + newRow = row + } + + bv.Values = append(bv.Values, sqltypes.TupleToProto(newRow)) + } + + bindVars[jv.BindVarName] = bv + rresult, err := vcursor.ExecutePrimitive(ctx, jv.Right, bindVars, wantfields) + if err != nil { + return nil, err + } + + if !jv.RowID { + // if we are not using the row ID, we can just return the result from the RHS + return rresult, nil + } + + result := &sqltypes.Result{} + + result.Fields = joinFields(lresult.Fields, rresult.Fields, jv.Cols) + for i := range result.Fields { + result.Fields[i].Name = jv.ColNames[i] + } + + for _, rrow := range rresult.Rows { + lhsRowID, err := rrow[len(rrow)-1].ToCastInt64() + if err != nil { + return nil, vterrors.VT13001("values joins cannot fetch lhs row ID: " + err.Error()) + } + + result.Rows = append(result.Rows, joinRows(lresult.Rows[lhsRowID], rrow, jv.Cols)) + } + + return result, nil +} + +// TryStreamExecute performs a streaming exec. +func (jv *ValuesJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { + panic("implement me") +} + +// GetFields fetches the field info. +func (jv *ValuesJoin) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + return jv.Right.GetFields(ctx, vcursor, bindVars) +} + +// Inputs returns the input primitives for this join +func (jv *ValuesJoin) Inputs() ([]Primitive, []map[string]any) { + return []Primitive{jv.Left, jv.Right}, nil +} + +// RouteType returns a description of the query routing type used by the primitive +func (jv *ValuesJoin) RouteType() string { + return "ValuesJoin" +} + +// GetKeyspaceName specifies the Keyspace that this primitive routes to. +func (jv *ValuesJoin) GetKeyspaceName() string { + if jv.Left.GetKeyspaceName() == jv.Right.GetKeyspaceName() { + return jv.Left.GetKeyspaceName() + } + return jv.Left.GetKeyspaceName() + "_" + jv.Right.GetKeyspaceName() +} + +// GetTableName specifies the table that this primitive routes to. +func (jv *ValuesJoin) GetTableName() string { + return jv.Left.GetTableName() + "_" + jv.Right.GetTableName() +} + +// NeedsTransaction implements the Primitive interface +func (jv *ValuesJoin) NeedsTransaction() bool { + return jv.Right.NeedsTransaction() || jv.Left.NeedsTransaction() +} + +func (jv *ValuesJoin) description() PrimitiveDescription { + return PrimitiveDescription{ + OperatorType: "Join", + Variant: "Values", + Other: map[string]any{ + "BindVarName": jv.BindVarName, + "CopyColumnsToRHS": jv.CopyColumnsToRHS, + }, + } +} diff --git a/go/vt/vtgate/engine/values_join_test.go b/go/vt/vtgate/engine/values_join_test.go new file mode 100644 index 00000000000..29297d6aa32 --- /dev/null +++ b/go/vt/vtgate/engine/values_join_test.go @@ -0,0 +1,149 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + querypb "vitess.io/vitess/go/vt/proto/query" + + "vitess.io/vitess/go/sqltypes" +) + +func TestJoinValuesExecute(t *testing.T) { + + type testCase struct { + rowID bool + cols []int + CopyColumnsToRHS []int + rhsResults []*sqltypes.Result + expectedRHSLog []string + } + + testCases := []testCase{ + { + /* + select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4 + LHS: select col1, col2, col3 from left + RHS: select col5, col6, id from (values row(1,2), ...) left(id,col1) join right on left.col1 = right.col4 + */ + + rowID: true, + cols: []int{-1, -2, -3, -1, 1, 2}, + CopyColumnsToRHS: []int{0}, + rhsResults: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col5|col6|id", + "varchar|varchar|int64", + ), + "d|dd|0", + "e|ee|1", + "f|ff|2", + "g|gg|3", + ), + }, + expectedRHSLog: []string{ + `Execute a: type:INT64 value:"10" v: [[INT64(1) INT64(0)][INT64(2) INT64(1)][INT64(3) INT64(2)][INT64(4) INT64(3)]] true`, + }, + }, { + /* + select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4 + LHS: select col1, col2, col3 from left + RHS: select col1, col2, col3, col4, col5, col6 from (values row(1,2,3), ...) left(col1,col2,col3) join right on left.col1 = right.col4 + */ + + rowID: false, + rhsResults: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3|col4|col5|col6", + "int64|varchar|varchar|int64|varchar|varchar", + ), + "1|a|aa|1|d|dd", + "2|b|bb|2|e|ee", + "3|c|cc|3|f|ff", + "4|d|dd|4|g|gg", + ), + }, + expectedRHSLog: []string{ + `Execute a: type:INT64 value:"10" v: [[INT64(1) VARCHAR("a") VARCHAR("aa")][INT64(2) VARCHAR("b") VARCHAR("bb")][INT64(3) VARCHAR("c") VARCHAR("cc")][INT64(4) VARCHAR("d") VARCHAR("dd")]] true`, + }, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("rowID:%t", tc.rowID), func(t *testing.T) { + leftPrim := &fakePrimitive{ + useNewPrintBindVars: true, + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + "3|c|cc", + "4|d|dd", + ), + }, + } + rightPrim := &fakePrimitive{ + useNewPrintBindVars: true, + results: tc.rhsResults, + } + + bv := map[string]*querypb.BindVariable{ + "a": sqltypes.Int64BindVariable(10), + } + + vjn := &ValuesJoin{ + Left: leftPrim, + Right: rightPrim, + CopyColumnsToRHS: tc.CopyColumnsToRHS, + BindVarName: "v", + Cols: tc.cols, + ColNames: []string{"col1", "col2", "col3", "col4", "col5", "col6"}, + RowID: tc.rowID, + } + + r, err := vjn.TryExecute(context.Background(), &noopVCursor{}, bv, true) + require.NoError(t, err) + leftPrim.ExpectLog(t, []string{ + `Execute a: type:INT64 value:"10" true`, + }) + rightPrim.ExpectLog(t, tc.expectedRHSLog) + + result := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3|col4|col5|col6", + "int64|varchar|varchar|int64|varchar|varchar", + ), + "1|a|aa|1|d|dd", + "2|b|bb|2|e|ee", + "3|c|cc|3|f|ff", + "4|d|dd|4|g|gg", + ) + expectResult(t, r, result) + }) + } +} diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index 916c5e200f4..f75ac0f8202 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -378,6 +378,16 @@ func valueToEval(value sqltypes.Value, collation collations.TypedCollation, valu } switch tt := value.Type(); { + case tt == sqltypes.Tuple: + t := &evalTuple{} + err := value.ForEachValue(func(bv sqltypes.Value) { + e, err := valueToEval(bv, collation, values) + if err != nil { + return + } + t.t = append(t.t, e) + }) + return t, wrap(err) case sqltypes.IsSigned(tt): ival, err := value.ToInt64() return newEvalInt64(ival), wrap(err) diff --git a/go/vt/vtgate/evalengine/eval_tuple.go b/go/vt/vtgate/evalengine/eval_tuple.go index 81fa3317977..1faff68e155 100644 --- a/go/vt/vtgate/evalengine/eval_tuple.go +++ b/go/vt/vtgate/evalengine/eval_tuple.go @@ -27,7 +27,15 @@ type evalTuple struct { var _ eval = (*evalTuple)(nil) func (e *evalTuple) ToRawBytes() []byte { - return nil + var vals []sqltypes.Value + for _, e2 := range e.t { + v, err := sqltypes.NewValue(e2.SQLType(), e2.ToRawBytes()) + if err != nil { + panic(err) + } + vals = append(vals, v) + } + return sqltypes.TupleToProto(vals).Value } func (e *evalTuple) SQLType() sqltypes.Type { diff --git a/go/vt/vtgate/evalengine/expr_tuple_bvar.go b/go/vt/vtgate/evalengine/expr_tuple_bvar.go index 14cfbd95a8b..754ed8cf4f8 100644 --- a/go/vt/vtgate/evalengine/expr_tuple_bvar.go +++ b/go/vt/vtgate/evalengine/expr_tuple_bvar.go @@ -30,7 +30,6 @@ type ( Key string Index int - Type sqltypes.Type Collation collations.ID } ) diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index b51eac449fc..311388d8ff5 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -79,6 +79,24 @@ func transformToPrimitive(ctx *plancontext.PlanningContext, op operators.Operato return transformRecurseCTE(ctx, op) case *operators.PercentBasedMirror: return transformPercentBasedMirror(ctx, op) + case *operators.ValuesJoin: + lhs, err := transformToPrimitive(ctx, op.LHS) + if err != nil { + return nil, err + } + rhs, err := transformToPrimitive(ctx, op.RHS) + if err != nil { + return nil, err + } + + return &engine.ValuesJoin{ + Left: lhs, + Right: rhs, + CopyColumnsToRHS: op.CopyColumnsToRHS, + BindVarName: op.BindVarName, + Cols: op.Columns, + ColNames: op.ColumnName, + }, nil } return nil, vterrors.VT13001(fmt.Sprintf("unknown type encountered: %T (transformToPrimitive)", op)) @@ -172,7 +190,7 @@ func transformInsertionSelection(ctx *plancontext.PlanningContext, op *operators return nil, vterrors.VT13001(fmt.Sprintf("Incorrect type encountered: %T (transformInsertionSelection)", op.Insert)) } - stmt, dmlOp, err := operators.ToSQL(ctx, rb.Source) + stmt, dmlOp, err := operators.ToAST(ctx, rb.Source) if err != nil { return nil, err } @@ -579,7 +597,7 @@ func getHints(cmt *sqlparser.ParsedComments) *queryHints { } func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) (engine.Primitive, error) { - stmt, dmlOp, err := operators.ToSQL(ctx, op.Source) + stmt, dmlOp, err := operators.ToAST(ctx, op.Source) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index ced81df147a..d46458d6379 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -509,7 +509,7 @@ func splitGroupingToLeftAndRight( rhs.addGrouping(ctx, groupBy) columns.addRight(groupBy.Inner) case deps.IsSolvedBy(lhs.tableID.Merge(rhs.tableID)): - jc := breakExpressionInLHSandRHS(ctx, groupBy.Inner, lhs.tableID) + jc := breakApplyJoinExpressionInLHSandRHS(ctx, groupBy.Inner, lhs.tableID) for _, lhsExpr := range jc.LHSExprs { e := lhsExpr.Expr lhs.addGrouping(ctx, NewGroupBy(e)) diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index 80bf74708a8..ed634bdd0ff 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -146,7 +146,7 @@ func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sql rhs := aj.RHS predicates := sqlparser.SplitAndExpression(nil, expr) for _, pred := range predicates { - col := breakExpressionInLHSandRHS(ctx, pred, TableID(aj.LHS)) + col := breakApplyJoinExpressionInLHSandRHS(ctx, pred, TableID(aj.LHS)) aj.JoinPredicates.add(col) ctx.AddJoinPredicates(pred, col.RHSExpr) rhs = rhs.AddPredicate(ctx, col.RHSExpr) @@ -199,7 +199,7 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq case deps.IsSolvedBy(rhs): col.RHSExpr = e case deps.IsSolvedBy(both): - col = breakExpressionInLHSandRHS(ctx, e, TableID(aj.LHS)) + col = breakApplyJoinExpressionInLHSandRHS(ctx, e, TableID(aj.LHS)) default: panic(vterrors.VT13001(fmt.Sprintf("expression depends on tables outside this join: %s", sqlparser.String(e)))) } diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 2e3781c94db..259a83213a3 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -156,7 +156,7 @@ func (jpc *joinPredicateCollector) inspectPredicate( // then we can use this predicate to connect the subquery to the outer query if !deps.IsSolvedBy(jpc.subqID) && deps.IsSolvedBy(jpc.totalID) { jpc.predicates = append(jpc.predicates, predicate) - jc := breakExpressionInLHSandRHS(ctx, predicate, jpc.outerID) + jc := breakApplyJoinExpressionInLHSandRHS(ctx, predicate, jpc.outerID) jpc.joinColumns = append(jpc.joinColumns, jc) pred = jc.RHSExpr } diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index f42ec87404d..4ba04aaf1ad 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -22,9 +22,9 @@ import ( "vitess.io/vitess/go/vt/vtgate/semantics" ) -// breakExpressionInLHSandRHS takes an expression and +// breakApplyJoinExpressionInLHSandRHS takes an expression and // extracts the parts that are coming from one of the sides into `ColName`s that are needed -func breakExpressionInLHSandRHS( +func breakApplyJoinExpressionInLHSandRHS( ctx *plancontext.PlanningContext, expr sqlparser.Expr, lhs semantics.TableSet, @@ -129,3 +129,24 @@ func getFirstSelect(selStmt sqlparser.TableStatement) *sqlparser.Select { } return firstSelect } + +func breakValuesJoinExpressionInLHS(ctx *plancontext.PlanningContext, + expr sqlparser.Expr, + lhs semantics.TableSet, +) (result valuesJoinColumn) { + result.Original = expr + result.PureLHS = true + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + col, ok := node.(*sqlparser.ColName) + if !ok { + return true, nil + } + if ctx.SemTable.RecursiveDeps(col) == lhs { + result.LHS = append(result.LHS, col) + } else { + result.PureLHS = false + } + return true, nil + }, expr) + return +} diff --git a/go/vt/vtgate/planbuilder/operators/expressions_test.go b/go/vt/vtgate/planbuilder/operators/expressions_test.go new file mode 100644 index 00000000000..9738fec7e0b --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/expressions_test.go @@ -0,0 +1,58 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +func TestSplitComplexPredicateToLHS(t *testing.T) { + ast, err := sqlparser.NewTestParser().ParseExpr("l.foo + r.bar - l.baz / r.tata = 0") + require.NoError(t, err) + lID := semantics.SingleTableSet(0) + rID := semantics.SingleTableSet(1) + ctx := plancontext.CreateEmptyPlanningContext() + ctx.SemTable = semantics.EmptySemTable() + // simple sem analysis using the column prefix + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + col, ok := node.(*sqlparser.ColName) + if !ok { + return true, nil + } + if col.Qualifier.Name.String() == "l" { + ctx.SemTable.Recursive[col] = lID + } else { + ctx.SemTable.Recursive[col] = rID + } + return false, nil + }, ast) + + valuesJoinCols := breakValuesJoinExpressionInLHS(ctx, ast, lID) + nodes := slice.Map(valuesJoinCols.LHS, func(from *sqlparser.ColName) string { + return sqlparser.String(from) + }) + + assert.Equal(t, []string{"l.foo", "l.baz"}, nodes) +} diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index ff4915527a7..ed1271d539e 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -158,7 +158,7 @@ func addCTEPredicate( } func breakCTEExpressionInLhsAndRhs(ctx *plancontext.PlanningContext, pred sqlparser.Expr, lhsID semantics.TableSet) *plancontext.RecurseExpression { - col := breakExpressionInLHSandRHS(ctx, pred, lhsID) + col := breakApplyJoinExpressionInLHSandRHS(ctx, pred, lhsID) lhsExprs := slice.Map(col.LHSExprs, func(bve BindVarExpr) plancontext.BindVarExpr { col, ok := bve.Expr.(*sqlparser.ColName) diff --git a/go/vt/vtgate/planbuilder/operators/op_to_ast.go b/go/vt/vtgate/planbuilder/operators/op_to_ast.go new file mode 100644 index 00000000000..5fea21b11eb --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/op_to_ast.go @@ -0,0 +1,430 @@ +/* +Copyright 2022 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "fmt" + + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +func ToAST(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement, _ Operator, err error) { + defer PanicHandler(&err) + + q := &queryBuilder{ctx: ctx} + buildAST(op, q) + if ctx.SemTable != nil { + q.sortTables() + } + return q.stmt, q.dmlOperator, nil +} + +func removeKeyspaceFromSelectExpr(expr sqlparser.SelectExpr) { + switch expr := expr.(type) { + case *sqlparser.AliasedExpr: + sqlparser.RemoveKeyspaceInCol(expr.Expr) + case *sqlparser.StarExpr: + expr.TableName.Qualifier = sqlparser.NewIdentifierCS("") + } +} + +func stripDownQuery(from, to sqlparser.TableStatement) { + switch node := from.(type) { + case *sqlparser.Select: + toNode, ok := to.(*sqlparser.Select) + if !ok { + panic(vterrors.VT13001("AST did not match")) + } + toNode.Distinct = node.Distinct + toNode.GroupBy = node.GroupBy + toNode.Having = node.Having + toNode.OrderBy = node.OrderBy + toNode.Comments = node.Comments + toNode.Limit = node.Limit + toNode.SelectExprs = node.SelectExprs + for _, expr := range toNode.SelectExprs { + removeKeyspaceFromSelectExpr(expr) + } + case *sqlparser.Union: + toNode, ok := to.(*sqlparser.Union) + if !ok { + panic(vterrors.VT13001("AST did not match")) + } + stripDownQuery(node.Left, toNode.Left) + stripDownQuery(node.Right, toNode.Right) + toNode.OrderBy = node.OrderBy + default: + panic(vterrors.VT13001(fmt.Sprintf("this should not happen - we have covered all implementations of SelectStatement %T", from))) + } +} + +// buildAST recursively builds the query into an AST, from an operator tree +func buildAST(op Operator, qb *queryBuilder) { + switch op := op.(type) { + case *Table: + buildTable(op, qb) + case *Projection: + buildProjection(op, qb) + case *ApplyJoin: + buildApplyJoin(op, qb) + case *Filter: + buildFilter(op, qb) + case *Horizon: + if op.TableId != nil { + buildDerived(op, qb) + return + } + buildHorizon(op, qb) + case *Limit: + buildLimit(op, qb) + case *Ordering: + buildOrdering(op, qb) + case *Aggregator: + buildAggregation(op, qb) + case *Union: + buildUnion(op, qb) + case *Distinct: + buildDistinct(op, qb) + case *Update: + buildUpdate(op, qb) + case *Delete: + buildDelete(op, qb) + case *Insert: + buildDML(op, qb) + case *RecurseCTE: + buildRecursiveCTE(op, qb) + case *Values: + buildValues(op, qb) + case *ValuesJoin: + buildValuesJoin(op, qb) + default: + panic(vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op))) + } +} + +func buildDistinct(op *Distinct, qb *queryBuilder) { + buildAST(op.Source, qb) + statement := qb.asSelectStatement() + d, ok := statement.(sqlparser.Distinctable) + if !ok { + panic(vterrors.VT13001("expected a select statement with distinct")) + } + d.MakeDistinct() +} + +func buildValuesJoin(op *ValuesJoin, qb *queryBuilder) { + qb.ctx.SkipValuesArgument(op.BindVarName) + buildAST(op.LHS, qb) + qbR := &queryBuilder{ctx: qb.ctx} + buildAST(op.RHS, qbR) + qb.joinWith(qbR, nil, sqlparser.NormalJoinType) +} + +func buildValues(op *Values, qb *queryBuilder) { + buildAST(op.Source, qb) + if qb.ctx.IsValuesArgumentSkipped(op.Arg) { + return + } + + qb.addTableExpr(op.Name, op.Name, TableID(op), &sqlparser.DerivedTable{ + Select: &sqlparser.ValuesStatement{ + ListArg: sqlparser.NewListArg(op.Arg), + }, + }, nil, op.getColsFromCtx(qb.ctx)) +} + +func buildDelete(op *Delete, qb *queryBuilder) { + qb.stmt = &sqlparser.Delete{ + Ignore: op.Ignore, + Targets: sqlparser.TableNames{op.Target.Name}, + } + buildAST(op.Source, qb) + + qb.dmlOperator = op +} + +func buildUpdate(op *Update, qb *queryBuilder) { + updExprs := getUpdateExprs(op) + upd := &sqlparser.Update{ + Ignore: op.Ignore, + Exprs: updExprs, + } + qb.stmt = upd + qb.dmlOperator = op + buildAST(op.Source, qb) +} + +func getUpdateExprs(op *Update) sqlparser.UpdateExprs { + updExprs := make(sqlparser.UpdateExprs, 0, len(op.Assignments)) + for _, se := range op.Assignments { + updExprs = append(updExprs, &sqlparser.UpdateExpr{ + Name: se.Name, + Expr: se.Expr.EvalExpr, + }) + } + return updExprs +} + +type OpWithAST interface { + Operator + Statement() sqlparser.Statement +} + +func buildDML(op OpWithAST, qb *queryBuilder) { + qb.stmt = op.Statement() + qb.dmlOperator = op +} + +func buildAggregation(op *Aggregator, qb *queryBuilder) { + buildAST(op.Source, qb) + + qb.clearProjections() + + cols := op.GetColumns(qb.ctx) + for _, column := range cols { + qb.addProjection(column) + } + + for _, by := range op.Grouping { + qb.addGroupBy(by.Inner) + simplified := by.Inner + if by.WSOffset != -1 { + qb.addGroupBy(weightStringFor(simplified)) + } + } + if op.WithRollup { + qb.setWithRollup() + } + + if op.DT != nil { + sel := qb.asSelectStatement() + qb.stmt = nil + qb.addTableExpr(op.DT.Alias, op.DT.Alias, TableID(op), &sqlparser.DerivedTable{ + Select: sel, + }, nil, op.DT.Columns) + } +} + +func buildOrdering(op *Ordering, qb *queryBuilder) { + buildAST(op.Source, qb) + + for _, order := range op.Order { + qb.asOrderAndLimit().AddOrder(order.Inner) + } +} + +func buildLimit(op *Limit, qb *queryBuilder) { + buildAST(op.Source, qb) + qb.asOrderAndLimit().SetLimit(op.AST) +} + +func buildTable(op *Table, qb *queryBuilder) { + if !qb.includeTable(op) { + return + } + + dbName := "" + + if op.QTable.IsInfSchema { + dbName = op.QTable.Table.Qualifier.String() + } + qb.addTable(dbName, op.QTable.Table.Name.String(), op.QTable.Alias.As.String(), TableID(op), op.QTable.Alias.Hints) + for _, pred := range op.QTable.Predicates { + qb.addPredicate(pred) + } + for _, name := range op.Columns { + qb.addProjection(&sqlparser.AliasedExpr{Expr: name}) + } +} + +func buildProjection(op *Projection, qb *queryBuilder) { + buildAST(op.Source, qb) + + _, isSel := qb.stmt.(*sqlparser.Select) + if isSel { + qb.clearProjections() + cols := op.GetSelectExprs(qb.ctx) + for _, column := range cols { + qb.addProjection(column) + } + } + + // if the projection is on derived table, we use the select we have + // created above and transform it into a derived table + if op.DT != nil { + sel := qb.asSelectStatement() + qb.stmt = nil + qb.addTableExpr(op.DT.Alias, op.DT.Alias, TableID(op), &sqlparser.DerivedTable{ + Select: sel, + }, nil, op.DT.Columns) + } + + if !isSel { + cols := op.GetSelectExprs(qb.ctx) + for _, column := range cols { + qb.addProjection(column) + } + } +} + +func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) { + predicates := slice.Map(op.JoinPredicates.columns, func(jc applyJoinColumn) sqlparser.Expr { + // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done + err := qb.ctx.SkipJoinPredicates(jc.Original) + if err != nil { + panic(err) + } + return jc.Original + }) + pred := sqlparser.AndExpressions(predicates...) + + buildAST(op.LHS, qb) + + qbR := &queryBuilder{ctx: qb.ctx} + buildAST(op.RHS, qbR) + + switch { + // if we have a recursive cte, we might be missing a statement from one of the sides + case qbR.stmt == nil: + // do nothing + case qb.stmt == nil: + qb.stmt = qbR.stmt + default: + qb.joinWith(qbR, pred, op.JoinType) + } +} + +func buildUnion(op *Union, qb *queryBuilder) { + // the first input is built first + buildAST(op.Sources[0], qb) + + for i, src := range op.Sources { + if i == 0 { + continue + } + + // now we can go over the remaining inputs and UNION them together + qbOther := &queryBuilder{ctx: qb.ctx} + buildAST(src, qbOther) + qb.unionWith(qbOther, op.distinct) + } +} + +func buildFilter(op *Filter, qb *queryBuilder) { + buildAST(op.Source, qb) + + for _, pred := range op.Predicates { + qb.addPredicate(pred) + } +} + +func buildDerived(op *Horizon, qb *queryBuilder) { + buildAST(op.Source, qb) + + sqlparser.RemoveKeyspaceInCol(op.Query) + + stmt := qb.stmt + qb.stmt = nil + switch sel := stmt.(type) { + case *sqlparser.Select: + buildDerivedSelect(op, qb, sel) + return + case *sqlparser.Union: + buildDerivedUnion(op, qb, sel) + return + } + panic(fmt.Sprintf("unknown select statement type: %T", stmt)) +} + +func buildDerivedUnion(op *Horizon, qb *queryBuilder, union *sqlparser.Union) { + opQuery, ok := op.Query.(*sqlparser.Union) + if !ok { + panic(vterrors.VT12001("Horizon contained SELECT but statement was UNION")) + } + + union.Limit = opQuery.Limit + union.OrderBy = opQuery.OrderBy + union.Distinct = opQuery.Distinct + + qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{ + Select: union, + }, nil, op.ColumnAliases) +} + +func buildDerivedSelect(op *Horizon, qb *queryBuilder, sel *sqlparser.Select) { + opQuery, ok := op.Query.(*sqlparser.Select) + if !ok { + panic(vterrors.VT12001("Horizon contained UNION but statement was SELECT")) + } + sel.Limit = opQuery.Limit + sel.OrderBy = opQuery.OrderBy + sel.GroupBy = opQuery.GroupBy + sel.Having = mergeHaving(sel.Having, opQuery.Having) + sel.SelectExprs = opQuery.SelectExprs + sel.Distinct = opQuery.Distinct + qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{ + Select: sel, + }, nil, op.ColumnAliases) + for _, col := range op.Columns { + qb.addProjection(&sqlparser.AliasedExpr{Expr: col}) + } +} + +func buildHorizon(op *Horizon, qb *queryBuilder) { + buildAST(op.Source, qb) + stripDownQuery(op.Query, qb.asSelectStatement()) + sqlparser.RemoveKeyspaceInCol(qb.stmt) +} + +func buildRecursiveCTE(op *RecurseCTE, qb *queryBuilder) { + predicates := slice.Map(op.Predicates, func(jc *plancontext.RecurseExpression) sqlparser.Expr { + // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done + err := qb.ctx.SkipJoinPredicates(jc.Original) + if err != nil { + panic(err) + } + return jc.Original + }) + pred := sqlparser.AndExpressions(predicates...) + buildAST(op.Seed(), qb) + qbR := &queryBuilder{ctx: qb.ctx} + buildAST(op.Term(), qbR) + qbR.addPredicate(pred) + infoFor, err := qb.ctx.SemTable.TableInfoFor(op.OuterID) + if err != nil { + panic(err) + } + + qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String(), op.Distinct, op.Def.Columns) +} + +func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where { + switch { + case h1 == nil && h2 == nil: + return nil + case h1 == nil: + return h2 + case h2 == nil: + return h1 + default: + h1.Expr = sqlparser.AndExpressions(h1.Expr, h2.Expr) + return h1 + } +} diff --git a/go/vt/vtgate/planbuilder/operators/op_to_ast_test.go b/go/vt/vtgate/planbuilder/operators/op_to_ast_test.go new file mode 100644 index 00000000000..303e3fb7eb5 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/op_to_ast_test.go @@ -0,0 +1,119 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +func TestToSQLValues(t *testing.T) { + ctx := plancontext.CreateEmptyPlanningContext() + bindVarName := "toto" + ctx.ValuesJoinColumns[bindVarName] = sqlparser.Columns{sqlparser.NewIdentifierCI("user_id")} + + tableName := sqlparser.NewTableName("x") + tableColumn := sqlparser.NewColName("id") + op := &Values{ + unaryOperator: newUnaryOp(&Table{ + QTable: &QueryTable{ + Table: tableName, + Alias: sqlparser.NewAliasedTableExpr(tableName, ""), + }, + Columns: []*sqlparser.ColName{tableColumn}, + }), + Name: "t", + Arg: bindVarName, + } + + stmt, _, err := ToAST(ctx, op) + require.NoError(t, err) + require.Equal(t, "select id from x, (values ::toto) as t(user_id)", sqlparser.String(stmt)) + + // Now do the same test but with a projection on top + proj := newAliasedProjection(op) + proj.addUnexploredExpr(sqlparser.NewAliasedExpr(tableColumn, ""), tableColumn) + + userIdColName := sqlparser.NewColNameWithQualifier("user_id", sqlparser.NewTableName("t")) + proj.addUnexploredExpr( + sqlparser.NewAliasedExpr(userIdColName, ""), + userIdColName, + ) + + stmt, _, err = ToAST(ctx, proj) + require.NoError(t, err) + require.Equal(t, "select id, t.user_id from x, (values ::toto) as t(user_id)", sqlparser.String(stmt)) +} + +func TestToSQLValuesJoin(t *testing.T) { + // Build a SQL AST from a values join that has been pushed under a route + ctx := plancontext.CreateEmptyPlanningContext() + parser := sqlparser.NewTestParser() + + lhsTableName := sqlparser.NewTableName("x") + lhsTableColumn := sqlparser.NewColName("id") + lhsFilterPred, err := parser.ParseExpr("x.id = 42") + require.NoError(t, err) + + LHS := &Filter{ + unaryOperator: newUnaryOp(&Table{ + QTable: &QueryTable{ + Table: lhsTableName, + Alias: sqlparser.NewAliasedTableExpr(lhsTableName, ""), + }, + Columns: []*sqlparser.ColName{lhsTableColumn}, + }), + Predicates: []sqlparser.Expr{lhsFilterPred}, + } + + const argumentName = "v" + ctx.ValuesJoinColumns[argumentName] = sqlparser.Columns{sqlparser.NewIdentifierCI("id")} + rhsTableName := sqlparser.NewTableName("y") + rhsTableColumn := sqlparser.NewColName("tata") + rhsFilterPred, err := parser.ParseExpr("y.tata = 42") + require.NoError(t, err) + rhsJoinFilterPred, err := parser.ParseExpr("y.tata = x.id") + require.NoError(t, err) + + RHS := &Filter{ + unaryOperator: newUnaryOp(&Values{ + unaryOperator: newUnaryOp(&Table{ + QTable: &QueryTable{ + Table: rhsTableName, + Alias: sqlparser.NewAliasedTableExpr(rhsTableName, ""), + }, + Columns: []*sqlparser.ColName{rhsTableColumn}, + }), + Name: lhsTableName.Name.String(), + Arg: argumentName, + }), + Predicates: []sqlparser.Expr{rhsFilterPred, rhsJoinFilterPred}, + } + + vj := &ValuesJoin{ + binaryOperator: newBinaryOp(LHS, RHS), + BindVarName: argumentName, + } + + stmt, _, err := ToAST(ctx, vj) + require.NoError(t, err) + require.Equal(t, "select id, tata from x, y where x.id = 42 and y.tata = 42 and y.tata = x.id", sqlparser.String(stmt)) +} diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/query_builder.go similarity index 51% rename from go/vt/vtgate/planbuilder/operators/SQL_builder.go rename to go/vt/vtgate/planbuilder/operators/query_builder.go index ca15b5c9134..8937f99e347 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/query_builder.go @@ -1,5 +1,5 @@ /* -Copyright 2022 The Vitess Authors. +Copyright 2025 The Vitess Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ import ( "slices" "sort" - "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -45,17 +44,6 @@ func (qb *queryBuilder) asOrderAndLimit() sqlparser.OrderAndLimit { return qb.stmt.(sqlparser.OrderAndLimit) } -func ToSQL(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement, _ Operator, err error) { - defer PanicHandler(&err) - - q := &queryBuilder{ctx: ctx} - buildQuery(op, q) - if ctx.SemTable != nil { - q.sortTables() - } - return q.stmt, q.dmlOperator, nil -} - // includeTable will return false if the table is a CTE, and it is not merged // it will return true if the table is not a CTE or if it is a CTE and it is merged func (qb *queryBuilder) includeTable(op *Table) bool { @@ -197,38 +185,6 @@ func (qb *queryBuilder) pushUnionInsideDerived() { qb.stmt = sel } -func unionSelects(exprs sqlparser.SelectExprs) (selectExprs sqlparser.SelectExprs) { - for _, col := range exprs { - switch col := col.(type) { - case *sqlparser.AliasedExpr: - expr := sqlparser.NewColName(col.ColumnName()) - selectExprs = append(selectExprs, &sqlparser.AliasedExpr{Expr: expr}) - default: - selectExprs = append(selectExprs, col) - } - } - return -} - -func checkUnionColumnByName(column *sqlparser.ColName, sel sqlparser.TableStatement) { - colName := column.Name.String() - firstSelect := getFirstSelect(sel) - exprs := firstSelect.SelectExprs - offset := slices.IndexFunc(exprs, func(expr sqlparser.SelectExpr) bool { - switch ae := expr.(type) { - case *sqlparser.StarExpr: - return true - case *sqlparser.AliasedExpr: - // When accessing columns on top of a UNION, we fall back to this simple strategy of string comparisons - return ae.ColumnName() == colName - } - return false - }) - if offset == -1 { - panic(vterrors.VT12001(fmt.Sprintf("did not find column [%s] on UNION", sqlparser.String(column)))) - } -} - func (qb *queryBuilder) clearProjections() { sel, isSel := qb.stmt.(*sqlparser.Select) if !isSel { @@ -266,17 +222,6 @@ func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string qb.addTable("", name, alias, "", nil) } -type FromStatement interface { - GetFrom() []sqlparser.TableExpr - SetFrom([]sqlparser.TableExpr) - GetWherePredicate() sqlparser.Expr - SetWherePredicate(sqlparser.Expr) -} - -var _ FromStatement = (*sqlparser.Select)(nil) -var _ FromStatement = (*sqlparser.Update)(nil) -var _ FromStatement = (*sqlparser.Delete)(nil) - func (qb *queryBuilder) joinWith(other *queryBuilder, onCondition sqlparser.Expr, joinType sqlparser.JoinType) { stmt := qb.stmt.(FromStatement) otherStmt := other.stmt.(FromStatement) @@ -314,48 +259,6 @@ func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) { } } -func buildJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr, joinType sqlparser.JoinType) *sqlparser.JoinTableExpr { - var lhs sqlparser.TableExpr - fromClause := stmt.GetFrom() - if len(fromClause) == 1 { - lhs = fromClause[0] - } else { - lhs = &sqlparser.ParenTableExpr{Exprs: fromClause} - } - var rhs sqlparser.TableExpr - otherFromClause := otherStmt.GetFrom() - if len(otherFromClause) == 1 { - rhs = otherFromClause[0] - } else { - rhs = &sqlparser.ParenTableExpr{Exprs: otherFromClause} - } - - return &sqlparser.JoinTableExpr{ - LeftExpr: lhs, - RightExpr: rhs, - Join: joinType, - Condition: &sqlparser.JoinCondition{ - On: onCondition, - }, - } -} - -func (qb *queryBuilder) sortTables() { - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - sel, isSel := node.(*sqlparser.Select) - if !isSel { - return true, nil - } - ts := &tableSorter{ - sel: sel, - tbl: qb.ctx.SemTable, - } - sort.Sort(ts) - return true, nil - }, qb.stmt) - -} - type tableSorter struct { sel *sqlparser.Select tbl *semantics.SemTable @@ -387,366 +290,86 @@ func (ts *tableSorter) Swap(i, j int) { ts.sel.From[i], ts.sel.From[j] = ts.sel.From[j], ts.sel.From[i] } -func removeKeyspaceFromSelectExpr(expr sqlparser.SelectExpr) { - switch expr := expr.(type) { - case *sqlparser.AliasedExpr: - sqlparser.RemoveKeyspaceInCol(expr.Expr) - case *sqlparser.StarExpr: - expr.TableName.Qualifier = sqlparser.NewIdentifierCS("") - } -} - -func stripDownQuery(from, to sqlparser.TableStatement) { - switch node := from.(type) { - case *sqlparser.Select: - toNode, ok := to.(*sqlparser.Select) - if !ok { - panic(vterrors.VT13001("AST did not match")) - } - toNode.Distinct = node.Distinct - toNode.GroupBy = node.GroupBy - toNode.Having = node.Having - toNode.OrderBy = node.OrderBy - toNode.Comments = node.Comments - toNode.Limit = node.Limit - toNode.SelectExprs = node.SelectExprs - for _, expr := range toNode.SelectExprs { - removeKeyspaceFromSelectExpr(expr) - } - case *sqlparser.Union: - toNode, ok := to.(*sqlparser.Union) - if !ok { - panic(vterrors.VT13001("AST did not match")) - } - stripDownQuery(node.Left, toNode.Left) - stripDownQuery(node.Right, toNode.Right) - toNode.OrderBy = node.OrderBy - default: - panic(vterrors.VT13001(fmt.Sprintf("this should not happen - we have covered all implementations of SelectStatement %T", from))) - } -} - -// buildQuery recursively builds the query into an AST, from an operator tree -func buildQuery(op Operator, qb *queryBuilder) { - switch op := op.(type) { - case *Table: - buildTable(op, qb) - case *Projection: - buildProjection(op, qb) - case *ApplyJoin: - buildApplyJoin(op, qb) - case *Filter: - buildFilter(op, qb) - case *Horizon: - if op.TableId != nil { - buildDerived(op, qb) - return - } - buildHorizon(op, qb) - case *Limit: - buildLimit(op, qb) - case *Ordering: - buildOrdering(op, qb) - case *Aggregator: - buildAggregation(op, qb) - case *Union: - buildUnion(op, qb) - case *Distinct: - buildQuery(op.Source, qb) - statement := qb.asSelectStatement() - d, ok := statement.(sqlparser.Distinctable) - if !ok { - panic(vterrors.VT13001("expected a select statement with distinct")) +func (qb *queryBuilder) sortTables() { + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + sel, isSel := node.(*sqlparser.Select) + if !isSel { + return true, nil } - d.MakeDistinct() - case *Update: - buildUpdate(op, qb) - case *Delete: - buildDelete(op, qb) - case *Insert: - buildDML(op, qb) - case *RecurseCTE: - buildRecursiveCTE(op, qb) - default: - panic(vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op))) - } -} - -func buildDelete(op *Delete, qb *queryBuilder) { - qb.stmt = &sqlparser.Delete{ - Ignore: op.Ignore, - Targets: sqlparser.TableNames{op.Target.Name}, - } - buildQuery(op.Source, qb) - - qb.dmlOperator = op -} - -func buildUpdate(op *Update, qb *queryBuilder) { - updExprs := getUpdateExprs(op) - upd := &sqlparser.Update{ - Ignore: op.Ignore, - Exprs: updExprs, - } - qb.stmt = upd - qb.dmlOperator = op - buildQuery(op.Source, qb) -} - -func getUpdateExprs(op *Update) sqlparser.UpdateExprs { - updExprs := make(sqlparser.UpdateExprs, 0, len(op.Assignments)) - for _, se := range op.Assignments { - updExprs = append(updExprs, &sqlparser.UpdateExpr{ - Name: se.Name, - Expr: se.Expr.EvalExpr, - }) - } - return updExprs -} - -type OpWithAST interface { - Operator - Statement() sqlparser.Statement -} - -func buildDML(op OpWithAST, qb *queryBuilder) { - qb.stmt = op.Statement() - qb.dmlOperator = op -} - -func buildAggregation(op *Aggregator, qb *queryBuilder) { - buildQuery(op.Source, qb) - - qb.clearProjections() - - cols := op.GetColumns(qb.ctx) - for _, column := range cols { - qb.addProjection(column) - } - - for _, by := range op.Grouping { - qb.addGroupBy(by.Inner) - simplified := by.Inner - if by.WSOffset != -1 { - qb.addGroupBy(weightStringFor(simplified)) + ts := &tableSorter{ + sel: sel, + tbl: qb.ctx.SemTable, } - } - if op.WithRollup { - qb.setWithRollup() - } - - if op.DT != nil { - sel := qb.asSelectStatement() - qb.stmt = nil - qb.addTableExpr(op.DT.Alias, op.DT.Alias, TableID(op), &sqlparser.DerivedTable{ - Select: sel, - }, nil, op.DT.Columns) - } -} - -func buildOrdering(op *Ordering, qb *queryBuilder) { - buildQuery(op.Source, qb) - - for _, order := range op.Order { - qb.asOrderAndLimit().AddOrder(order.Inner) - } -} - -func buildLimit(op *Limit, qb *queryBuilder) { - buildQuery(op.Source, qb) - qb.asOrderAndLimit().SetLimit(op.AST) -} - -func buildTable(op *Table, qb *queryBuilder) { - if !qb.includeTable(op) { - return - } - - dbName := "" - - if op.QTable.IsInfSchema { - dbName = op.QTable.Table.Qualifier.String() - } - qb.addTable(dbName, op.QTable.Table.Name.String(), op.QTable.Alias.As.String(), TableID(op), op.QTable.Alias.Hints) - for _, pred := range op.QTable.Predicates { - qb.addPredicate(pred) - } - for _, name := range op.Columns { - qb.addProjection(&sqlparser.AliasedExpr{Expr: name}) - } + sort.Sort(ts) + return true, nil + }, qb.stmt) } -func buildProjection(op *Projection, qb *queryBuilder) { - buildQuery(op.Source, qb) - - _, isSel := qb.stmt.(*sqlparser.Select) - if isSel { - qb.clearProjections() - cols := op.GetSelectExprs(qb.ctx) - for _, column := range cols { - qb.addProjection(column) - } - } - - // if the projection is on derived table, we use the select we have - // created above and transform it into a derived table - if op.DT != nil { - sel := qb.asSelectStatement() - qb.stmt = nil - qb.addTableExpr(op.DT.Alias, op.DT.Alias, TableID(op), &sqlparser.DerivedTable{ - Select: sel, - }, nil, op.DT.Columns) - } - - if !isSel { - cols := op.GetSelectExprs(qb.ctx) - for _, column := range cols { - qb.addProjection(column) +func unionSelects(exprs sqlparser.SelectExprs) (selectExprs sqlparser.SelectExprs) { + for _, col := range exprs { + switch col := col.(type) { + case *sqlparser.AliasedExpr: + expr := sqlparser.NewColName(col.ColumnName()) + selectExprs = append(selectExprs, &sqlparser.AliasedExpr{Expr: expr}) + default: + selectExprs = append(selectExprs, col) } } + return } -func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) { - predicates := slice.Map(op.JoinPredicates.columns, func(jc applyJoinColumn) sqlparser.Expr { - // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done - err := qb.ctx.SkipJoinPredicates(jc.Original) - if err != nil { - panic(err) +func checkUnionColumnByName(column *sqlparser.ColName, sel sqlparser.TableStatement) { + colName := column.Name.String() + firstSelect := getFirstSelect(sel) + exprs := firstSelect.SelectExprs + offset := slices.IndexFunc(exprs, func(expr sqlparser.SelectExpr) bool { + switch ae := expr.(type) { + case *sqlparser.StarExpr: + return true + case *sqlparser.AliasedExpr: + // When accessing columns on top of a UNION, we fall back to this simple strategy of string comparisons + return ae.ColumnName() == colName } - return jc.Original + return false }) - pred := sqlparser.AndExpressions(predicates...) - - buildQuery(op.LHS, qb) - - qbR := &queryBuilder{ctx: qb.ctx} - buildQuery(op.RHS, qbR) - - switch { - // if we have a recursive cte, we might be missing a statement from one of the sides - case qbR.stmt == nil: - // do nothing - case qb.stmt == nil: - qb.stmt = qbR.stmt - default: - qb.joinWith(qbR, pred, op.JoinType) - } -} - -func buildUnion(op *Union, qb *queryBuilder) { - // the first input is built first - buildQuery(op.Sources[0], qb) - - for i, src := range op.Sources { - if i == 0 { - continue - } - - // now we can go over the remaining inputs and UNION them together - qbOther := &queryBuilder{ctx: qb.ctx} - buildQuery(src, qbOther) - qb.unionWith(qbOther, op.distinct) - } -} - -func buildFilter(op *Filter, qb *queryBuilder) { - buildQuery(op.Source, qb) - - for _, pred := range op.Predicates { - qb.addPredicate(pred) + if offset == -1 { + panic(vterrors.VT12001(fmt.Sprintf("did not find column [%s] on UNION", sqlparser.String(column)))) } } -func buildDerived(op *Horizon, qb *queryBuilder) { - buildQuery(op.Source, qb) - - sqlparser.RemoveKeyspaceInCol(op.Query) - - stmt := qb.stmt - qb.stmt = nil - switch sel := stmt.(type) { - case *sqlparser.Select: - buildDerivedSelect(op, qb, sel) - return - case *sqlparser.Union: - buildDerivedUnion(op, qb, sel) - return - } - panic(fmt.Sprintf("unknown select statement type: %T", stmt)) +type FromStatement interface { + GetFrom() []sqlparser.TableExpr + SetFrom([]sqlparser.TableExpr) + GetWherePredicate() sqlparser.Expr + SetWherePredicate(sqlparser.Expr) } -func buildDerivedUnion(op *Horizon, qb *queryBuilder, union *sqlparser.Union) { - opQuery, ok := op.Query.(*sqlparser.Union) - if !ok { - panic(vterrors.VT12001("Horizon contained SELECT but statement was UNION")) - } - - union.Limit = opQuery.Limit - union.OrderBy = opQuery.OrderBy - union.Distinct = opQuery.Distinct - - qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{ - Select: union, - }, nil, op.ColumnAliases) -} +var _ FromStatement = (*sqlparser.Select)(nil) +var _ FromStatement = (*sqlparser.Update)(nil) +var _ FromStatement = (*sqlparser.Delete)(nil) -func buildDerivedSelect(op *Horizon, qb *queryBuilder, sel *sqlparser.Select) { - opQuery, ok := op.Query.(*sqlparser.Select) - if !ok { - panic(vterrors.VT12001("Horizon contained UNION but statement was SELECT")) - } - sel.Limit = opQuery.Limit - sel.OrderBy = opQuery.OrderBy - sel.GroupBy = opQuery.GroupBy - sel.Having = mergeHaving(sel.Having, opQuery.Having) - sel.SelectExprs = opQuery.SelectExprs - sel.Distinct = opQuery.Distinct - qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{ - Select: sel, - }, nil, op.ColumnAliases) - for _, col := range op.Columns { - qb.addProjection(&sqlparser.AliasedExpr{Expr: col}) +func buildJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr, joinType sqlparser.JoinType) *sqlparser.JoinTableExpr { + var lhs sqlparser.TableExpr + fromClause := stmt.GetFrom() + if len(fromClause) == 1 { + lhs = fromClause[0] + } else { + lhs = &sqlparser.ParenTableExpr{Exprs: fromClause} } -} - -func buildHorizon(op *Horizon, qb *queryBuilder) { - buildQuery(op.Source, qb) - stripDownQuery(op.Query, qb.asSelectStatement()) - sqlparser.RemoveKeyspaceInCol(qb.stmt) -} - -func buildRecursiveCTE(op *RecurseCTE, qb *queryBuilder) { - predicates := slice.Map(op.Predicates, func(jc *plancontext.RecurseExpression) sqlparser.Expr { - // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done - err := qb.ctx.SkipJoinPredicates(jc.Original) - if err != nil { - panic(err) - } - return jc.Original - }) - pred := sqlparser.AndExpressions(predicates...) - buildQuery(op.Seed(), qb) - qbR := &queryBuilder{ctx: qb.ctx} - buildQuery(op.Term(), qbR) - qbR.addPredicate(pred) - infoFor, err := qb.ctx.SemTable.TableInfoFor(op.OuterID) - if err != nil { - panic(err) + var rhs sqlparser.TableExpr + otherFromClause := otherStmt.GetFrom() + if len(otherFromClause) == 1 { + rhs = otherFromClause[0] + } else { + rhs = &sqlparser.ParenTableExpr{Exprs: otherFromClause} } - qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String(), op.Distinct, op.Def.Columns) -} - -func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where { - switch { - case h1 == nil && h2 == nil: - return nil - case h1 == nil: - return h2 - case h2 == nil: - return h1 - default: - h1.Expr = sqlparser.AndExpressions(h1.Expr, h2.Expr) - return h1 + return &sqlparser.JoinTableExpr{ + LeftExpr: lhs, + RightExpr: rhs, + Join: joinType, + Condition: &sqlparser.JoinCondition{ + On: onCondition, + }, } } diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index db716966d47..2c061c5a36b 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -104,7 +104,8 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { return tryPushUpdate(in) case *RecurseCTE: return tryMergeRecurse(ctx, in) - + case *Values: + return tryPushValues(in) default: return in, NoRewrite } @@ -120,6 +121,13 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { return FixedPointBottomUp(root, TableID, visitor, stopAtRoute) } +func tryPushValues(in *Values) (Operator, *ApplyResult) { + if src, ok := in.Source.(*Route); ok { + return Swap(in, src, "pushed values under route") + } + return in, NoRewrite +} + func tryPushDelete(in *Delete) (Operator, *ApplyResult) { if src, ok := in.Source.(*Route); ok { return pushDMLUnderRoute(in, src, "pushed delete under route") diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index 90eb16e1562..18166b48a44 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -287,6 +287,31 @@ func requiresSwitchingSides(ctx *plancontext.PlanningContext, op Operator) (requ return } +func newJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType sqlparser.JoinType) JoinOp { + lhsID := TableID(lhs) + if lhsID.NumberOfTables() > 1 || !joinType.IsInner() { + return NewApplyJoin(ctx, lhs, rhs, nil, joinType) + } + lhsTableInfo, err := ctx.SemTable.TableInfoFor(lhsID) + if err != nil { + panic(vterrors.VT13001(err.Error())) + } + lhsTableName, err := lhsTableInfo.Name() + if err != nil { + panic(vterrors.VT13001(err.Error())) + } + bindVariableName := ctx.ReservedVars.ReserveVariable("values") + v := &Values{ + unaryOperator: newUnaryOp(rhs), + Name: lhsTableName.Name.String(), + Arg: bindVariableName, + } + return &ValuesJoin{ + binaryOperator: newBinaryOp(lhs, v), + BindVarName: bindVariableName, + } +} + func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredicates []sqlparser.Expr, joinType sqlparser.JoinType) (Operator, *ApplyResult) { jm := newJoinMerge(joinPredicates, joinType) newPlan := jm.mergeJoinInputs(ctx, lhs, rhs, joinPredicates) @@ -305,14 +330,14 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredic return join, Rewrote("use a hash join because we have LIMIT on the LHS") } - join := NewApplyJoin(ctx, Clone(rhs), Clone(lhs), nil, joinType) + join := newJoin(ctx, Clone(rhs), Clone(lhs), joinType) for _, pred := range joinPredicates { join.AddJoinPredicate(ctx, pred) } return join, Rewrote("logical join to applyJoin, switching side because LIMIT") } - join := NewApplyJoin(ctx, Clone(lhs), Clone(rhs), nil, joinType) + join := newJoin(ctx, Clone(lhs), Clone(rhs), joinType) for _, pred := range joinPredicates { join.AddJoinPredicate(ctx, pred) } diff --git a/go/vt/vtgate/planbuilder/operators/sharded_routing.go b/go/vt/vtgate/planbuilder/operators/sharded_routing.go index 2c8873dee07..2737f74fcd7 100644 --- a/go/vt/vtgate/planbuilder/operators/sharded_routing.go +++ b/go/vt/vtgate/planbuilder/operators/sharded_routing.go @@ -613,7 +613,6 @@ func (tr *ShardedRouting) planCompositeInOpArg( Index: idx, } if typ, found := ctx.TypeForExpr(col); found { - value.Type = typ.Type() value.Collation = typ.Collation() } diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index 9610a2b10c9..b6a21501225 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -101,7 +101,7 @@ func (sq *SubQuery) GetJoinColumns(ctx *plancontext.PlanningContext, outer Opera } sq.outerID = outerID mapper := func(in sqlparser.Expr) (applyJoinColumn, error) { - return breakExpressionInLHSandRHS(ctx, in, outerID), nil + return breakApplyJoinExpressionInLHSandRHS(ctx, in, outerID), nil } joinPredicates, err := slice.MapWithError(sq.Predicates, mapper) if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 06ca69dd7f3..174d39db266 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -289,7 +289,7 @@ func extractLHSExpr( lhs semantics.TableSet, ) func(expr sqlparser.Expr) sqlparser.Expr { return func(expr sqlparser.Expr) sqlparser.Expr { - col := breakExpressionInLHSandRHS(ctx, expr, lhs) + col := breakApplyJoinExpressionInLHSandRHS(ctx, expr, lhs) if col.IsPureLeft() { panic(vterrors.VT13001("did not expect to find any predicates that do not need data from the inner here")) } @@ -668,7 +668,7 @@ func (s *subqueryRouteMerger) merge(ctx *plancontext.PlanningContext, inner, out // We really need to figure out why this is not working as expected func (s *subqueryRouteMerger) rewriteASTExpression(ctx *plancontext.PlanningContext, inner *Route) Operator { src := s.outer.Source - stmt, _, err := ToSQL(ctx, inner.Source) + stmt, _, err := ToAST(ctx, inner.Source) if err != nil { panic(err) } diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 18a81175f7b..158a34e2cc2 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -212,7 +212,7 @@ func prepareUpdateExpressionList(ctx *plancontext.PlanningContext, upd *sqlparse for _, ue := range upd.Exprs { target := ctx.SemTable.DirectDeps(ue.Name) exprDeps := ctx.SemTable.RecursiveDeps(ue.Expr) - jc := breakExpressionInLHSandRHS(ctx, ue.Expr, exprDeps.Remove(target)) + jc := breakApplyJoinExpressionInLHSandRHS(ctx, ue.Expr, exprDeps.Remove(target)) ueMap[target] = append(ueMap[target], updColumn{ue.Name, jc}) } diff --git a/go/vt/vtgate/planbuilder/operators/values.go b/go/vt/vtgate/planbuilder/operators/values.go new file mode 100644 index 00000000000..d8de43b7e79 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/values.go @@ -0,0 +1,93 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +type Values struct { + unaryOperator + + Name string + Arg string +} + +func (v *Values) Clone(inputs []Operator) Operator { + clone := *v + return &clone +} + +func (v *Values) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { + return newFilter(v, expr) +} + +func (v *Values) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { + panic(vterrors.VT13001("we cannot add new columns to a Values operator")) +} + +func (v *Values) AddWSColumn(ctx *plancontext.PlanningContext, offset int, underRoute bool) int { + panic(vterrors.VT13001("we cannot add new columns to a Values operator")) +} + +func (v *Values) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { + col, ok := expr.(*sqlparser.ColName) + if !ok { + return -1 + } + for i, column := range v.getColsFromCtx(ctx) { + if col.Name.Equal(column) { + return i + } + } + return -1 +} + +func (v *Values) getColsFromCtx(ctx *plancontext.PlanningContext) sqlparser.Columns { + columns, found := ctx.ValuesJoinColumns[v.Arg] + if !found { + panic(vterrors.VT13001("columns not found")) + } + return columns +} + +func (v *Values) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { + var cols []*sqlparser.AliasedExpr + for _, column := range v.getColsFromCtx(ctx) { + cols = append(cols, sqlparser.NewAliasedExpr(sqlparser.NewColName(column.String()), "")) + } + return cols +} + +func (v *Values) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { + r := v.GetColumns(ctx) + var selectExprs sqlparser.SelectExprs + for _, expr := range r { + selectExprs = append(selectExprs, expr) + } + return selectExprs +} + +func (v *Values) ShortDescription() string { + return v.Name +} + +func (v *Values) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy { + return v.Source.GetOrdering(ctx) +} diff --git a/go/vt/vtgate/planbuilder/operators/values_join.go b/go/vt/vtgate/planbuilder/operators/values_join.go new file mode 100644 index 00000000000..9103e413254 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/values_join.go @@ -0,0 +1,180 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +type ( + ValuesJoin struct { + binaryOperator + + BindVarName string + + JoinColumns []valuesJoinColumn + JoinPredicates []valuesJoinColumn + + // After offset planning + + // CopyColumnsToRHS are the offsets of columns from LHS we are copying over to the RHS + // []int{0,2} means that the first column in the t-o-t is the first offset from the left and the second column is the third offset + CopyColumnsToRHS []int + + Columns []int + ColumnName []string + } + + valuesJoinColumn struct { + Original sqlparser.Expr + LHS []*sqlparser.ColName + PureLHS bool + } +) + +var _ Operator = (*ValuesJoin)(nil) +var _ JoinOp = (*ValuesJoin)(nil) + +func (vj *ValuesJoin) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { + if reuseExisting { + if offset := vj.FindCol(ctx, expr.Expr, false); offset >= 0 { + return offset + } + } + + vj.JoinColumns = append(vj.JoinColumns, breakValuesJoinExpressionInLHS(ctx, expr.Expr, TableID(vj.LHS))) + vj.ColumnName = append(vj.ColumnName, expr.ColumnName()) + return len(vj.JoinColumns) - 1 +} + +// AddWSColumn is used to add a weight_string column to the operator +func (vj *ValuesJoin) AddWSColumn(ctx *plancontext.PlanningContext, offset int, underRoute bool) int { + panic("oh no") +} + +func (vj *ValuesJoin) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { + for offset, column := range vj.JoinColumns { + if ctx.SemTable.EqualsExpr(column.Original, expr) { + return offset + } + } + return -1 +} + +func (vj *ValuesJoin) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { + results := make([]*sqlparser.AliasedExpr, len(vj.JoinColumns)) + for i, column := range vj.JoinColumns { + results[i] = sqlparser.NewAliasedExpr(column.Original, vj.ColumnName[i]) + } + return results +} + +func (vj *ValuesJoin) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { + return transformColumnsToSelectExprs(ctx, vj) +} + +func (vj *ValuesJoin) GetLHS() Operator { + return vj.LHS +} + +func (vj *ValuesJoin) GetRHS() Operator { + return vj.RHS +} + +func (vj *ValuesJoin) SetLHS(operator Operator) { + vj.LHS = operator +} + +func (vj *ValuesJoin) SetRHS(operator Operator) { + vj.RHS = operator +} + +func (vj *ValuesJoin) MakeInner() { + // no-op for values-join +} + +func (vj *ValuesJoin) IsInner() bool { + return true +} + +func (vj *ValuesJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) { + if expr == nil { + return + } + lID := TableID(vj.LHS) + lhsJoinCols := breakValuesJoinExpressionInLHS(ctx, expr, lID) + if lhsJoinCols.PureLHS { + vj.LHS = vj.LHS.AddPredicate(ctx, expr) + return + } + vj.RHS = vj.RHS.AddPredicate(ctx, expr) + vj.JoinPredicates = append(vj.JoinPredicates, lhsJoinCols) +} + +func (vj *ValuesJoin) Clone(inputs []Operator) Operator { + clone := *vj + clone.LHS = inputs[0] + clone.RHS = inputs[1] + return &clone +} + +func (vj *ValuesJoin) ShortDescription() string { + return "" +} + +func (vj *ValuesJoin) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy { + return vj.RHS.GetOrdering(ctx) +} + +func (vj *ValuesJoin) planOffsets(ctx *plancontext.PlanningContext) Operator { + valuesColumns := ctx.ValuesJoinColumns[vj.BindVarName] + for i, jc := range vj.JoinColumns { + if jc.PureLHS { + offset := vj.LHS.AddColumn(ctx, true, false, sqlparser.NewAliasedExpr(jc.Original, vj.ColumnName[i])) + vj.Columns = append(vj.Columns, ToLeftOffset(offset)) + continue + } + vj.planOffsetsForValueJoinPredicate(ctx, jc.LHS, &valuesColumns) + ctx.ValuesJoinColumns[vj.BindVarName] = valuesColumns + + offset := vj.RHS.AddColumn(ctx, true, false, aeWrap(jc.Original)) + vj.Columns = append(vj.Columns, ToRightOffset(offset)) + } + + for _, predicate := range vj.JoinPredicates { + vj.planOffsetsForValueJoinPredicate(ctx, predicate.LHS, &valuesColumns) + } + + ctx.ValuesJoinColumns[vj.BindVarName] = valuesColumns + return vj +} + +func (vj *ValuesJoin) planOffsetsForValueJoinPredicate(ctx *plancontext.PlanningContext, lhsPred []*sqlparser.ColName, valuesColumns *sqlparser.Columns) { +outer: + for _, lh := range lhsPred { + for _, ci := range *valuesColumns { + if ci.Equal(lh.Name) { + // already there, no need to add it again + continue outer + } + } + offset := vj.LHS.AddColumn(ctx, true, false, aeWrap(lh)) + vj.CopyColumnsToRHS = append(vj.CopyColumnsToRHS, offset) + *valuesColumns = append(*valuesColumns, lh.Name) + } +} diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 016f5c877cf..eab9d1c18d4 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -43,6 +43,12 @@ type PlanningContext struct { // a join predicate is reverted to its original form during planning. skipPredicates map[sqlparser.Expr]any + // skipValuesArgument tracks Values operator that should be skipped when + // rewriting the operator tree to an AST tree. + // This happens when a ValuesJoin is pushed under a route and we do not + // need to have a Values operator anymore on its RHS. + skipValuesArgument map[string]any + PlannerVersion querypb.ExecuteOptions_PlannerVersion // If we during planning have turned this expression into an argument name, @@ -77,10 +83,23 @@ type PlanningContext struct { // isMirrored indicates that mirrored tables should be used. isMirrored bool + // ValuesJoinColumns stores the columns we need for each values statement in the plan. + ValuesJoinColumns map[string]sqlparser.Columns + emptyEnv *evalengine.ExpressionEnv constantCfg *evalengine.Config } +func CreateEmptyPlanningContext() *PlanningContext { + return &PlanningContext{ + joinPredicates: make(map[sqlparser.Expr][]sqlparser.Expr), + skipPredicates: make(map[sqlparser.Expr]any), + skipValuesArgument: make(map[string]any), + ReservedArguments: make(map[sqlparser.Expr]string), + ValuesJoinColumns: make(map[string]sqlparser.Columns), + } +} + // CreatePlanningContext initializes a new PlanningContext with the given parameters. // It analyzes the SQL statement within the given virtual schema context, // handling default keyspace settings and semantic analysis. @@ -104,14 +123,16 @@ func CreatePlanningContext(stmt sqlparser.Statement, vschema.PlannerWarning(semTable.Warning) return &PlanningContext{ - ReservedVars: reservedVars, - SemTable: semTable, - VSchema: vschema, - joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, - skipPredicates: map[sqlparser.Expr]any{}, - PlannerVersion: version, - ReservedArguments: map[sqlparser.Expr]string{}, - Statement: stmt, + ReservedVars: reservedVars, + SemTable: semTable, + VSchema: vschema, + joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, + skipPredicates: map[sqlparser.Expr]any{}, + skipValuesArgument: map[string]any{}, + PlannerVersion: version, + ReservedArguments: map[sqlparser.Expr]string{}, + ValuesJoinColumns: make(map[string]sqlparser.Columns), + Statement: stmt, }, nil } @@ -176,6 +197,15 @@ func (ctx *PlanningContext) SkipJoinPredicates(joinPred sqlparser.Expr) error { return vterrors.VT13001("predicate does not exist: " + sqlparser.String(joinPred)) } +func (ctx *PlanningContext) SkipValuesArgument(name string) { + ctx.skipValuesArgument[name] = "" +} + +func (ctx *PlanningContext) IsValuesArgumentSkipped(name string) bool { + _, ok := ctx.skipValuesArgument[name] + return ok +} + // KeepPredicateInfo transfers join predicate information from another context. // This is useful when nesting queries, ensuring consistent predicate handling across contexts. func (ctx *PlanningContext) KeepPredicateInfo(other *PlanningContext) { diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.json b/go/vt/vtgate/planbuilder/testdata/onecase.json index 9d653b2f6e9..a0f3630869f 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.json +++ b/go/vt/vtgate/planbuilder/testdata/onecase.json @@ -1,7 +1,7 @@ [ { "comment": "Add your test case here for debugging and run go test -run=One.", - "query": "", + "query": "select user.foo, user_extra.user_id from user, user_extra where user.id = user_extra.toto", "plan": { } }