Skip to content

Commit

Permalink
Implement Convergent DML Migration Directive (#159)
Browse files Browse the repository at this point in the history
* Convergent DML Directive

* Update pkg/spanner/migration.go

Co-authored-by: Rory Quinn <[email protected]>

---------

Co-authored-by: Rory Quinn <[email protected]>
  • Loading branch information
tglennan and RoryQ authored Jan 6, 2025
1 parent 5eb6036 commit 36b8944
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 37 deletions.
43 changes: 42 additions & 1 deletion pkg/spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,8 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l
}
}

switch m.Kind {
statementKind := cmp.Or(m.Directives.StatementKind, m.Kind)
switch statementKind {
case StatementKindDDL:
if err := c.ApplyDDL(ctx, m.Statements); err != nil {
return nil, &Error{
Expand Down Expand Up @@ -696,6 +697,18 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l
}
}

migrationsOutput[m.FileName] = migrationInfo{
RowsAffected: rowsAffected,
}
case StatementKindConvergentDML:
rowsAffected, err := convergentApply(ctx, c.ApplyDML, m.Statements, m.Directives.Concurrency)
if err != nil {
return nil, &Error{
Code: ErrorCodeExecuteMigrations,
err: err,
}
}

migrationsOutput[m.FileName] = migrationInfo{
RowsAffected: rowsAffected,
}
Expand Down Expand Up @@ -1108,3 +1121,31 @@ func (c *Client) releaseMigrationLock(ctx context.Context, tableName, lockIdenti
}
return nil
}

func convergentApply(ctx context.Context, applyDMLFunc func(context.Context, []string) (int64, error), statements []string, concurrency int) (int64, error) {
concurrency = max(concurrency, 1)
p := pool.New().WithMaxGoroutines(concurrency).WithErrors()

// Apply each statement in the migration in a separate transaction.
var totalRowCount atomic.Int64
for _, statement := range statements {
p.Go(func() error {
for {
rowCount, err := applyDMLFunc(ctx, []string{statement})
if err != nil {
return err
}

if rowCount == 0 {
return nil
}
totalRowCount.Add(rowCount)
}
})
}

if err := p.Wait(); err != nil {
return 0, err
}
return totalRowCount.Load(), nil
}
121 changes: 114 additions & 7 deletions pkg/spanner/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
package spanner

import (
"cmp"
"context"
"fmt"
"os"
"reflect"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -238,22 +240,28 @@ func TestExecuteMigrations(t *testing.T) {
ensureMigrationVersionRecord(t, ctx, client, 2, false)
ensureMigrationHistoryRecord(t, ctx, client, 2, false)

// execute remaining migrations
if migrationsOutput, err = client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable, 1); err != nil {
t.Fatalf("failed to execute migration: %v", err)
}

if want, got := int64(1), migrationsOutput["000003.sql"].RowsAffected; want != got {
t.Errorf("want %d, but got %d", want, got)
t.Errorf("migration %q: want %d rows affected, but got %d", "000003.sql", want, got)
}

// ensure that 000003.sql and 000004.sql have been applied.
// ensure that 000003.sql, 000004.sql, and 000005.sql have been applied.
ensureMigrationColumn(t, ctx, client, "LastName", "STRING(MAX)", "NO")
ensureMigrationVersionRecord(t, ctx, client, 4, false)
ensureMigrationHistoryRecord(t, ctx, client, 4, false)
ensureMigrationVersionRecord(t, ctx, client, 5, false)
ensureMigrationHistoryRecord(t, ctx, client, 5, false)

// ensure that schema is not changed and ExecuteMigrate is safely finished even though no migrations should be applied.
ensureMigrationColumn(t, ctx, client, "LastName", "STRING(MAX)", "NO")
ensureMigrationVersionRecord(t, ctx, client, 4, false)
ensureMigrationVersionRecord(t, ctx, client, 5, false)

// ensure that 000005.sql has been applied, inserting an additional 4 rows
if want, got := int64(4), migrationsOutput["000005.sql"].RowsAffected; want != got {
t.Errorf("migration %q: want %d rows affected, but got %d", "000005.sql", want, got)
}
}

func ensureMigrationColumn(t *testing.T, ctx context.Context, client *Client, columnName, spannerType, isNullable string) {
Expand Down Expand Up @@ -751,15 +759,15 @@ func TestClient_RepairMigration(t *testing.T) {
defer done()

// add LastName NULLABLE
err := migrateUpDir(t, ctx, client, "testdata/migrations", 3, 4)
err := migrateUpDir(t, ctx, client, "testdata/migrations", 3, 4, 5)
require.NoError(t, err, "error running migrations")

// add row with NULL LastName
_, err = client.spannerClient.Apply(ctx, []*spanner.Mutation{spanner.Insert(singerTable, []string{"SingerID", "FirstName"}, []any{"ABC", "Fred"})})
require.NoError(t, err, "failed to insert row")

// make dirty with bad migration
err = migrateUpDir(t, ctx, client, "testdata/migrations", 3)
err = migrateUpDir(t, ctx, client, "testdata/migrations", 3, 5)
assert.EqualError(t, err, "Cannot specify a null value for column: LastName in table: Singers referenced by key: {String(\"ABC\")}")

assertDirtyCount := func(isDirty bool, expected int64) {
Expand Down Expand Up @@ -915,3 +923,102 @@ func Test_parseDDL1(t *testing.T) {
})
}
}

func Test_convergentApply(t *testing.T) {
// create a new "apply" function that executes each statement 11 times, the
// first 10 times returning 100 rows and the 11th time returning 0 rows.
newApplyFunc := func(t *testing.T) (func(context.Context, []string) (int64, error), *[]string) {
var callHistory []string
var mu sync.Mutex
applyFn := func(_ context.Context, s []string) (int64, error) {
mu.Lock()
defer mu.Unlock()
require.Len(t, s, 1)
callHistory = append(callHistory, s[0])

if countOccurrences(callHistory, s[0]) > 10 {
return 0, nil
}

// Return a static 100 "rows affected" per call the first 10 times
return 100, nil
}
return applyFn, &callHistory
}

t.Run("CallsEachStatementUntilNoRowsUpdated", func(t *testing.T) {
applyFn, callHistory := newApplyFunc(t)

stmts := []string{"stmt1", "stmt2", "stmt3"}
got, err := convergentApply(context.Background(), applyFn, stmts, 1)
require.NoError(t, err)

// Total number of rows should equal the number of times each statement was called
expectedRowCount := 3000 // 3 statements * (10 calls * 100 row count + 1 call * 0 row count)
assert.EqualValues(t, expectedRowCount, got)

// Validate that each statement was called the expected number of times
for _, stmt := range stmts {
assert.Equal(t, 11, countOccurrences(*callHistory, stmt))
}

// Validate that calls were ordered with concurrency = 1
assertOrdered(t, *callHistory)
})

t.Run("Concurrency>1", func(t *testing.T) {
applyFn, callHistory := newApplyFunc(t)

stmts := []string{"stmt1", "stmt2", "stmt3", "stmt4", "stmt5", "stmt6"}
got, err := convergentApply(context.Background(), applyFn, stmts, 10)
require.NoError(t, err)

// Total number of rows should equal the number of times each statement was called
expectedRowCount := 6000 // 6 statements * (10 calls * 100 row count + 1 call * 0 row count)
assert.EqualValues(t, expectedRowCount, got)

// Validate that each statement was called the expected number of times
for _, stmt := range stmts {
assert.Equal(t, 11, countOccurrences(*callHistory, stmt))
}

// Validate that calls were unordered with concurrency > 1
assertNotOrdered(t, *callHistory)
})

t.Run("Error", func(t *testing.T) {
applyFn := func(_ context.Context, s []string) (int64, error) {
return 0, assert.AnError
}
got, err := convergentApply(context.Background(), applyFn, []string{"stmt1", "stmt2", "stmt3"}, 1)
require.ErrorIs(t, err, assert.AnError)
assert.Zero(t, got)
})
}

func assertOrdered[T cmp.Ordered](t *testing.T, vs []T) {
assert.True(t, isOrdered(vs), "slice is unordered: %v", vs)
}

func assertNotOrdered[T cmp.Ordered](t *testing.T, vs []T) {
assert.False(t, isOrdered(vs), "slice is ordered: %v", vs)
}

func isOrdered[T cmp.Ordered](vs []T) bool {
for i := 0; i < len(vs)-1; i++ {
if vs[i] > vs[i+1] {
return false
}
}
return true
}

func countOccurrences(s []string, v string) int {
var c int
for _, i := range s {
if i == v {
c++
}
}
return c
}
27 changes: 21 additions & 6 deletions pkg/spanner/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ const (
StatementKindDDL StatementKind = "DDL"
StatementKindDML StatementKind = "DML"
StatementKindPartitionedDML StatementKind = "PartitionedDML"
// StatementKindConvergentDML repeatedly executes all statements in
// the migration until no more rows are affected. Each statement is executed
// in its own transaction, and the concurrency can be configured via the
// @wrench.Concurrency directive.
StatementKindConvergentDML StatementKind = "ConvergentDML"
)

type (
Expand All @@ -88,7 +93,12 @@ type (

// MigrationDirectives configures how the migration should be executed.
MigrationDirectives struct {
placeholder string
// StatementKind overrides the auto-detected statement kind.
// This can be used to customise how migrations are executed.
StatementKind StatementKind
// Kind defines the execution concurrency. Only applicable when
// StatementKind is StatementKindConvergentDML.
Concurrency int
}

Migrations []*Migration
Expand Down Expand Up @@ -369,9 +379,8 @@ func removeCommentsAndTrim(sql string) (string, error) {
// @wrench.{key}={value} from the migration preamble.
func parseMigrationDirectives(migration string) (MigrationDirectives, error) {
const (
// placeholderKey is a placeholder to validate parsing until a directive
// is implemented.
placeholderKey = "TODO"
concurrencyKey = "Concurrency"
statementKindKey = "StatementKind"
)

// matches a migration directive in the format @wrench.{key}={value}
Expand All @@ -382,8 +391,14 @@ func parseMigrationDirectives(migration string) (MigrationDirectives, error) {
for _, match := range directiveMatches {
key, val := match["Key"], match["Value"]
switch key {
case placeholderKey:
directives.placeholder = val
case statementKindKey:
directives.StatementKind = StatementKind(val)
case concurrencyKey:
concurrency, err := strconv.Atoi(val)
if err != nil || concurrency < 1 {
return MigrationDirectives{}, fmt.Errorf("invalid concurrency value: %s", val)
}
directives.Concurrency = concurrency
default:
return directives, fmt.Errorf("unknown migration directive: %s", key)
}
Expand Down
Loading

0 comments on commit 36b8944

Please sign in to comment.