From 0ad0d1262e4a9bdb161b96f23958b3adb9982f05 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 22 Nov 2024 10:30:43 +0100 Subject: [PATCH] feat: also analyze deletes Signed-off-by: Andres Taylor --- go/transactions/transactions.go | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/go/transactions/transactions.go b/go/transactions/transactions.go index 02a2803..b153793 100644 --- a/go/transactions/transactions.go +++ b/go/transactions/transactions.go @@ -215,8 +215,10 @@ func (s *state) consume(ch <-chan []sqlparser.Statement, wg *sync.WaitGroup) { switch query := query.(type) { case *sqlparser.Update: s.consumeUpdate(query, st, n, tx) + case *sqlparser.Delete: + s.consumeDelete(query, st, n, tx) default: - panic("not supported for now") + panic(fmt.Sprintf("not supported for now: %T", query)) } } s.addSignature(tx) @@ -244,11 +246,28 @@ func (s *state) consumeUpdate(query *sqlparser.Update, st *semantics.SemTable, n // Find all predicates in the where clause that use a column and a literal tx.addPredicate(getPredicates(query.Where.Expr, st, n)) + query.Where = normalizeWhere(query.Where, n) +} + +func (s *state) consumeDelete(del *sqlparser.Delete, st *semantics.SemTable, n *normalizer, tx *TxSignature) { + defer func() { + tx.Queries = append(tx.Queries, sqlparser.String(del)) + }() - var newWhere sqlparser.Where - wheres := sqlparser.SplitAndExpression(nil, query.Where.Expr) - for _, where := range wheres { - switch cmp := where.(type) { + if del.Where == nil { + return + } + + // Find all predicates in the where clause that use a column and a literal + tx.addPredicate(getPredicates(del.Where.Expr, st, n)) + del.Where = normalizeWhere(del.Where, n) +} + +func normalizeWhere(where *sqlparser.Where, n *normalizer) (newWhere *sqlparser.Where) { + newWhere = new(sqlparser.Where) + predicates := sqlparser.SplitAndExpression(nil, where.Expr) + for _, predicate := range predicates { + switch cmp := predicate.(type) { case *sqlparser.ComparisonExpr: lhs, lhsOK := cmp.Left.(*sqlparser.Literal) rhs, rhsOK := cmp.Right.(*sqlparser.Literal) @@ -270,10 +289,10 @@ func (s *state) consumeUpdate(query *sqlparser.Update, st *semantics.SemTable, n } newWhere.Expr = sqlparser.AndExpressions(newWhere.Expr, &newCmp) default: - newWhere.Expr = sqlparser.AndExpressions(newWhere.Expr, where) + newWhere.Expr = sqlparser.AndExpressions(newWhere.Expr, predicate) } } - query.Where = &newWhere + return } func (s *state) addSignature(tx *TxSignature) {