Skip to content

Commit

Permalink
Make sure no AST types are bare slices (#17674)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
Signed-off-by: Vicent Marti <[email protected]>
Co-authored-by: Vicent Marti <[email protected]>
  • Loading branch information
systay and vmg authored Feb 6, 2025
1 parent cf28afa commit 9bf3326
Show file tree
Hide file tree
Showing 112 changed files with 3,130 additions and 1,960 deletions.
32 changes: 16 additions & 16 deletions go/test/endtoend/vtgate/queries/random/query_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (sg *selectGenerator) randomSelect() {
}

// make sure we have at least one select expression
for isRandomExpr || len(sg.sel.SelectExprs) == 0 {
for isRandomExpr || len(sg.sel.SelectExprs.Exprs) == 0 {
// TODO: if the random expression is an int literal,
// TODO: and if the query is (potentially) an aggregate query,
// TODO: then we must group by the random expression,
Expand Down Expand Up @@ -395,7 +395,7 @@ func (sg *selectGenerator) createJoin(tables []tableT) {

// returns 1-3 random expressions based on the last two elements of tables
// tables should have at least two elements
func (sg *selectGenerator) createJoinPredicates(tables []tableT) sqlparser.Exprs {
func (sg *selectGenerator) createJoinPredicates(tables []tableT) []sqlparser.Expr {
if len(tables) < 2 {
log.Fatalf("tables has %d elements, needs at least 2", len(tables))
}
Expand Down Expand Up @@ -427,7 +427,7 @@ func (sg *selectGenerator) createGroupBy(tables []tableT) (grouping []column) {

// add to select
if rand.IntN(2) < 1 {
sg.sel.SelectExprs = append(sg.sel.SelectExprs, newAliasedColumn(col, ""))
sg.sel.AddSelectExpr(newAliasedColumn(col, ""))
grouping = append(grouping, col)
}
}
Expand All @@ -437,13 +437,13 @@ func (sg *selectGenerator) createGroupBy(tables []tableT) (grouping []column) {

// aliasGroupingColumns randomly aliases the grouping columns in the SelectExprs
func (sg *selectGenerator) aliasGroupingColumns(grouping []column) []column {
if len(grouping) != len(sg.sel.SelectExprs) {
log.Fatalf("grouping (length: %d) and sg.sel.SelectExprs (length: %d) should have the same length at this point", len(grouping), len(sg.sel.SelectExprs))
if len(grouping) != len(sg.sel.SelectExprs.Exprs) {
log.Fatalf("grouping (length: %d) and sg.sel.SelectExprs (length: %d) should have the same length at this point", len(grouping), len(sg.sel.SelectExprs.Exprs))
}

for i := range grouping {
if rand.IntN(2) < 1 {
if aliasedExpr, ok := sg.sel.SelectExprs[i].(*sqlparser.AliasedExpr); ok {
if aliasedExpr, ok := sg.sel.SelectExprs.Exprs[i].(*sqlparser.AliasedExpr); ok {
alias := fmt.Sprintf("cgroup%d", i)
aliasedExpr.SetAlias(alias)
grouping[i].name = alias
Expand All @@ -454,7 +454,7 @@ func (sg *selectGenerator) aliasGroupingColumns(grouping []column) []column {
return grouping
}

// returns the aggregation columns as three types: sqlparser.SelectExprs, []column
// returns the aggregation columns as three types: *sqlparser.SelectExprs, []column
func (sg *selectGenerator) createAggregations(tables []tableT) (aggregates []column) {
exprGenerators := slice.Map(tables, func(t tableT) sqlparser.ExprGenerator { return &t })
// add scalar subqueries
Expand Down Expand Up @@ -485,7 +485,7 @@ func (sg *selectGenerator) createOrderBy() {
}

// randomly order on SelectExprs
for _, selExpr := range sg.sel.SelectExprs {
for _, selExpr := range sg.sel.SelectExprs.Exprs {
if aliasedExpr, ok := selExpr.(*sqlparser.AliasedExpr); ok && rand.IntN(2) < 1 {
literal, ok := aliasedExpr.Expr.(*sqlparser.Literal)
isIntLiteral := ok && literal.Type == sqlparser.IntVal
Expand Down Expand Up @@ -527,7 +527,7 @@ func (sg *selectGenerator) createHavingPredicates(grouping []column) {
}

// returns between minExprs and maxExprs random expressions using generators
func (sg *selectGenerator) createRandomExprs(minExprs, maxExprs int, generators ...sqlparser.ExprGenerator) (predicates sqlparser.Exprs) {
func (sg *selectGenerator) createRandomExprs(minExprs, maxExprs int, generators ...sqlparser.ExprGenerator) (predicates []sqlparser.Expr) {
if minExprs > maxExprs {
log.Fatalf("minExprs is greater than maxExprs; minExprs: %d, maxExprs: %d\n", minExprs, maxExprs)
} else if maxExprs <= 0 {
Expand Down Expand Up @@ -578,28 +578,28 @@ func (sg *selectGenerator) randomlyAlias(expr sqlparser.Expr, alias string) colu
} else {
col.name = alias
}
sg.sel.SelectExprs = append(sg.sel.SelectExprs, sqlparser.NewAliasedExpr(expr, alias))
sg.sel.AddSelectExpr(sqlparser.NewAliasedExpr(expr, alias))

return col
}

// matchNumCols makes sure sg.sel.SelectExprs and newTable both have the same number of cols: sg.genConfig.NumCols
func (sg *selectGenerator) matchNumCols(tables []tableT, newTable tableT, canAggregate bool) tableT {
// remove SelectExprs and newTable.cols randomly until there are sg.genConfig.NumCols amount
for len(sg.sel.SelectExprs) > sg.genConfig.NumCols && sg.genConfig.NumCols > 0 {
for len(sg.sel.SelectExprs.Exprs) > sg.genConfig.NumCols && sg.genConfig.NumCols > 0 {
// select a random index and remove it from SelectExprs and newTable
idx := rand.IntN(len(sg.sel.SelectExprs))
idx := rand.IntN(len(sg.sel.SelectExprs.Exprs))

sg.sel.SelectExprs[idx] = sg.sel.SelectExprs[len(sg.sel.SelectExprs)-1]
sg.sel.SelectExprs = sg.sel.SelectExprs[:len(sg.sel.SelectExprs)-1]
sg.sel.SelectExprs.Exprs[idx] = sg.sel.SelectExprs.Exprs[len(sg.sel.SelectExprs.Exprs)-1]
sg.sel.SelectExprs.Exprs = sg.sel.SelectExprs.Exprs[:len(sg.sel.SelectExprs.Exprs)-1]

newTable.cols[idx] = newTable.cols[len(newTable.cols)-1]
newTable.cols = newTable.cols[:len(newTable.cols)-1]
}

// alternatively, add random expressions until there are sg.genConfig.NumCols amount
if sg.genConfig.NumCols > len(sg.sel.SelectExprs) {
diff := sg.genConfig.NumCols - len(sg.sel.SelectExprs)
if sg.genConfig.NumCols > len(sg.sel.SelectExprs.Exprs) {
diff := sg.genConfig.NumCols - len(sg.sel.SelectExprs.Exprs)
exprs := sg.createRandomExprs(diff, diff,
slice.Map(tables, func(t tableT) sqlparser.ExprGenerator { return &t })...)

Expand Down
26 changes: 24 additions & 2 deletions go/tools/astfmtgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,6 @@ func (r *Rewriter) rewriteAstPrintf(cursor *astutil.Cursor, expr *ast.CallExpr)

token := format[i]
switch token {
case 'c':
cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteByte", expr.Args[2+fieldnum]))
case 's':
cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteString", expr.Args[2+fieldnum]))
case 'l', 'r', 'v':
Expand Down Expand Up @@ -249,6 +247,26 @@ func (r *Rewriter) rewriteAstPrintf(cursor *astutil.Cursor, expr *ast.CallExpr)
Args: []ast.Expr{&ast.BasicLit{Value: `"%d"`, Kind: gotoken.STRING}, expr.Args[2+fieldnum]},
}
cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteString", call))
case 'n': // directive for slices of AST nodes checked at code generation time
inputExpr := expr.Args[2+fieldnum]
inputType := r.pkg.TypesInfo.Types[inputExpr].Type
sliceType, ok := inputType.(*types.Slice)
if !ok {
panic("'%n' directive requires a slice")
}
if types.Implements(sliceType.Elem(), r.astExpr) {
// Fast path: input is []Expr
call := &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: callexpr.X,
Sel: &ast.Ident{Name: "formatExprs"},
},
Args: []ast.Expr{inputExpr},
}
cursor.InsertBefore(&ast.ExprStmt{X: call})
break
}
panic("slow path for `n` directive for slice of type other than Expr")
default:
panic(fmt.Sprintf("unsupported escape %q", token))
}
Expand All @@ -259,3 +277,7 @@ func (r *Rewriter) rewriteAstPrintf(cursor *astutil.Cursor, expr *ast.CallExpr)
cursor.Delete()
return true
}

var noQualifier = func(p *types.Package) string {
return ""
}
11 changes: 1 addition & 10 deletions go/tools/asthelpergen/asthelpergen.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ type (
}
)

// exprInterfacePath is the path of the sqlparser.Expr interface.
const exprInterfacePath = "vitess.io/vitess/go/vt/sqlparser.Expr"

func (gen *astHelperGen) iface() *types.Interface {
return gen._iface
}
Expand Down Expand Up @@ -207,19 +204,13 @@ func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) {
return nil, err
}

exprType, _ := findTypeObject(exprInterfacePath, scopes)
var exprInterface *types.Interface
if exprType != nil {
exprInterface = exprType.Type().(*types.Named).Underlying().(*types.Interface)
}

nt := tt.Type().(*types.Named)
pName := nt.Obj().Pkg().Name()
generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt,
newEqualsGen(pName, &options.Equals),
newCloneGen(pName, &options.Clone),
newVisitGen(pName),
newRewriterGen(pName, types.TypeString(nt, noQualifier), exprInterface),
newRewriterGen(pName, types.TypeString(nt, noQualifier)),
newCOWGen(pName, nt),
)

Expand Down
19 changes: 19 additions & 0 deletions go/tools/asthelpergen/asthelpergen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"strings"
"testing"

"vitess.io/vitess/go/tools/codegen"

"github.com/stretchr/testify/require"
)

Expand All @@ -45,3 +47,20 @@ func TestFullGeneration(t *testing.T) {
require.False(t, applyIdx == 0 && cloneIdx == 0, "file doesn't contain expected contents")
}
}

func TestRecreateAllFiles(t *testing.T) {
// t.Skip("This test recreates all files in the integration directory. It should only be run when the ASTHelperGen code has changed.")
result, err := GenerateASTHelpers(&Options{
Packages: []string{"./integration/..."},
RootInterface: "vitess.io/vitess/go/tools/asthelpergen/integration.AST",
Clone: CloneOptions{
Exclude: []string{"*NoCloneType"},
},
})
require.NoError(t, err)

for fullPath, file := range result {
err := codegen.SaveJenFile(fullPath, file)
require.NoError(t, err)
}
}
Loading

0 comments on commit 9bf3326

Please sign in to comment.