From 36b89441d9fe3bfcf3b69a6bb86b66e12da3e13d Mon Sep 17 00:00:00 2001 From: tglennan <7598704+tglennan@users.noreply.github.com> Date: Mon, 6 Jan 2025 15:56:59 +1100 Subject: [PATCH] Implement Convergent DML Migration Directive (#159) * Convergent DML Directive * Update pkg/spanner/migration.go Co-authored-by: Rory Quinn --------- Co-authored-by: Rory Quinn --- pkg/spanner/client.go | 43 +++++++- pkg/spanner/client_test.go | 121 +++++++++++++++++++-- pkg/spanner/migration.go | 27 ++++- pkg/spanner/migration_test.go | 68 ++++++++---- pkg/spanner/testdata/migrations/000005.sql | 9 ++ 5 files changed, 231 insertions(+), 37 deletions(-) create mode 100644 pkg/spanner/testdata/migrations/000005.sql diff --git a/pkg/spanner/client.go b/pkg/spanner/client.go index 72e46f8..f9bdb00 100644 --- a/pkg/spanner/client.go +++ b/pkg/spanner/client.go @@ -667,7 +667,8 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l } } - switch m.Kind { + statementKind := cmp.Or(m.Directives.StatementKind, m.Kind) + switch statementKind { case StatementKindDDL: if err := c.ApplyDDL(ctx, m.Statements); err != nil { return nil, &Error{ @@ -696,6 +697,18 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l } } + migrationsOutput[m.FileName] = migrationInfo{ + RowsAffected: rowsAffected, + } + case StatementKindConvergentDML: + rowsAffected, err := convergentApply(ctx, c.ApplyDML, m.Statements, m.Directives.Concurrency) + if err != nil { + return nil, &Error{ + Code: ErrorCodeExecuteMigrations, + err: err, + } + } + migrationsOutput[m.FileName] = migrationInfo{ RowsAffected: rowsAffected, } @@ -1108,3 +1121,31 @@ func (c *Client) releaseMigrationLock(ctx context.Context, tableName, lockIdenti } return nil } + +func convergentApply(ctx context.Context, applyDMLFunc func(context.Context, []string) (int64, error), statements []string, concurrency int) (int64, error) { + concurrency = max(concurrency, 1) + p := pool.New().WithMaxGoroutines(concurrency).WithErrors() + + // Apply each statement in the migration in a separate transaction. + var totalRowCount atomic.Int64 + for _, statement := range statements { + p.Go(func() error { + for { + rowCount, err := applyDMLFunc(ctx, []string{statement}) + if err != nil { + return err + } + + if rowCount == 0 { + return nil + } + totalRowCount.Add(rowCount) + } + }) + } + + if err := p.Wait(); err != nil { + return 0, err + } + return totalRowCount.Load(), nil +} diff --git a/pkg/spanner/client_test.go b/pkg/spanner/client_test.go index 89fb1ae..5fd7a58 100644 --- a/pkg/spanner/client_test.go +++ b/pkg/spanner/client_test.go @@ -20,10 +20,12 @@ package spanner import ( + "cmp" "context" "fmt" "os" "reflect" + "sync" "testing" "time" @@ -238,22 +240,28 @@ func TestExecuteMigrations(t *testing.T) { ensureMigrationVersionRecord(t, ctx, client, 2, false) ensureMigrationHistoryRecord(t, ctx, client, 2, false) + // execute remaining migrations if migrationsOutput, err = client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable, 1); err != nil { t.Fatalf("failed to execute migration: %v", err) } if want, got := int64(1), migrationsOutput["000003.sql"].RowsAffected; want != got { - t.Errorf("want %d, but got %d", want, got) + t.Errorf("migration %q: want %d rows affected, but got %d", "000003.sql", want, got) } - // ensure that 000003.sql and 000004.sql have been applied. + // ensure that 000003.sql, 000004.sql, and 000005.sql have been applied. ensureMigrationColumn(t, ctx, client, "LastName", "STRING(MAX)", "NO") - ensureMigrationVersionRecord(t, ctx, client, 4, false) - ensureMigrationHistoryRecord(t, ctx, client, 4, false) + ensureMigrationVersionRecord(t, ctx, client, 5, false) + ensureMigrationHistoryRecord(t, ctx, client, 5, false) // ensure that schema is not changed and ExecuteMigrate is safely finished even though no migrations should be applied. ensureMigrationColumn(t, ctx, client, "LastName", "STRING(MAX)", "NO") - ensureMigrationVersionRecord(t, ctx, client, 4, false) + ensureMigrationVersionRecord(t, ctx, client, 5, false) + + // ensure that 000005.sql has been applied, inserting an additional 4 rows + if want, got := int64(4), migrationsOutput["000005.sql"].RowsAffected; want != got { + t.Errorf("migration %q: want %d rows affected, but got %d", "000005.sql", want, got) + } } func ensureMigrationColumn(t *testing.T, ctx context.Context, client *Client, columnName, spannerType, isNullable string) { @@ -751,7 +759,7 @@ func TestClient_RepairMigration(t *testing.T) { defer done() // add LastName NULLABLE - err := migrateUpDir(t, ctx, client, "testdata/migrations", 3, 4) + err := migrateUpDir(t, ctx, client, "testdata/migrations", 3, 4, 5) require.NoError(t, err, "error running migrations") // add row with NULL LastName @@ -759,7 +767,7 @@ func TestClient_RepairMigration(t *testing.T) { require.NoError(t, err, "failed to insert row") // make dirty with bad migration - err = migrateUpDir(t, ctx, client, "testdata/migrations", 3) + err = migrateUpDir(t, ctx, client, "testdata/migrations", 3, 5) assert.EqualError(t, err, "Cannot specify a null value for column: LastName in table: Singers referenced by key: {String(\"ABC\")}") assertDirtyCount := func(isDirty bool, expected int64) { @@ -915,3 +923,102 @@ func Test_parseDDL1(t *testing.T) { }) } } + +func Test_convergentApply(t *testing.T) { + // create a new "apply" function that executes each statement 11 times, the + // first 10 times returning 100 rows and the 11th time returning 0 rows. + newApplyFunc := func(t *testing.T) (func(context.Context, []string) (int64, error), *[]string) { + var callHistory []string + var mu sync.Mutex + applyFn := func(_ context.Context, s []string) (int64, error) { + mu.Lock() + defer mu.Unlock() + require.Len(t, s, 1) + callHistory = append(callHistory, s[0]) + + if countOccurrences(callHistory, s[0]) > 10 { + return 0, nil + } + + // Return a static 100 "rows affected" per call the first 10 times + return 100, nil + } + return applyFn, &callHistory + } + + t.Run("CallsEachStatementUntilNoRowsUpdated", func(t *testing.T) { + applyFn, callHistory := newApplyFunc(t) + + stmts := []string{"stmt1", "stmt2", "stmt3"} + got, err := convergentApply(context.Background(), applyFn, stmts, 1) + require.NoError(t, err) + + // Total number of rows should equal the number of times each statement was called + expectedRowCount := 3000 // 3 statements * (10 calls * 100 row count + 1 call * 0 row count) + assert.EqualValues(t, expectedRowCount, got) + + // Validate that each statement was called the expected number of times + for _, stmt := range stmts { + assert.Equal(t, 11, countOccurrences(*callHistory, stmt)) + } + + // Validate that calls were ordered with concurrency = 1 + assertOrdered(t, *callHistory) + }) + + t.Run("Concurrency>1", func(t *testing.T) { + applyFn, callHistory := newApplyFunc(t) + + stmts := []string{"stmt1", "stmt2", "stmt3", "stmt4", "stmt5", "stmt6"} + got, err := convergentApply(context.Background(), applyFn, stmts, 10) + require.NoError(t, err) + + // Total number of rows should equal the number of times each statement was called + expectedRowCount := 6000 // 6 statements * (10 calls * 100 row count + 1 call * 0 row count) + assert.EqualValues(t, expectedRowCount, got) + + // Validate that each statement was called the expected number of times + for _, stmt := range stmts { + assert.Equal(t, 11, countOccurrences(*callHistory, stmt)) + } + + // Validate that calls were unordered with concurrency > 1 + assertNotOrdered(t, *callHistory) + }) + + t.Run("Error", func(t *testing.T) { + applyFn := func(_ context.Context, s []string) (int64, error) { + return 0, assert.AnError + } + got, err := convergentApply(context.Background(), applyFn, []string{"stmt1", "stmt2", "stmt3"}, 1) + require.ErrorIs(t, err, assert.AnError) + assert.Zero(t, got) + }) +} + +func assertOrdered[T cmp.Ordered](t *testing.T, vs []T) { + assert.True(t, isOrdered(vs), "slice is unordered: %v", vs) +} + +func assertNotOrdered[T cmp.Ordered](t *testing.T, vs []T) { + assert.False(t, isOrdered(vs), "slice is ordered: %v", vs) +} + +func isOrdered[T cmp.Ordered](vs []T) bool { + for i := 0; i < len(vs)-1; i++ { + if vs[i] > vs[i+1] { + return false + } + } + return true +} + +func countOccurrences(s []string, v string) int { + var c int + for _, i := range s { + if i == v { + c++ + } + } + return c +} diff --git a/pkg/spanner/migration.go b/pkg/spanner/migration.go index b766661..7efab22 100644 --- a/pkg/spanner/migration.go +++ b/pkg/spanner/migration.go @@ -63,6 +63,11 @@ const ( StatementKindDDL StatementKind = "DDL" StatementKindDML StatementKind = "DML" StatementKindPartitionedDML StatementKind = "PartitionedDML" + // StatementKindConvergentDML repeatedly executes all statements in + // the migration until no more rows are affected. Each statement is executed + // in its own transaction, and the concurrency can be configured via the + // @wrench.Concurrency directive. + StatementKindConvergentDML StatementKind = "ConvergentDML" ) type ( @@ -88,7 +93,12 @@ type ( // MigrationDirectives configures how the migration should be executed. MigrationDirectives struct { - placeholder string + // StatementKind overrides the auto-detected statement kind. + // This can be used to customise how migrations are executed. + StatementKind StatementKind + // Kind defines the execution concurrency. Only applicable when + // StatementKind is StatementKindConvergentDML. + Concurrency int } Migrations []*Migration @@ -369,9 +379,8 @@ func removeCommentsAndTrim(sql string) (string, error) { // @wrench.{key}={value} from the migration preamble. func parseMigrationDirectives(migration string) (MigrationDirectives, error) { const ( - // placeholderKey is a placeholder to validate parsing until a directive - // is implemented. - placeholderKey = "TODO" + concurrencyKey = "Concurrency" + statementKindKey = "StatementKind" ) // matches a migration directive in the format @wrench.{key}={value} @@ -382,8 +391,14 @@ func parseMigrationDirectives(migration string) (MigrationDirectives, error) { for _, match := range directiveMatches { key, val := match["Key"], match["Value"] switch key { - case placeholderKey: - directives.placeholder = val + case statementKindKey: + directives.StatementKind = StatementKind(val) + case concurrencyKey: + concurrency, err := strconv.Atoi(val) + if err != nil || concurrency < 1 { + return MigrationDirectives{}, fmt.Errorf("invalid concurrency value: %s", val) + } + directives.Concurrency = concurrency default: return directives, fmt.Errorf("unknown migration directive: %s", key) } diff --git a/pkg/spanner/migration_test.go b/pkg/spanner/migration_test.go index bfd9fad..ef92ed2 100644 --- a/pkg/spanner/migration_test.go +++ b/pkg/spanner/migration_test.go @@ -40,8 +40,8 @@ func TestLoadMigrations(t *testing.T) { t.Fatal(err) } - if len(ms) != 3 { - t.Fatalf("migrations length want 3, but got %v", len(ms)) + if len(ms) != 4 { + t.Fatalf("migrations length want 4, but got %v", len(ms)) } testcases := []struct { @@ -73,7 +73,7 @@ func TestLoadMigrations(t *testing.T) { } func TestLoadMigrationsSkipVersion(t *testing.T) { - ms, err := LoadMigrations(filepath.Join("testdata", "migrations"), []uint{2, 3}, false) + ms, err := LoadMigrations(filepath.Join("testdata", "migrations"), []uint{2, 3, 4}, false) if err != nil { t.Fatal(err) } @@ -82,8 +82,8 @@ func TestLoadMigrationsSkipVersion(t *testing.T) { t.Fatalf("migrations length want 1, but got %v", len(ms)) } - if ms[0].Version != 4 { - t.Errorf("version want %v, but got %v", 4, ms[0].Version) + if ms[0].Version != 5 { + t.Errorf("version want %v, but got %v", 5, ms[0].Version) } } @@ -634,7 +634,7 @@ SELECT 1`, } func Test_parseMigrationDirectives(t *testing.T) { - const placeholderKey = "TODO" + const testStatementKind = StatementKind("Foo") tests := []struct { name string @@ -660,40 +660,48 @@ SELECT 1 FROM Foo`, name: "PreambleWithDirectives_BlockComment", data: fmt.Sprintf(` /* - @wrench.%s=%s + @wrench.StatementKind=%s + @wrench.Concurrency=123 */ -SELECT 1 FROM Foo`, placeholderKey, "value"), +SELECT 1 FROM Foo`, testStatementKind), want: MigrationDirectives{ - placeholder: "value", + StatementKind: testStatementKind, + Concurrency: 123, }, }, { name: "PreambleWithDirectives_LineComment", data: fmt.Sprintf(` --- @wrench.%s=%s -SELECT 1 FROM Foo`, placeholderKey, "value"), +-- @wrench.StatementKind=%s +-- @wrench.Concurrency=123 +SELECT 1 FROM Foo`, testStatementKind), want: MigrationDirectives{ - placeholder: "value", + StatementKind: testStatementKind, + Concurrency: 123, }, }, { name: "PreambleWithDirectives_DirectiveCommentIgnored", data: fmt.Sprintf(` /* - @wrench.%s=%s // This is ignored + @wrench.StatementKind=%s // This is ignored */ -SELECT 1 FROM Foo`, placeholderKey, "value"), +SELECT 1 FROM Foo +`, testStatementKind), want: MigrationDirectives{ - placeholder: "value", + StatementKind: testStatementKind, }, }, { name: "WhitespaceIgnored", data: fmt.Sprintf(` -/* @wrench.%s=%s */ -SELECT 1 FROM Foo`, placeholderKey, "value"), +/* @wrench.StatementKind=%s */ +-- @wrench.Concurrency=123 +SELECT 1 FROM Foo +`, testStatementKind), want: MigrationDirectives{ - placeholder: "value", + StatementKind: testStatementKind, + Concurrency: 123, }, }, { @@ -702,14 +710,17 @@ SELECT 1 FROM Foo`, placeholderKey, "value"), /* This is my migration! -@wrench.%s=%s +@wrench.StatementKind=%s Foo bar baz. +@wrench.Concurrency=123 */ -SELECT 1 FROM Foo`, placeholderKey, "value"), +SELECT 1 FROM Foo +`, testStatementKind), want: MigrationDirectives{ - placeholder: "value", + StatementKind: testStatementKind, + Concurrency: 123, }, }, } @@ -722,6 +733,17 @@ SELECT 1 FROM Foo`, placeholderKey, "value"), } t.Run("Errors", func(t *testing.T) { + t.Run("InvalidConcurrency", func(t *testing.T) { + got, err := parseMigrationDirectives(fmt.Sprintf(`/* +@wrench.StatementKind=%s +@wrench.Concurrency=abc +*/ +SELECT 1 FROM Foo +`, testStatementKind)) + assert.Zero(t, got) + assert.Error(t, err) + }) + t.Run("UnknownKey", func(t *testing.T) { got, err := parseMigrationDirectives(` -- @wrench.foo=bar @@ -825,5 +847,5 @@ block 3`}, t.Run(tt.name, func(t *testing.T) { assert.Equal(t, tt.want, extractPreamble(tt.data)) }) - } -} \ No newline at end of file + } +} diff --git a/pkg/spanner/testdata/migrations/000005.sql b/pkg/spanner/testdata/migrations/000005.sql new file mode 100644 index 0000000..6caa4be --- /dev/null +++ b/pkg/spanner/testdata/migrations/000005.sql @@ -0,0 +1,9 @@ +-- @wrench.StatementKind=ConvergentDML +-- @wrench.Concurrency=1 +-- +-- This statement inserts one new row if fewer than 5 rows are present in the table. +-- The StatementKind is set to "ConvergentDML" via a directive, instructing wrench to +-- repeatedly execute this statement until no more rows are affected, i.e. the table has 5 rows. +INSERT INTO Singers (SingerID, FirstName, LastName) +SELECT NextSingerID, CONCAT("Singer", CAST(TotalSingers AS STRING)), "" +FROM (SELECT GENERATE_UUID() AS NextSingerID, COUNT(1) AS TotalSingers FROM Singers HAVING TotalSingers < 5)