Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CTEs. #37

Merged
merged 2 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 111 additions & 40 deletions pkg/vet/vet.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func validateTableColumns(ctx VetContext, tables []TableUsed, cols []ColumnUsed)
}
}
if !found {
if len(tables) == 1 {
if len(usedTables) == 1 {
// to make error message more useful, if only one table is
// referenced in the query, it's safe to assume user only
// want to use columns from that table.
Expand Down Expand Up @@ -541,6 +541,11 @@ func getUsedColumnsFromSortClause(sortList []*pg_query.Node) []ColumnUsed {
func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) (queryParams []QueryParam, targetCols []schema.Column, err error) {
usedCols := []ColumnUsed{}

if stmt.GetWithClause() != nil {
if err := parseCTE(ctx, stmt.GetWithClause()); err != nil {
return nil, nil, err
}
}
postponed := PostponedNodes{}
for _, fromClause := range stmt.FromClause {
re := &ParseResult{}
Expand Down Expand Up @@ -648,11 +653,22 @@ func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) (queryParams
return queryParams, targetCols, validateTableColumns(ctx, ctx.UsedTables, usedCols)
}

func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam, error) {
func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam, []ColumnUsed, error) {
if stmt.GetWithClause() != nil {
if err := parseCTE(ctx, stmt.GetWithClause()); err != nil {
return nil, nil, err
}
}
tableName := stmt.Relation.Relname
if err := validateTable(ctx, tableName, true); err != nil {
return nil, err
return nil, nil, err
}

var tableAlias string
if stmt.Relation.Alias != nil {
tableAlias = stmt.Relation.Alias.Aliasname
}

usedTables := []TableUsed{{Name: tableName}}
usedTables = append(usedTables, getUsedTablesFromSelectStmt(stmt.FromClause)...)

Expand Down Expand Up @@ -682,7 +698,7 @@ func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam
re := &ParseResult{}
err := parseWhereClause(ctx, stmt.WhereClause, re)
if err != nil {
return nil, err
return nil, nil, err
}
usedCols = append(usedCols, re.Columns...)
AddQueryParams(&queryParams, re.Params)
Expand All @@ -692,13 +708,25 @@ func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam
usedCols = append(usedCols, getUsedColumnsFromReturningList(stmt.ReturningList)...)
}

return queryParams, validateTableColumns(ctx, usedTables, usedCols)
if len(usedCols) > 0 {
usedTables = append(usedTables, TableUsed{Name: tableName, Alias: tableAlias})
if err := validateTableColumns(ctx, usedTables, usedCols); err != nil {
return nil, nil, err
}
}

return queryParams, usedCols, nil
}

func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam, error) {
func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam, []ColumnUsed, error) {
if stmt.GetWithClause() != nil {
if err := parseCTE(ctx, stmt.GetWithClause()); err != nil {
return nil, nil, err
}
}
tableName := stmt.Relation.Relname
if err := validateTable(ctx, tableName, true); err != nil {
return nil, err
return nil, nil, err
}
usedTables := []TableUsed{{Name: tableName}}

Expand Down Expand Up @@ -736,7 +764,7 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam
re := &ParseResult{}
err := parseExpression(ctx, node, re)
if err != nil {
return nil, fmt.Errorf("invalid value list: %w", err)
return nil, nil, fmt.Errorf("invalid value list: %w", err)
}
if len(re.Columns) > 0 {
usedCols = append(usedCols, re.Columns...)
Expand All @@ -758,7 +786,7 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam
re := &ParseResult{}
err := parseFromClause(ctx, fromClause, re)
if err != nil {
return nil, err
return nil, nil, err
}
if len(re.Columns) > 0 {
usedCols = append(usedCols, re.Columns...)
Expand All @@ -772,7 +800,7 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam
re := &ParseResult{}
err := parseWhereClause(ctx, selectStmt.WhereClause, re)
if err != nil {
return nil, err
return nil, nil, err
}
if len(re.Columns) > 0 {
usedCols = append(usedCols, re.Columns...)
Expand All @@ -796,12 +824,12 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam
case target.GetSubLink() != nil:
tv := target.GetSubLink().Subselect
if tv.GetSelectStmt() == nil {
return nil, fmt.Errorf(
return nil, nil, fmt.Errorf(
"unsupported subquery type in value list: %s", reflect.TypeOf(tv))
}
qparams, _, err := validateSelectStmt(ctx, tv.GetSelectStmt())
if err != nil {
return nil, fmt.Errorf("invalid SELECT query in value list: %w", err)
return nil, nil, fmt.Errorf("invalid SELECT query in value list: %w", err)
}
if len(qparams) > 0 {
AddQueryParams(&queryParams, qparams)
Expand All @@ -815,17 +843,22 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam
}

if err := validateTableColumns(ctx, usedTables, usedCols); err != nil {
return nil, err
return nil, nil, err
}

if err := validateInsertValues(ctx, targetCols, values); err != nil {
return nil, err
return nil, nil, err
}

return queryParams, nil
return queryParams, usedCols, nil
}

func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam, error) {
func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam, []ColumnUsed, error) {
if stmt.GetWithClause() != nil {
if err := parseCTE(ctx, stmt.GetWithClause()); err != nil {
return nil, nil, err
}
}
tableName := stmt.Relation.Relname
var tableAlias string

Expand All @@ -834,7 +867,7 @@ func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam
}

if err := validateTable(ctx, tableName, true); err != nil {
return nil, err
return nil, nil, err
}

usedCols := []ColumnUsed{}
Expand All @@ -845,25 +878,25 @@ func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam
re := &ParseResult{}
err := parseWhereClause(ctx, stmt.WhereClause, re)
if err != nil {
return nil, err
return nil, nil, err
}
if len(re.Columns) > 0 {
usedCols = append(usedCols, re.Columns...)
} else {
return nil, fmt.Errorf("no columns in DELETE's WHERE clause")
return nil, nil, fmt.Errorf("no columns in DELETE's WHERE clause")
}
if len(re.Params) > 0 {
queryParams = re.Params
}
} else {
return nil, fmt.Errorf("no WHERE clause for DELETE")
return nil, nil, fmt.Errorf("no WHERE clause for DELETE")
}

for _, using := range stmt.UsingClause {
re := &ParseResult{}
err := parseUsingClause(ctx, using, re)
if err != nil {
return nil, err
return nil, nil, err
}
usedTables = append(usedTables, re.Tables...)
}
Expand All @@ -876,11 +909,39 @@ func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam
if len(usedCols) > 0 {
usedTables = append(usedTables, TableUsed{Name: tableName, Alias: tableAlias})
if err := validateTableColumns(ctx, usedTables, usedCols); err != nil {
return nil, err
return nil, nil, err
}
}

return queryParams, nil
return queryParams, usedCols, nil
}

func parseCTE(ctx VetContext, with *pg_query.WithClause) error {
for _, cteNode := range with.Ctes {
cte := cteNode.GetCommonTableExpr()
query := cte.GetCtequery()
_, cols, err := validateSqlQuery(ctx, query)
if err != nil {
return err
}

var columns map[string]schema.Column
if cols != nil {
columns = make(map[string]schema.Column)
for _, col := range cols {
columns[col.Column] = schema.Column{
Name: col.Column,
}
}
}

ctx.InnerSchema.Tables[cte.Ctename] = schema.Table{
Name: cte.Ctename,
Columns: columns,
ReadOnly: true,
}
}
return nil
}

func ValidateSqlQuery(ctx VetContext, queryStr string) ([]QueryParam, error) {
Expand All @@ -893,26 +954,36 @@ func ValidateSqlQuery(ctx VetContext, queryStr string) ([]QueryParam, error) {
return nil, fmt.Errorf("query contained more than one statement")
}

var raw *pg_query.RawStmt = tree.Stmts[0]
params, _, err := validateSqlQuery(ctx, tree.Stmts[0].Stmt)
return params, err
}

func validateSqlQuery(ctx VetContext, node *pg_query.Node) ([]QueryParam, []ColumnUsed, error) {

switch {
case raw.Stmt.GetSelectStmt() != nil:
qparams, _, err := validateSelectStmt(ctx, raw.Stmt.GetSelectStmt())
return qparams, err
case raw.Stmt.GetUpdateStmt() != nil:
return validateUpdateStmt(ctx, raw.Stmt.GetUpdateStmt())
case raw.Stmt.GetInsertStmt() != nil:
return validateInsertStmt(ctx, raw.Stmt.GetInsertStmt())
case raw.Stmt.GetDeleteStmt() != nil:
return validateDeleteStmt(ctx, raw.Stmt.GetDeleteStmt())
case raw.Stmt.GetDropStmt() != nil:
case raw.Stmt.GetTruncateStmt() != nil:
case raw.Stmt.GetAlterTableStmt() != nil:
case raw.Stmt.GetCreateSchemaStmt() != nil:
case raw.Stmt.GetVariableSetStmt() != nil:
case node.GetSelectStmt() != nil:
qparams, targetCols, err := validateSelectStmt(ctx, node.GetSelectStmt())
var cused []ColumnUsed
for _, tcol := range targetCols {
cused = append(cused, ColumnUsed{Column: tcol.Name})
}
return qparams, cused, err
case node.GetUpdateStmt() != nil:
return validateUpdateStmt(ctx, node.GetUpdateStmt())
case node.GetInsertStmt() != nil:
return validateInsertStmt(ctx, node.GetInsertStmt())
case node.GetDeleteStmt() != nil:
return validateDeleteStmt(ctx, node.GetDeleteStmt())
case node.GetDropStmt() != nil:
case node.GetTruncateStmt() != nil:
case node.GetAlterTableStmt() != nil:
case node.GetCreateSchemaStmt() != nil:
case node.GetVariableSetStmt() != nil:

// TODO: check for invalid pg variables
default:
return nil, fmt.Errorf("unsupported statement: %v", reflect.TypeOf(raw.Stmt))
return nil, nil, fmt.Errorf("unsupported statement: %v", reflect.TypeOf(node))
}

return nil, nil
return nil, nil, nil
}
27 changes: 26 additions & 1 deletion pkg/vet/vet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ func TestSelect(t *testing.T) {
`SELECT id, f.id, coalesce(bzz.created_at,0)
FROM foo as f
LEFT JOIN LATERAL (
SELECT *, created_at, b.created_at, coalesce(baz_count,0), coalesce(baz_count,0) AS b_created_at
SELECT *, created_at, b.created_at, coalesce(baz_count,0), coalesce(baz_count,0) AS b_created_at
FROM baz b
) bzz ON true
WHERE value IS NULL`,
Expand All @@ -447,6 +447,17 @@ func TestSelect(t *testing.T) {
WHERE f.id = b.id) bzz ON true
WHERE value IS NULL`,
},
{
"select CTE",
`WITH cte1 AS (SELECT id FROM foo)
SELECT id FROM cte1`,
},
{
"select 2 CTEs",
`WITH cte1 AS (SELECT id FROM foo),
cte2 AS (SELECT value FROM foo)
SELECT c1.id, c2.value FROM cte1 c1, cte2 c2`,
},
}

for _, tcase := range testCases {
Expand Down Expand Up @@ -494,6 +505,15 @@ func TestUpdate(t *testing.T) {
"update with returning",
`UPDATE foo SET id=1 RETURNING value`,
},
{
"update alias with returning",
`UPDATE foo f SET id=1 RETURNING f.value`,
},
{
"update CTE",
`WITH cte1 AS (SELECT id FROM foo)
UPDATE foo SET value='bar' FROM cte1 WHERE foo.id = cte1.id`,
},
}

for _, tcase := range testCases {
Expand Down Expand Up @@ -590,6 +610,11 @@ func TestDelete(t *testing.T) {
"delete using with aliases",
`DELETE FROM foo AS f USING bar b WHERE f.id = b.id`,
},
{
"delete CTE",
`WITH cte1 AS (SELECT id FROM foo)
DELETE FROM foo USING cte1 WHERE foo.id = cte1.id`,
},
}

for _, tcase := range testCases {
Expand Down
Loading