Skip to content

Commit

Permalink
feat: also analyze deletes
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Nov 22, 2024
1 parent e769a72 commit 0ad0d12
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions go/transactions/transactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,10 @@ func (s *state) consume(ch <-chan []sqlparser.Statement, wg *sync.WaitGroup) {
switch query := query.(type) {
case *sqlparser.Update:
s.consumeUpdate(query, st, n, tx)
case *sqlparser.Delete:
s.consumeDelete(query, st, n, tx)
default:
panic("not supported for now")
panic(fmt.Sprintf("not supported for now: %T", query))
}
}
s.addSignature(tx)
Expand Down Expand Up @@ -244,11 +246,28 @@ func (s *state) consumeUpdate(query *sqlparser.Update, st *semantics.SemTable, n

// Find all predicates in the where clause that use a column and a literal
tx.addPredicate(getPredicates(query.Where.Expr, st, n))
query.Where = normalizeWhere(query.Where, n)
}

func (s *state) consumeDelete(del *sqlparser.Delete, st *semantics.SemTable, n *normalizer, tx *TxSignature) {
defer func() {
tx.Queries = append(tx.Queries, sqlparser.String(del))
}()

var newWhere sqlparser.Where
wheres := sqlparser.SplitAndExpression(nil, query.Where.Expr)
for _, where := range wheres {
switch cmp := where.(type) {
if del.Where == nil {
return
}

// Find all predicates in the where clause that use a column and a literal
tx.addPredicate(getPredicates(del.Where.Expr, st, n))
del.Where = normalizeWhere(del.Where, n)
}

func normalizeWhere(where *sqlparser.Where, n *normalizer) (newWhere *sqlparser.Where) {
newWhere = new(sqlparser.Where)
predicates := sqlparser.SplitAndExpression(nil, where.Expr)
for _, predicate := range predicates {
switch cmp := predicate.(type) {
case *sqlparser.ComparisonExpr:
lhs, lhsOK := cmp.Left.(*sqlparser.Literal)
rhs, rhsOK := cmp.Right.(*sqlparser.Literal)
Expand All @@ -270,10 +289,10 @@ func (s *state) consumeUpdate(query *sqlparser.Update, st *semantics.SemTable, n
}
newWhere.Expr = sqlparser.AndExpressions(newWhere.Expr, &newCmp)
default:
newWhere.Expr = sqlparser.AndExpressions(newWhere.Expr, where)
newWhere.Expr = sqlparser.AndExpressions(newWhere.Expr, predicate)
}
}
query.Where = &newWhere
return
}

func (s *state) addSignature(tx *TxSignature) {
Expand Down

0 comments on commit 0ad0d12

Please sign in to comment.