diff --git a/sqldb/migrations.go b/sqldb/migrations.go index 6b104f19bd..c1e12471e1 100644 --- a/sqldb/migrations.go +++ b/sqldb/migrations.go @@ -8,16 +8,71 @@ import ( "net/http" "strings" + "github.com/btcsuite/btclog" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" "github.com/golang-migrate/migrate/v4/source/httpfs" ) +// MigrationTarget is a functional option that can be passed to applyMigrations +// to specify a target version to migrate to. +type MigrationTarget func(mig *migrate.Migrate) error + +var ( + // TargetLatest is a MigrationTarget that migrates to the latest + // version available. + TargetLatest = func(mig *migrate.Migrate) error { + return mig.Up() + } + + // TargetVersion is a MigrationTarget that migrates to the given + // version. + TargetVersion = func(version uint) MigrationTarget { + return func(mig *migrate.Migrate) error { + return mig.Migrate(version) + } + } +) + +// migrationLogger is a logger that wraps the passed btclog.Logger so it can be +// used to log migrations. +type migrationLogger struct { + log btclog.Logger +} + +// Printf is like fmt.Printf. We map this to the target logger based on the +// current log level. +func (m *migrationLogger) Printf(format string, v ...interface{}) { + // Trim trailing newlines from the format. + format = strings.TrimRight(format, "\n") + + switch m.log.Level() { + case btclog.LevelTrace: + m.log.Tracef(format, v...) + case btclog.LevelDebug: + m.log.Debugf(format, v...) + case btclog.LevelInfo: + m.log.Infof(format, v...) + case btclog.LevelWarn: + m.log.Warnf(format, v...) + case btclog.LevelError: + m.log.Errorf(format, v...) + case btclog.LevelCritical: + m.log.Criticalf(format, v...) + case btclog.LevelOff: + } +} + +// Verbose should return true when verbose logging output is wanted +func (m *migrationLogger) Verbose() bool { + return m.log.Level() <= btclog.LevelDebug +} + // applyMigrations executes all database migration files found in the given file // system under the given path, using the passed database driver and database // name. func applyMigrations(fs fs.FS, driver database.Driver, path, - dbName string) error { + dbName string, targetVersion MigrationTarget) error { // With the migrate instance open, we'll create a new migration source // using the embedded file system stored in sqlSchemas. The library @@ -37,7 +92,15 @@ func applyMigrations(fs fs.FS, driver database.Driver, path, if err != nil { return err } - err = sqlMigrate.Up() + + migrationVersion, _, _ := sqlMigrate.Version() + log.Infof("Applying migrations from version=%v", migrationVersion) + + // Apply our local logger to the migration instance. + sqlMigrate.Log = &migrationLogger{log} + + // Execute the migration based on the target given. + err = targetVersion(sqlMigrate) if err != nil && !errors.Is(err, migrate.ErrNoChange) { return err } diff --git a/sqldb/postgres.go b/sqldb/postgres.go index e6e88c93bd..2db2c9dfdf 100644 --- a/sqldb/postgres.go +++ b/sqldb/postgres.go @@ -2,6 +2,7 @@ package sqldb import ( "database/sql" + "fmt" "net/url" "path" "strings" @@ -19,6 +20,17 @@ var ( // fully executed yet. So this time needs to be chosen correctly to be // longer than the longest expected individual test run time. DefaultPostgresFixtureLifetime = 10 * time.Minute + + // postgresSchemaReplacements is a map of schema strings that need to be + // replaced for postgres. This is needed because we write the schemas to + // work with sqlite primarily but in sqlc's own dialect, and postgres + // has some differences. + postgresSchemaReplacements = map[string]string{ + "BLOB": "BYTEA", + "INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY", + "BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY", + "TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE", + } ) // replacePasswordInDSN takes a DSN string and returns it with the password @@ -98,42 +110,43 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { rawDB.SetMaxIdleConns(maxConns) rawDB.SetConnMaxLifetime(connIdleLifetime) - if !cfg.SkipMigrations { - // Now that the database is open, populate the database with - // our set of schemas based on our embedded in-memory file - // system. - // - // First, we'll need to open up a new migration instance for - // our current target database: Postgres. - driver, err := postgres_migrate.WithInstance( - rawDB, &postgres_migrate.Config{}, - ) - if err != nil { - return nil, err - } - - postgresFS := newReplacerFS(sqlSchemas, map[string]string{ - "BLOB": "BYTEA", - "INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY", - "BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY", - "TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE", - }) - - err = applyMigrations( - postgresFS, driver, "sqlc/migrations", dbName, - ) - if err != nil { - return nil, err - } - } - queries := sqlc.New(rawDB) - return &PostgresStore{ + s := &PostgresStore{ cfg: cfg, BaseDB: &BaseDB{ DB: rawDB, Queries: queries, }, - }, nil + } + + // Now that the database is open, populate the database with our + // set of schemas based on our embedded in-memory file system. + if !cfg.SkipMigrations { + err := s.ExecuteMigrations(dbName, TargetLatest) + if err != nil { + return nil, fmt.Errorf("error executing migrations: %w", + err) + } + } + + return s, nil +} + +// ExecuteMigrations runs migrations for the Postgres database, depending on the +// target given, either all migrations or up to a given version. +func (s *PostgresStore) ExecuteMigrations(dbName string, + target MigrationTarget) error { + + driver, err := postgres_migrate.WithInstance( + s.DB, &postgres_migrate.Config{}, + ) + if err != nil { + return fmt.Errorf("error creating postgres migration: %w", err) + } + + postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements) + return applyMigrations( + postgresFS, driver, "sqlc/migrations", dbName, target, + ) } diff --git a/sqldb/sqlite.go b/sqldb/sqlite.go index 705d5cc47a..103374ad10 100644 --- a/sqldb/sqlite.go +++ b/sqldb/sqlite.go @@ -26,6 +26,16 @@ const ( sqliteTxLockImmediate = "_txlock=immediate" ) +var ( + // sqliteSchemaReplacements is a map of schema strings that need to be + // replaced for sqlite. This is needed because sqlite doesn't directly + // support the BIGINT type for primary keys, so we need to replace it + // with INTEGER. + sqliteSchemaReplacements = map[string]string{ + "BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY", + } +) + // SqliteStore is a database store implementation that uses a sqlite backend. type SqliteStore struct { cfg *SqliteConfig @@ -95,46 +105,43 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) { db.SetMaxOpenConns(defaultMaxConns) db.SetMaxIdleConns(defaultMaxConns) db.SetConnMaxLifetime(connIdleLifetime) - - if !cfg.SkipMigrations { - // Now that the database is open, populate the database with - // our set of schemas based on our embedded in-memory file - // system. - // - // First, we'll need to open up a new migration instance for - // our current target database: sqlite. - driver, err := sqlite_migrate.WithInstance( - db, &sqlite_migrate.Config{}, - ) - if err != nil { - return nil, err - } - - // We use INTEGER PRIMARY KEY for sqlite, because it acts as a - // ROWID alias which is 8 bytes big and also autoincrements. - // It's important to use the ROWID as a primary key because the - // key look ups are almost twice as fast - sqliteFS := newReplacerFS(sqlSchemas, map[string]string{ - "BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY", - }) - - err = applyMigrations( - sqliteFS, driver, "sqlc/migrations", "sqlc", - ) - if err != nil { - return nil, err - } - } - queries := sqlc.New(db) - return &SqliteStore{ + s := &SqliteStore{ cfg: cfg, BaseDB: &BaseDB{ DB: db, Queries: queries, }, - }, nil + } + + // Now that the database is open, populate the database with our set of + // schemas based on our embedded in-memory file system. + if !cfg.SkipMigrations { + if err := s.ExecuteMigrations(TargetLatest); err != nil { + return nil, fmt.Errorf("error executing migrations: "+ + "%w", err) + + } + } + + return s, nil +} + +// ExecuteMigrations runs migrations for the sqlite database, depending on the +// target given, either all migrations or up to a given version. +func (s *SqliteStore) ExecuteMigrations(target MigrationTarget) error { + driver, err := sqlite_migrate.WithInstance( + s.DB, &sqlite_migrate.Config{}, + ) + if err != nil { + return fmt.Errorf("error creating sqlite migration: %w", err) + } + + sqliteFS := newReplacerFS(sqlSchemas, sqliteSchemaReplacements) + return applyMigrations( + sqliteFS, driver, "sqlc/migrations", "sqlite", target, + ) } // NewTestSqliteDB is a helper function that creates an SQLite database for