Skip to content

Commit

Permalink
apply partitioned dml via worker pool
Browse files Browse the repository at this point in the history
  • Loading branch information
RoryQ committed Mar 18, 2024
1 parent 67e894d commit e96357d
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 83 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.PHONY: test
test: _spanner-up
go test -race -count=1 ./...
go test -race -v -count=1 ./...
-@make _spanner-down


Expand Down
27 changes: 14 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,20 @@ Available Commands:
repair If a migration has failed, clean up any schema changes manually then repair the history with this command
Flags:
--credentials-file string Specify Credentials File
--database string Cloud Spanner database name (optional. if not set, will use $SPANNER_DATABASE_ID value)
--directory string Directory that schema file placed (required)
-h, --help help for wrench
--instance string Cloud Spanner instance name (optional. if not set, will use $SPANNER_INSTANCE_ID value)
--lock-identifier string Random identifier used to lock migration operations to a single wrench process. (optional. if not set then it will be generated) (default "58a4394a-19f9-4dbf-880d-20b6cf169d46")
--project string GCP project id (optional. if not set, will use $SPANNER_PROJECT_ID or $GOOGLE_CLOUD_PROJECT value)
--schema-file string Name of schema file (optional. if not set, will use default 'schema.sql' file name)
--sequence-interval uint16 Used to generate the next migration id. Rounds up to the next interval. (optional. if not set, will use $WRENCH_SEQUENCE_INTERVAL or default to 1) (default 1)
--static-data-tables-file string File containing list of static data tables to track (optional)
--stmt-timeout duration Set a non-default timeout for statement execution
--verbose Used to indicate whether to output Migration information during a migration
-v, --version version for wrench
--credentials-file string Specify Credentials File
--database string Cloud Spanner database name (optional. if not set, will use $SPANNER_DATABASE_ID value)
--directory string Directory that schema file placed (required)
-h, --help help for wrench
--instance string Cloud Spanner instance name (optional. if not set, will use $SPANNER_INSTANCE_ID value)
--lock-identifier string Random identifier used to lock migration operations to a single wrench process. (optional. if not set then it will be generated) (default "58a4394a-19f9-4dbf-880d-20b6cf169d46")
--partitioned-dml-concurrency uint16 Set the concurrency for Partitioned-DML statements. (optional. if not set, will use $WRENCH_PARTITIONED_DML_CONCURRENCY or default to 1) (default 1)
--project string GCP project id (optional. if not set, will use $SPANNER_PROJECT_ID or $GOOGLE_CLOUD_PROJECT value)
--schema-file string Name of schema file (optional. if not set, will use default 'schema.sql' file name)
--sequence-interval uint16 Used to generate the next migration id. Rounds up to the next interval. (optional. if not set, will use $WRENCH_SEQUENCE_INTERVAL or default to 1) (default 1)
--static-data-tables-file string File containing list of static data tables to track (optional)
--stmt-timeout duration Set a non-default timeout for statement execution
--verbose Used to indicate whether to output Migration information during a migration
-v, --version version for wrench
Use "wrench [command] --help" for more information about a command.
```
6 changes: 4 additions & 2 deletions cmd/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"os"

"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -78,15 +79,16 @@ func apply(c *cobra.Command, _ []string) error {
}

// apply dml
dml, err := ioutil.ReadFile(dmlFile)
dml, err := os.ReadFile(dmlFile)
if err != nil {
return &Error{
err: err,
cmd: c,
}
}

numAffectedRows, err := client.ApplyDMLFile(ctx, dml, partitioned)
concurrency := int(partitionedDMLConcurrency)
numAffectedRows, err := client.ApplyDMLFile(ctx, dml, partitioned, concurrency)
if err != nil {
return &Error{
err: err,
Expand Down
39 changes: 20 additions & 19 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,26 @@ import (
)

const (
flagNameProject = "project"
flagNameInstance = "instance"
flagNameDatabase = "database"
flagNameDirectory = "directory"
flagSkipVersions = "skip-versions"
flagNameCreateNoPrompt = "no-prompt"
flagCredentialsFile = "credentials-file"
flagStaticDataTablesFile = "static-data-tables-file"
flagNameSchemaFile = "schema-file"
flagLockIdentifier = "lock-identifier"
flagSequenceInterval = "sequence-interval"
flagStmtTimeout = "stmt-timeout"
flagVerbose = "verbose"
flagDDLFile = "ddl"
flagDMLFile = "dml"
flagPartitioned = "partitioned"
flagSpannerEmulatorImage = "spanner-emulator-image"
defaultSchemaFileName = "schema.sql"
defaultStaticDataTablesFile = "{wrench.json|static_data_tables.txt}"
flagNameProject = "project"
flagNameInstance = "instance"
flagNameDatabase = "database"
flagNameDirectory = "directory"
flagSkipVersions = "skip-versions"
flagNameCreateNoPrompt = "no-prompt"
flagCredentialsFile = "credentials-file"
flagStaticDataTablesFile = "static-data-tables-file"
flagNameSchemaFile = "schema-file"
flagLockIdentifier = "lock-identifier"
flagSequenceInterval = "sequence-interval"
flagStmtTimeout = "stmt-timeout"
flagPartitionedDMLConcurrency = "partitioned-dml-concurrency"
flagVerbose = "verbose"
flagDDLFile = "ddl"
flagDMLFile = "dml"
flagPartitioned = "partitioned"
flagSpannerEmulatorImage = "spanner-emulator-image"
defaultSchemaFileName = "schema.sql"
defaultStaticDataTablesFile = "{wrench.json|static_data_tables.txt}"
)

func newSpannerClient(ctx context.Context, c *cobra.Command) (*spanner.Client, error) {
Expand Down
4 changes: 3 additions & 1 deletion cmd/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ func migrateUp(c *cobra.Command, args []string) error {
}
}

concurrency := int(partitionedDMLConcurrency)

var migrationsOutput spanner.MigrationsOutput
switch status {
case spanner.ExistingMigrationsUpgradeStarted:
Expand All @@ -220,7 +222,7 @@ func migrateUp(c *cobra.Command, args []string) error {
return err
}
case spanner.ExistingMigrationsUpgradeCompleted:
migrationsOutput, err = client.ExecuteMigrations(ctx, migrations, limit, migrationTableName)
migrationsOutput, err = client.ExecuteMigrations(ctx, migrations, limit, migrationTableName, concurrency)
if err != nil {
return err
}
Expand Down
32 changes: 21 additions & 11 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,18 @@ var (
)

var (
project string
instance string
database string
directory string
schemaFile string
credentialsFile string
staticDataTablesFile string
lockIdentifier string
sequenceInterval uint16
stmtTimeout time.Duration
verbose bool
project string
instance string
database string
directory string
schemaFile string
credentialsFile string
staticDataTablesFile string
lockIdentifier string
sequenceInterval uint16
partitionedDMLConcurrency uint16
stmtTimeout time.Duration
verbose bool
)

var rootCmd = &cobra.Command{
Expand Down Expand Up @@ -89,6 +90,7 @@ func init() {
rootCmd.PersistentFlags().Uint16Var(&sequenceInterval, flagSequenceInterval, getSequenceInterval(), "Used to generate the next migration id. Rounds up to the next interval. (optional. if not set, will use $WRENCH_SEQUENCE_INTERVAL or default to 1)")
rootCmd.PersistentFlags().BoolVar(&verbose, flagVerbose, false, "Used to indicate whether to output Migration information during a migration")
rootCmd.PersistentFlags().DurationVar(&stmtTimeout, flagStmtTimeout, getStmtTimeout(), "Set a non-default timeout for statement execution")
rootCmd.PersistentFlags().Uint16Var(&partitionedDMLConcurrency, flagPartitionedDMLConcurrency, getPartitionedDMLConcurrency(), "Set the concurrency for Partitioned-DML statements. (optional. if not set, will use $WRENCH_PARTITIONED_DML_CONCURRENCY or default to 1)")

rootCmd.Version = versioninfo.Version
rootCmd.SetVersionTemplate(versionTemplate)
Expand Down Expand Up @@ -137,3 +139,11 @@ func getStmtTimeout() time.Duration {
}
return i
}

func getPartitionedDMLConcurrency() uint16 {
i, err := strconv.Atoi(os.Getenv("WRENCH_PARTITIONED_DML_CONCURRENCY"))
if err != nil {
return 1
}
return uint16(i)
}
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ require (
github.com/carlmjohnson/versioninfo v0.22.5
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/googleapis/gax-go/v2 v2.12.2
github.com/kennygrant/sanitize v1.2.4
github.com/ory/dockertest/v3 v3.10.0
github.com/sourcegraph/conc v0.3.0
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.9.0
Expand Down Expand Up @@ -47,7 +49,6 @@ require (
github.com/google/s2a-go v0.1.7 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
Expand All @@ -69,6 +70,8 @@ require (
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/mod v0.14.0 // indirect
golang.org/x/net v0.22.0 // indirect
Expand All @@ -87,4 +90,4 @@ require (
gopkg.in/yaml.v3 v3.0.1 // indirect
)

go 1.21
go 1.22
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
Expand Down Expand Up @@ -191,6 +193,10 @@ go.opentelemetry.io/otel/sdk v1.22.0 h1:6coWHw9xw7EfClIC/+O31R8IY3/+EiRFHevmHafB
go.opentelemetry.io/otel/sdk v1.22.0/go.mod h1:iu7luyVGYovrRpe2fmj3CVKouQNdTOkxtLzPvPz1DOc=
go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
Expand Down
54 changes: 34 additions & 20 deletions pkg/spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,31 @@
package spanner

import (
"cmp"
"context"
"errors"
"fmt"
"regexp"
"sort"
"strings"
"sync/atomic"
"time"

"github.com/googleapis/gax-go/v2"
"github.com/sourcegraph/conc/pool"

"github.com/roryq/wrench/pkg/spannerz"

"google.golang.org/grpc/codes"

"github.com/roryq/wrench/pkg/spanner/dataloader"

"cloud.google.com/go/spanner"
admin "cloud.google.com/go/spanner/admin/database/apiv1"
vkit "cloud.google.com/go/spanner/apiv1"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
databasepb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"

"github.com/roryq/wrench/pkg/spanner/dataloader"
)

const (
Expand Down Expand Up @@ -403,11 +407,11 @@ func (c *Client) ApplyDDL(ctx context.Context, statements []string) error {
return nil
}

func (c *Client) ApplyDMLFile(ctx context.Context, ddl []byte, partitioned bool) (int64, error) {
statements := toStatements(ddl)
func (c *Client) ApplyDMLFile(ctx context.Context, dml []byte, partitioned bool, concurrency int) (int64, error) {
statements := toStatements(dml)

if partitioned {
return c.ApplyPartitionedDML(ctx, statements)
return c.ApplyPartitionedDML(ctx, statements, concurrency)
}
return c.ApplyDML(ctx, statements)
}
Expand Down Expand Up @@ -436,24 +440,34 @@ func (c *Client) ApplyDML(ctx context.Context, statements []string) (int64, erro
return numAffectedRows, nil
}

func (c *Client) ApplyPartitionedDML(ctx context.Context, statements []string) (int64, error) {
numAffectedRows := int64(0)
func (c *Client) ApplyPartitionedDML(ctx context.Context, statements []string, concurrency int) (int64, error) {
numAffectedRows := atomic.Int64{}

concurrency = cmp.Or(concurrency, 1)
p := pool.New().WithMaxGoroutines(concurrency).WithErrors()
for _, s := range statements {
num, err := c.spannerClient.PartitionedUpdate(ctx, spanner.Statement{
SQL: s,
})
if err != nil {
return numAffectedRows, &Error{
Code: ErrorCodeUpdatePartitionedDML,
err: err,
p.Go(func() error {
num, err := c.spannerClient.PartitionedUpdate(ctx, spanner.Statement{
SQL: s,
})
if err != nil {
return err
}
}

numAffectedRows += num
numAffectedRows.Add(num)
return nil
})
}

return numAffectedRows, nil
err := p.Wait()
if err != nil {
return numAffectedRows.Load(), &Error{
Code: ErrorCodeUpdatePartitionedDML,
err: err,
}
}

return numAffectedRows.Load(), nil
}

func (c *Client) UpgradeExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string) (MigrationsOutput, error) {
Expand All @@ -462,7 +476,7 @@ func (c *Client) UpgradeExecuteMigrations(ctx context.Context, migrations Migrat
return nil, err
}

migrationsOutput, err := c.ExecuteMigrations(ctx, migrations, limit, tableName)
migrationsOutput, err := c.ExecuteMigrations(ctx, migrations, limit, tableName, 1)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -591,7 +605,7 @@ func (i MigrationsOutput) String() string {
return fmt.Sprintf("%s\n", output)
}

func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string) (MigrationsOutput, error) {
func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string, partitionedConcurrency int) (MigrationsOutput, error) {
sort.Sort(migrations)

version, dirty, err := c.GetSchemaMigrationVersion(ctx, tableName)
Expand Down Expand Up @@ -663,7 +677,7 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l
RowsAffected: rowsAffected,
}
case statementKindPartitionedDML:
rowsAffected, err := c.ApplyPartitionedDML(ctx, m.Statements)
rowsAffected, err := c.ApplyPartitionedDML(ctx, m.Statements, partitionedConcurrency)
if err != nil {
return nil, &Error{
Code: ErrorCodeExecuteMigrations,
Expand Down
Loading

0 comments on commit e96357d

Please sign in to comment.