diff --git a/adapter/adapter.go b/adapter/adapter.go index b02b4604..e4091ab3 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -9,14 +9,34 @@ import ( type ConnectionHolder interface { GetConn(ctx context.Context) (*stdsql.Conn, error) + GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) GetCatalogConn(ctx context.Context) (*stdsql.Conn, error) + GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) + TryGetTxn() *stdsql.Tx + CloseTxn() } func GetConn(ctx *sql.Context) (*stdsql.Conn, error) { return ctx.Session.(ConnectionHolder).GetConn(ctx) } -func QueryContext(ctx *sql.Context, query string, args ...any) (*stdsql.Rows, error) { +func GetTxn(ctx *sql.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) { + return ctx.Session.(ConnectionHolder).GetTxn(ctx, options) +} + +func GetCatalogTxn(ctx *sql.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) { + return ctx.Session.(ConnectionHolder).GetCatalogTxn(ctx, options) +} + +func TryGetTxn(ctx *sql.Context) *stdsql.Tx { + return ctx.Session.(ConnectionHolder).TryGetTxn() +} + +func CloseTxn(ctx *sql.Context) { + ctx.Session.(ConnectionHolder).CloseTxn() +} + +func Query(ctx *sql.Context, query string, args ...any) (*stdsql.Rows, error) { conn, err := GetConn(ctx) if err != nil { return nil, err @@ -24,10 +44,18 @@ func QueryContext(ctx *sql.Context, query string, args ...any) (*stdsql.Rows, er return conn.QueryContext(ctx, query, args...) } -// QueryCatalogContext is a helper function to query the catalog, such as information_schema. +func QueryRow(ctx *sql.Context, query string, args ...any) *stdsql.Row { + conn, err := GetConn(ctx) + if err != nil { + return nil + } + return conn.QueryRowContext(ctx, query, args...) +} + +// QueryCatalog is a helper function to query the catalog, such as information_schema. // Unlike QueryContext, this function does not require a schema name to be set on the connection, // and the current schema of the connection does not matter. -func QueryCatalogContext(ctx *sql.Context, query string, args ...any) (*stdsql.Rows, error) { +func QueryCatalog(ctx *sql.Context, query string, args ...any) (*stdsql.Rows, error) { conn, err := ctx.Session.(ConnectionHolder).GetCatalogConn(ctx) if err != nil { return nil, err @@ -35,7 +63,15 @@ func QueryCatalogContext(ctx *sql.Context, query string, args ...any) (*stdsql.R return conn.QueryContext(ctx, query, args...) } -func ExecContext(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) { +func QueryRowCatalog(ctx *sql.Context, query string, args ...any) *stdsql.Row { + conn, err := ctx.Session.(ConnectionHolder).GetCatalogConn(ctx) + if err != nil { + return nil + } + return conn.QueryRowContext(ctx, query, args...) +} + +func Exec(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) { conn, err := GetConn(ctx) if err != nil { return nil, err @@ -43,13 +79,29 @@ func ExecContext(ctx *sql.Context, query string, args ...any) (stdsql.Result, er return conn.ExecContext(ctx, query, args...) } -// ExecCatalogContext is a helper function to execute a catalog modification query, such as creating a database. +// ExecCatalog is a helper function to execute a catalog modification query, such as creating a database. // Unlike ExecContext, this function does not require a schema name to be set on the connection, // and the current schema of the connection does not matter. -func ExecCatalogContext(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) { +func ExecCatalog(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) { conn, err := ctx.Session.(ConnectionHolder).GetCatalogConn(ctx) if err != nil { return nil, err } return conn.ExecContext(ctx, query, args...) } + +func ExecCatalogInTxn(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) { + tx, err := ctx.Session.(ConnectionHolder).GetCatalogTxn(ctx, nil) + if err != nil { + return nil, err + } + return tx.ExecContext(ctx, query, args...) +} + +func ExecInTxn(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) { + tx, err := GetTxn(ctx, nil) + if err != nil { + return nil, err + } + return tx.ExecContext(ctx, query, args...) +} diff --git a/backend/connpool.go b/backend/connpool.go index 0d9c312f..8f48964e 100644 --- a/backend/connpool.go +++ b/backend/connpool.go @@ -119,6 +119,14 @@ func (p *ConnectionPool) GetTxn(ctx context.Context, id uint32, schemaName strin return tx, nil } +func (p *ConnectionPool) TryGetTxn(id uint32) *stdsql.Tx { + entry, ok := p.txns.Load(id) + if !ok { + return nil + } + return entry.(*stdsql.Tx) +} + func (p *ConnectionPool) CloseTxn(id uint32) { p.txns.Delete(id) } diff --git a/backend/executor.go b/backend/executor.go index 257fe129..fc64fe1c 100644 --- a/backend/executor.go +++ b/backend/executor.go @@ -31,6 +31,8 @@ import ( type DuckBuilder struct { base sql.NodeExecBuilder pool *ConnectionPool + + FlushDeltaBuffer func() error } var _ sql.NodeExecBuilder = (*DuckBuilder)(nil) @@ -43,6 +45,14 @@ func NewDuckBuilder(base sql.NodeExecBuilder, pool *ConnectionPool) *DuckBuilder } func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.RowIter, error) { + // Flush the delta buffer before executing the query. + // TODO(fan): Be fine-grained and flush only when the replicated tables are touched. + if b.FlushDeltaBuffer != nil { + if err := b.FlushDeltaBuffer(); err != nil { + return nil, err + } + } + n := root qp, ok := n.(*plan.QueryProcess) if ok { @@ -59,7 +69,7 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row ctx.GetLogger().WithFields(logrus.Fields{ "Query": ctx.Query(), "NodeType": fmt.Sprintf("%T", n), - }).Infoln("Building node:", n) + }).Trace("Building node:", n) // TODO; find a better way to fallback to the base builder switch n.(type) { @@ -130,7 +140,7 @@ func (b *DuckBuilder) executeExpressioner(ctx *sql.Context, n sql.Expressioner, } func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Conn) (sql.RowIter, error) { - logrus.Infoln("Executing Query...") + ctx.GetLogger().Trace("Executing Query...") var ( duckSQL string @@ -148,10 +158,10 @@ func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Co return nil, catalog.ErrTranspiler.New(err) } - logrus.WithFields(logrus.Fields{ + ctx.GetLogger().WithFields(logrus.Fields{ "Query": ctx.Query(), "DuckSQL": duckSQL, - }).Infoln("Executing Query...") + }).Trace("Executing Query...") // Execute the DuckDB query rows, err := conn.QueryContext(ctx.Context, duckSQL) @@ -169,10 +179,10 @@ func (b *DuckBuilder) executeDML(ctx *sql.Context, conn *stdsql.Conn) (sql.RowIt return nil, catalog.ErrTranspiler.New(err) } - logrus.WithFields(logrus.Fields{ + ctx.GetLogger().WithFields(logrus.Fields{ "Query": ctx.Query(), "DuckSQL": duckSQL, - }).Infoln("Executing DML...") + }).Trace("Executing DML...") // Execute the DuckDB query result, err := conn.ExecContext(ctx.Context, duckSQL) diff --git a/backend/session.go b/backend/session.go index 18133ba9..9f1e29cc 100644 --- a/backend/session.go +++ b/backend/session.go @@ -39,11 +39,6 @@ func NewSession(base *memory.Session, provider *catalog.DatabaseProvider, pool * // NewSessionBuilder returns a session builder for the given database provider. func NewSessionBuilder(provider *catalog.DatabaseProvider, pool *ConnectionPool) func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) { - _, err := pool.Exec("CREATE TABLE IF NOT EXISTS main.persistent_variables (name TEXT PRIMARY KEY, value TEXT, type TEXT)") - if err != nil { - panic(err) - } - return func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) { host := "" user := "" @@ -73,7 +68,7 @@ var _ sql.Transaction = (*Transaction)(nil) // StartTransaction implements sql.TransactionSession. func (sess Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.TransactionCharacteristic) (sql.Transaction, error) { - sess.GetLogger().Infoln("StartTransaction") + sess.GetLogger().Trace("StartTransaction") base, err := sess.Session.StartTransaction(ctx, tCharacteristic) if err != nil { return nil, err @@ -93,7 +88,7 @@ func (sess Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.Trans var tx *stdsql.Tx if startUnderlyingTx { - sess.GetLogger().Infoln("StartDuckTransaction") + sess.GetLogger().Trace("StartDuckTransaction") tx, err = sess.GetTxn(ctx, &stdsql.TxOptions{ReadOnly: tCharacteristic == sql.ReadOnly}) if err != nil { return nil, err @@ -104,10 +99,10 @@ func (sess Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.Trans // CommitTransaction implements sql.TransactionSession. func (sess Session) CommitTransaction(ctx *sql.Context, tx sql.Transaction) error { - sess.GetLogger().Infoln("CommitTransaction") + sess.GetLogger().Trace("CommitTransaction") transaction := tx.(*Transaction) if transaction.tx != nil { - sess.GetLogger().Infoln("CommitDuckTransaction") + sess.GetLogger().Trace("CommitDuckTransaction") defer sess.CloseTxn() if err := transaction.tx.Commit(); err != nil { return err @@ -118,10 +113,10 @@ func (sess Session) CommitTransaction(ctx *sql.Context, tx sql.Transaction) erro // Rollback implements sql.TransactionSession. func (sess Session) Rollback(ctx *sql.Context, tx sql.Transaction) error { - sess.GetLogger().Infoln("Rollback") + sess.GetLogger().Trace("Rollback") transaction := tx.(*Transaction) if transaction.tx != nil { - sess.GetLogger().Infoln("RollbackDuckTransaction") + sess.GetLogger().Trace("RollbackDuckTransaction") defer sess.CloseTxn() if err := transaction.tx.Rollback(); err != nil { return err @@ -137,7 +132,7 @@ func (sess Session) PersistGlobal(sysVarName string, value interface{}) error { } _, err := sess.ExecContext( context.Background(), - "INSERT OR REPLACE INTO main.persistent_variables (name, value, vtype) VALUES (?, ?, ?)", + catalog.InternalTables.PersistentVariable.UpsertStmt(), sysVarName, value, fmt.Sprintf("%T", value), ) return err @@ -147,7 +142,7 @@ func (sess Session) PersistGlobal(sysVarName string, value interface{}) error { func (sess Session) RemovePersistedGlobal(sysVarName string) error { _, err := sess.ExecContext( context.Background(), - "DELETE FROM main.persistent_variables WHERE name = ?", + catalog.InternalTables.PersistentVariable.DeleteStmt(), sysVarName, ) return err @@ -155,7 +150,7 @@ func (sess Session) RemovePersistedGlobal(sysVarName string) error { // RemoveAllPersistedGlobals implements sql.PersistableSession. func (sess Session) RemoveAllPersistedGlobals() error { - _, err := sess.ExecContext(context.Background(), "DELETE FROM main.persistent_variables") + _, err := sess.ExecContext(context.Background(), "DELETE FROM "+catalog.InternalTables.PersistentVariable.Name) return err } @@ -164,7 +159,8 @@ func (sess Session) GetPersistedValue(k string) (interface{}, error) { var value, vtype string err := sess.QueryRow( context.Background(), - "SELECT value, vtype FROM main.persistent_variables WHERE name = ?", k, + catalog.InternalTables.PersistentVariable.SelectStmt(), + k, ).Scan(&value, &vtype) switch { case err == stdsql.ErrNoRows: @@ -195,10 +191,22 @@ func (sess Session) GetCatalogConn(ctx context.Context) (*stdsql.Conn, error) { return sess.pool.GetConn(ctx, sess.ID()) } +// GetTxn implements adapter.ConnectionHolder. func (sess Session) GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) { return sess.pool.GetTxn(ctx, sess.ID(), sess.GetCurrentDatabase(), options) } +// GetCatalogTxn implements adapter.ConnectionHolder. +func (sess Session) GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) { + return sess.pool.GetTxn(ctx, sess.ID(), "", options) +} + +// TryGetTxn implements adapter.ConnectionHolder. +func (sess Session) TryGetTxn() *stdsql.Tx { + return sess.pool.TryGetTxn(sess.ID()) +} + +// CloseTxn implements adapter.ConnectionHolder. func (sess Session) CloseTxn() { sess.pool.CloseTxn(sess.ID()) } diff --git a/binlogreplication/binlog_position_store.go b/binlogreplication/binlog_position_store.go index 793a40f8..aeaea54b 100644 --- a/binlogreplication/binlog_position_store.go +++ b/binlogreplication/binlog_position_store.go @@ -15,6 +15,7 @@ package binlogreplication import ( + stdsql "database/sql" "errors" "fmt" "io/fs" @@ -23,94 +24,67 @@ import ( "strings" "sync" + "github.com/apecloud/myduckserver/adapter" + "github.com/apecloud/myduckserver/catalog" gms "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/sql" "vitess.io/vitess/go/mysql/replication" ) const binlogPositionDirectory = ".replica" -const binlogPositionFilename = "binlog-position" const mysqlFlavor = "MySQL56" +const defaultChannelName = "" -// binlogPositionStore manages loading and saving data to the binlog position file stored on disk. This provides +// binlogPositionStore manages loading and saving data to the binlog position metadata table. This provides // durable storage for the set of GTIDs that have been successfully executed on the replica, so that the replica // server can be restarted and resume binlog event messages at the correct point. type binlogPositionStore struct { mu sync.Mutex } -// Load loads a mysql.Position instance from the .replica/binlog-position file at the root of working directory -// This file MUST be stored at the root of the provider's filesystem, and NOT inside a nested database's .replica directory, -// since the binlog position contains events that cover all databases in a SQL server. The returned mysql.Position -// represents the set of GTIDs that have been successfully executed and applied on this replica. Currently only the -// default binlog channel ("") is supported. If no .replica/binlog-position file is stored, this method returns a nil -// mysql.Position and a nil error. If any errors are encountered, a nil mysql.Position and an error are returned. -func (store *binlogPositionStore) Load(engine *gms.Engine) (pos replication.Position, err error) { +// Load loads a mysql.Position instance from the database. The returned mysql.Position +// represents the set of GTIDs that have been successfully executed and applied on this replica. +// Currently only the default binlog channel ("") is supported. +// If no position is stored, this method returns a zero mysql.Position and a nil error. +// If any errors are encountered, a nil mysql.Position and an error are returned. +func (store *binlogPositionStore) Load(ctx *sql.Context, engine *gms.Engine) (pos replication.Position, err error) { store.mu.Lock() defer store.mu.Unlock() - dir := filepath.Join(getDataDir(engine), binlogPositionDirectory) - _, err = os.Stat(dir) - if err != nil && errors.Is(err, fs.ErrNotExist) { - return pos, nil + var positionString string + err = adapter.QueryRowCatalog(ctx, catalog.InternalTables.BinlogPosition.SelectStmt(), defaultChannelName).Scan(&positionString) + if err == stdsql.ErrNoRows { + return replication.Position{}, nil } else if err != nil { - return pos, err - } - - _, err = os.Stat(filepath.Join(dir, binlogPositionFilename)) - if err != nil && errors.Is(err, fs.ErrNotExist) { - return pos, nil - } else if err != nil { - return pos, err - } - - filePath, err := filepath.Abs(filepath.Join(dir, binlogPositionFilename)) - if err != nil { - return pos, err - } - - bytes, err := os.ReadFile(filePath) - if err != nil { - return pos, err + return replication.Position{}, fmt.Errorf("unable to load binlog position: %w", err) } - positionString := string(bytes) // Strip off the "MySQL56/" prefix - prefix := "MySQL56/" - if strings.HasPrefix(positionString, prefix) { - positionString = string(bytes[len(prefix):]) - } + positionString = strings.TrimPrefix(positionString, "MySQL56/") return replication.ParsePosition(mysqlFlavor, positionString) } -// Save saves the specified |position| to disk in the .replica/binlog-position file at the root of the provider's -// filesystem. This file MUST be stored at the root of the provider's filesystem, and NOT inside a nested database's -// .replica directory, since the binlog position contains events that cover all databases in a SQL server. |position| -// represents the set of GTIDs that have been successfully executed and applied on this replica. Currently only the -// default binlog channel ("") is supported. If any errors are encountered persisting the position to disk, an -// error is returned. +// Save persists the specified |position| to disk. +// The |position| represents the set of GTIDs that have been successfully executed and applied on this replica. +// Currently only the default binlog channel ("") is supported. +// If any errors are encountered persisting the position to disk, an error is returned. func (store *binlogPositionStore) Save(ctx *sql.Context, engine *gms.Engine, position replication.Position) error { if position.IsZero() { - return fmt.Errorf("unable to save binlog position: nil position passed") + return fmt.Errorf("unable to save binlog position: empty position passed") } store.mu.Lock() defer store.mu.Unlock() - // The .replica dir may not exist yet, so create it if necessary. - dir, err := createReplicaDir(engine) - if err != nil { - return err + if _, err := adapter.ExecCatalogInTxn( + ctx, + catalog.InternalTables.BinlogPosition.UpsertStmt(), + defaultChannelName, position.String(), + ); err != nil { + return fmt.Errorf("unable to save binlog position: %w", err) } - - filePath, err := filepath.Abs(filepath.Join(dir, binlogPositionFilename)) - if err != nil { - return err - } - - encodedPosition := replication.EncodePosition(position) - return os.WriteFile(filePath, []byte(encodedPosition), 0666) + return nil } // Delete deletes the stored mysql.Position information stored in .replica/binlog-position in the root of the provider's @@ -120,7 +94,8 @@ func (store *binlogPositionStore) Delete(ctx *sql.Context, engine *gms.Engine) e store.mu.Lock() defer store.mu.Unlock() - return os.Remove(filepath.Join(getDataDir(engine), binlogPositionDirectory, binlogPositionFilename)) + _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.BinlogPosition.DeleteStmt(), defaultChannelName) + return err } // createReplicaDir creates the .replica directory if it doesn't already exist. diff --git a/binlogreplication/binlog_replica_applier.go b/binlogreplication/binlog_replica_applier.go index 86849bfc..2d332026 100644 --- a/binlogreplication/binlog_replica_applier.go +++ b/binlogreplication/binlog_replica_applier.go @@ -15,6 +15,7 @@ package binlogreplication import ( + stdsql "database/sql" "encoding/binary" "errors" "fmt" @@ -24,11 +25,15 @@ import ( "sync/atomic" "time" + "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/binlog" "github.com/apecloud/myduckserver/charset" + "github.com/apecloud/myduckserver/delta" + "github.com/apecloud/myduckserver/mysqlutil" gms "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/binlogreplication" + "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/types" doltvtmysql "github.com/dolthub/vitess/go/mysql" "github.com/sirupsen/logrus" @@ -62,7 +67,16 @@ type binlogReplicaApplier struct { filters *filterConfiguration running atomic.Bool engine *gms.Engine - tableWriterProvider TableWriterProvider + + tableWriterProvider TableWriterProvider + previousGtid replication.GTID + pendingPosition replication.Position + ongoingBatchTxn atomic.Bool // true if we're in a batched transaction, i.e., a series of binlog-format=ROW primary transactions + dirtyTxn atomic.Bool // true if we're in a transaction that is opened and/or has uncommited changes + dirtyStream atomic.Bool // true if the binlog stream does not end with a commit event + inTxnStmtID atomic.Uint64 // auto-incrementing ID for statements within a transaction + deltaBufSize atomic.Uint64 // size of the delta buffer + lastCommitTime time.Time // time of the last commit } func newBinlogReplicaApplier(filters *filterConfiguration) *binlogReplicaApplier { @@ -190,7 +204,7 @@ func (a *binlogReplicaApplier) startReplicationEventStream(ctx *sql.Context, con return err } - position, err := positionStore.Load(a.engine) + position, err := positionStore.Load(ctx, a.engine) if err != nil { return err } @@ -235,11 +249,26 @@ func (a *binlogReplicaApplier) startReplicationEventStream(ctx *sql.Context, con } a.currentPosition = position + a.pendingPosition = position + if err := sql.SystemVariables.AssignValues(map[string]interface{}{"gtid_executed": a.currentPosition.GTIDSet.String()}); err != nil { + ctx.GetLogger().Errorf("unable to set @@GLOBAL.gtid_executed: %s", err.Error()) + } // Clear out the format description in case we're reconnecting, so that we don't use the old format description // to interpret any event messages before we receive the new format description from the new stream. a.format = nil + // Clear out the transactional states and the delta buffer + a.previousGtid = nil + a.currentGtid = nil + a.ongoingBatchTxn.Store(false) + a.dirtyTxn.Store(false) + a.dirtyStream.Store(false) + a.inTxnStmtID.Store(0) + a.lastCommitTime = time.Now() + a.tableWriterProvider.DiscardDeltaBuffer(ctx) + a.deltaBufSize.Store(0) + // If the source server has binlog checksums enabled (@@global.binlog_checksum), then the replica MUST // set @master_binlog_checksum to handshake with the server to acknowledge that it knows that checksums // are in use. Without this step, the server will just send back error messages saying that the replica @@ -259,7 +288,7 @@ func (a *binlogReplicaApplier) startReplicationEventStream(ctx *sql.Context, con "serverId": serverId, "binlogFile": binlogFile, "position": position.String(), - }).Infoln("Sending binlog dump command to source") + }).Trace("Sending binlog dump command to source") return conn.SendBinlogDumpCommand(serverId, binlogFile, position) } @@ -272,6 +301,9 @@ func (a *binlogReplicaApplier) replicaBinlogEventHandler(ctx *sql.Context) error var conn *mysql.Conn var eventProducer *binlogEventProducer + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + // Process binlog events for { if conn == nil { @@ -316,24 +348,40 @@ func (a *binlogReplicaApplier) replicaBinlogEventHandler(ctx *sql.Context) error } } else { // otherwise, log the error if it's something we don't expect and continue - ctx.GetLogger().Errorf("unexpected error of type %T: '%v'", err, err.Error()) - MyBinlogReplicaController.setIoError(sqlerror.ERUnknownError, err.Error()) + recordReplicationError(ctx, err) + } + + case <-ticker.C: + if a.ongoingBatchTxn.Load() && !a.dirtyStream.Load() { + // We should commit the transaction to flush the changes to the database + // if we're in a batched transaction and haven't seen any changes for a while. + if err := a.extendOrCommitBatchTxn(ctx, engine); err != nil { + recordReplicationError(ctx, err) + } } case <-a.stopReplicationChan: ctx.GetLogger().Trace("received stop replication signal") eventProducer.Stop() + if a.ongoingBatchTxn.Load() && !a.dirtyStream.Load() { + if err := a.commitOngoingTxn(ctx, engine, NormalCommit, delta.OnCloseFlushReason); err != nil { + recordReplicationError(ctx, err) + } + } return nil } } } +func recordReplicationError(ctx *sql.Context, err error) { + ctx.GetLogger().Errorf("unexpected error of type %T: '%v'", err, err.Error()) + MyBinlogReplicaController.setSqlError(sqlerror.ERUnknownError, err.Error()) +} + // processBinlogEvent processes a single binlog event message and returns an error if there were any problems // processing it. func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms.Engine, event mysql.BinlogEvent) error { var err error - createCommit := false - commitToAllDatabases := false // We don't support checksum validation, so we MUST strip off any checksum bytes if present, otherwise it gets // interpreted as part of the payload and corrupts the data. Future checksum sizes, are not guaranteed to be the @@ -341,7 +389,7 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms. // tells us if checksums are enabled and what algorithm they use. We can NOT strip the checksum off of // FormatDescription events, because FormatDescription always includes a CRC32 checksum, and Vitess depends on // those bytes always being present when we parse the event into a FormatDescription type. - if a.format != nil && event.IsFormatDescription() == false { + if a.format != nil && !event.IsFormatDescription() { var err error event, _, err = event.StripChecksum(*a.format) if err != nil { @@ -351,19 +399,31 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms. } } + // ------------------- NOTE ----------------------- + // Since this function is called in a hot loop, + // we invoke the logging API conditionally + // to avoid unnecessary memory allocation + // made by logrus and interface boxing. + // ------------------------------------------------ + logger := ctx.GetLogger() + isTraceLevelEnabled := logger.Logger.IsLevelEnabled(logrus.TraceLevel) + switch { case event.IsRand(): // A RAND_EVENT contains two seed values that set the rand_seed1 and rand_seed2 system variables that are // used to compute the random number. For more details, see: https://mariadb.com/kb/en/rand_event/ // Note: it is written only before a QUERY_EVENT and is NOT used with row-based logging. - ctx.GetLogger().Trace("Received binlog event: Rand") + if isTraceLevelEnabled { + logger.Trace("Received binlog event: Rand") + } case event.IsXID(): // An XID event is generated for a COMMIT of a transaction that modifies one or more tables of an // XA-capable storage engine. For more details, see: https://mariadb.com/kb/en/xid_event/ - ctx.GetLogger().Trace("Received binlog event: XID") - createCommit = true - commitToAllDatabases = true + if isTraceLevelEnabled { + logger.Trace("Received binlog event: XID") + } + return a.extendOrCommitBatchTxn(ctx, engine) case event.IsQuery(): // A Query event represents a statement executed on the source server that should be executed on the @@ -377,69 +437,85 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms. flags, mode := parseQueryEventVars(*a.format, event) - ctx.GetLogger().WithFields(logrus.Fields{ - "database": query.Database, - "charset": query.Charset, - "query": query.SQL, - "flags": fmt.Sprintf("0x%x", flags), - "sql_mode": fmt.Sprintf("0x%x", mode), - }).Infoln("Received binlog event: Query") - - // When executing SQL statements sent from the primary, we can't be sure what database was modified unless we - // look closely at the statement. For example, we could be connected to db01, but executed - // "create table db02.t (...);" – i.e., looking at query.Database is NOT enough to always determine the correct - // database that was modified, so instead, we commit to all databases when we see a Query binlog event to - // avoid issues with correctness, at the cost of being slightly less efficient - commitToAllDatabases = true + if isTraceLevelEnabled { + logger.WithFields(logrus.Fields{ + "database": query.Database, + "charset": query.Charset, + "query": query.SQL, + "flags": fmt.Sprintf("0x%x", flags), + "sql_mode": fmt.Sprintf("0x%x", mode), + }).Trace("Received binlog event: Query") + } + var msg string if flags&doltvtmysql.QFlagOptionAutoIsNull > 0 { - ctx.GetLogger().Tracef("Setting sql_auto_is_null ON") + msg = "Setting sql_auto_is_null ON" ctx.SetSessionVariable(ctx, "sql_auto_is_null", 1) } else { - ctx.GetLogger().Tracef("Setting sql_auto_is_null OFF") + msg = "Setting sql_auto_is_null OFF" ctx.SetSessionVariable(ctx, "sql_auto_is_null", 0) } + if isTraceLevelEnabled { + logger.Trace(msg) + } if flags&doltvtmysql.QFlagOptionNotAutocommit > 0 { - ctx.GetLogger().Tracef("Setting autocommit=0") + msg = "Setting autocommit=0" ctx.SetSessionVariable(ctx, "autocommit", 0) } else { - ctx.GetLogger().Tracef("Setting autocommit=1") + msg = "Setting autocommit=1" ctx.SetSessionVariable(ctx, "autocommit", 1) } + if isTraceLevelEnabled { + logger.Trace(msg) + } if flags&doltvtmysql.QFlagOptionNoForeignKeyChecks > 0 { - ctx.GetLogger().Tracef("Setting foreign_key_checks=0") + msg = "Setting foreign_key_checks=0" ctx.SetSessionVariable(ctx, "foreign_key_checks", 0) } else { - ctx.GetLogger().Tracef("Setting foreign_key_checks=1") + msg = "Setting foreign_key_checks=1" ctx.SetSessionVariable(ctx, "foreign_key_checks", 1) } + if isTraceLevelEnabled { + logger.Trace(msg) + } // NOTE: unique_checks is not currently honored by Dolt if flags&doltvtmysql.QFlagOptionRelaxedUniqueChecks > 0 { - ctx.GetLogger().Tracef("Setting unique_checks=0") + msg = "Setting unique_checks=0" ctx.SetSessionVariable(ctx, "unique_checks", 0) } else { - ctx.GetLogger().Tracef("Setting unique_checks=1") + msg = "Setting unique_checks=1" ctx.SetSessionVariable(ctx, "unique_checks", 1) } + if isTraceLevelEnabled { + logger.Trace(msg) + } - createCommit = !strings.EqualFold(query.SQL, "begin") - // TODO(fan): Here we - // skip the transaction for now; - // skip the operations on `mysql.time_zone*` tables, which are not supported by go-mysql-server yet. - if createCommit && !(query.Database == "mysql" && strings.HasPrefix(query.SQL, "TRUNCATE TABLE time_zone")) { - ctx.SetCurrentDatabase(query.Database) - executeQueryWithEngine(ctx, engine, query.SQL) + if err := a.executeQueryWithEngine(ctx, engine, query); err != nil { + ctx.GetLogger().WithFields(logrus.Fields{ + "error": err.Error(), + "query": query.SQL, + }).Error("Applying query failed") + msg := fmt.Sprintf("Applying query failed: %v", err.Error()) + MyBinlogReplicaController.setSqlError(sqlerror.ERUnknownError, msg) } + a.inTxnStmtID.Add(1) case event.IsRotate(): // When a binary log file exceeds the configured size limit, a ROTATE_EVENT is written at the end of the file, // pointing to the next file in the sequence. ROTATE_EVENT is generated locally and written to the binary log // on the source server and it's also written when a FLUSH LOGS statement occurs on the source server. // For more details, see: https://mariadb.com/kb/en/rotate_event/ - ctx.GetLogger().Trace("Received binlog event: Rotate") + if isTraceLevelEnabled { + logger.Trace("Received binlog event: Rotate") + } + // https://dev.mysql.com/doc/refman/8.4/en/binary-log.html + // > ... a transaction is written to the file in one piece, never split between files. + if a.currentGtid != nil { + return a.extendOrCommitBatchTxn(ctx, engine) + } case event.IsFormatDescription(): // This is a descriptor event that is written to the beginning of a binary log file, at position 4 (after @@ -449,12 +525,15 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms. return err } a.format = &format - ctx.GetLogger().WithFields(logrus.Fields{ - "format": a.format, - "formatVersion": a.format.FormatVersion, - "serverVersion": a.format.ServerVersion, - "checksum": a.format.ChecksumAlgorithm, - }).Trace("Received binlog event: FormatDescription") + + if isTraceLevelEnabled { + logger.WithFields(logrus.Fields{ + "format": a.format, + "formatVersion": a.format.FormatVersion, + "serverVersion": a.format.ServerVersion, + "checksum": a.format.ChecksumAlgorithm, + }).Trace("Received binlog event: FormatDescription") + } case event.IsPreviousGTIDs(): // Logged in every binlog to record the current replication state. Consists of the last GTID seen for each @@ -463,9 +542,12 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms. if err != nil { return err } - ctx.GetLogger().WithFields(logrus.Fields{ - "previousGtids": position.GTIDSet.String(), - }).Trace("Received binlog event: PreviousGTIDs") + + if isTraceLevelEnabled { + logger.WithFields(logrus.Fields{ + "previousGtids": position.GTIDSet.String(), + }).Trace("Received binlog event: PreviousGTIDs") + } case event.IsGTID(): // For global transaction ID, used to start a new transaction event group, instead of the old BEGIN query event, @@ -474,14 +556,24 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms. if err != nil { return err } - if isBegin { - ctx.GetLogger().Errorf("unsupported binlog protocol message: GTID event with 'isBegin' set to true") + + if isTraceLevelEnabled { + logger.WithFields(logrus.Fields{ + "gtid": gtid, + "isBegin": isBegin, + }).Trace("Received binlog event: GTID") } - ctx.GetLogger().WithFields(logrus.Fields{ - "gtid": gtid, - "isBegin": isBegin, - }).Trace("Received binlog event: GTID") + + // BEGIN will commit the previous transaction implicitly, so we need to set the previous GTID to the current GTID. + a.previousGtid = a.currentGtid a.currentGtid = gtid + + if isBegin { + if err := a.executeQueryWithEngine(ctx, engine, mysql.Query{SQL: "BEGIN"}); err != nil { + return err + } + } + // if the source's UUID hasn't been set yet, set it and persist it if a.replicationSourceUuid == "" { uuid := fmt.Sprintf("%v", gtid.SourceServer()) @@ -502,18 +594,21 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms. if err != nil { return err } - ctx.GetLogger().WithFields(logrus.Fields{ - "id": tableId, - "tableName": tableMap.Name, - "database": tableMap.Database, - "flags": convertToHexString(tableMap.Flags), - "metadata": tableMap.Metadata, - "types": tableMap.Types, - }).Trace("Received binlog event: TableMap") + + if isTraceLevelEnabled { + logger.WithFields(logrus.Fields{ + "id": tableId, + "tableName": tableMap.Name, + "database": tableMap.Database, + "flags": convertToHexString(tableMap.Flags), + "metadata": tableMap.Metadata, + "types": tableMap.Types, + }).Trace("Received binlog event: TableMap") + } if tableId == 0xFFFFFF { // Table ID 0xFFFFFF is a special value that indicates table maps can be freed. - ctx.GetLogger().Infof("binlog protocol message: table ID '0xFFFFFF'; clearing table maps") + ctx.GetLogger().Trace("binlog protocol message: table ID '0xFFFFFF'; clearing table maps") a.tableMapsById = make(map[uint64]*mysql.TableMap) } else { flags := tableMap.Flags @@ -526,7 +621,7 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms. } if flags != 0 { msg := fmt.Sprintf("unsupported binlog protocol message: TableMap event with unsupported flags '%x'", flags) - ctx.GetLogger().Errorf(msg) + ctx.GetLogger().Error(msg) MyBinlogReplicaController.setSqlError(sqlerror.ERUnknownError, msg) } a.tableMapsById[tableId] = tableMap @@ -539,6 +634,9 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms. if err != nil { return err } + a.dirtyTxn.Store(true) + a.dirtyStream.Store(true) + a.inTxnStmtID.Add(1) default: // https://mariadb.com/kb/en/2-binlog-event-header/ @@ -557,33 +655,221 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms. } } - if createCommit { - // TODO(fan): Skip the transaction commit for now - _ = commitToAllDatabases - // var databasesToCommit []string - // if commitToAllDatabases { - // databasesToCommit = getAllUserDatabaseNames(ctx, engine) - // for _, database := range databasesToCommit { - // executeQueryWithEngine(ctx, engine, "use `"+database+"`;") - // executeQueryWithEngine(ctx, engine, "commit;") - // } - // } + return nil +} - // Record the last GTID processed after the commit - a.currentPosition.GTIDSet = a.currentPosition.GTIDSet.AddGTID(a.currentGtid) - err := sql.SystemVariables.AssignValues(map[string]interface{}{"gtid_executed": a.currentPosition.GTIDSet.String()}) - if err != nil { - ctx.GetLogger().Errorf("unable to set @@GLOBAL.gtid_executed: %s", err.Error()) +type CommitKind int + +const ( + NormalCommit CommitKind = iota + ImplicitCommitBeforeStmt + ImplicitCommitAfterStmt +) + +func (a *binlogReplicaApplier) commitOngoingTxn(ctx *sql.Context, engine *gms.Engine, kind CommitKind, reason delta.FlushReason) error { + // Flush the delta buffer if it's grown too large + // TODO(fan): Make the threshold configurable + if err := a.flushDeltaBuffer(ctx, reason); err != nil { + return err + } + + // Record the last GTID processed. + // If the commit is caused implicitly (by, e.g., a BEGIN statment or a DDL statement), + // then we don't want to update the saved position to include the current GTID. + if kind != ImplicitCommitBeforeStmt { + a.pendingPosition = replication.AppendGTID(a.pendingPosition, a.currentGtid) + } + if err := positionStore.Save(ctx, engine, a.pendingPosition); err != nil { + return fmt.Errorf("unable to store GTID executed metadata to disk: %s", err.Error()) + } + + // --- Commit the transaction --- // + + // Commit the transaction started on this session + if kind != ImplicitCommitAfterStmt || !getAutocommit(ctx) { + subctx := sql.NewContext(ctx, sql.WithSession(ctx.Session)).WithQuery("COMMIT") + if err := a.execute(subctx, engine, "COMMIT"); err != nil { + return err } - err = positionStore.Save(ctx, engine, a.currentPosition) - if err != nil { - return fmt.Errorf("unable to store GTID executed metadata to disk: %s", err.Error()) + } + // The session manager does not start an actual transaction in autocommit=1 mode, + // but there may be a transaction in progress if we have started it manually. + if tx := adapter.TryGetTxn(ctx); tx != nil { + if err := tx.Commit(); err != nil && err != stdsql.ErrTxDone { + return err + } + adapter.CloseTxn(ctx) + } + + // --- Update the in-memory states --- // + + // Reset the transaction-related flags + a.ongoingBatchTxn.Store(false) + a.dirtyTxn.Store(false) + a.dirtyStream.Store(false) + a.inTxnStmtID.Store(0) + + // Record the time of the last commit + a.lastCommitTime = time.Now() + + // Synchronize the current position with the pending position + a.currentPosition = a.pendingPosition + + // Expose the last GTID executed as a system variable + err := sql.SystemVariables.AssignValues(map[string]interface{}{"gtid_executed": a.currentPosition.GTIDSet.String()}) + if err != nil { + ctx.GetLogger().Errorf("unable to set @@GLOBAL.gtid_executed: %s", err.Error()) + } + + return nil +} + +func (a *binlogReplicaApplier) mayExtendBatchTxn() (bool, delta.FlushReason) { + extend, reason := false, delta.UnknownFlushReason + if a.ongoingBatchTxn.Load() { + extend = true + switch { + case time.Since(a.lastCommitTime) >= 200*time.Millisecond: // commit the batched txn every 200ms + extend, reason = false, delta.TimeTickFlushReason + case a.deltaBufSize.Load() >= (128 << 20): // commit the batched txn if the delta buffer is too large (>= 128MB) + extend, reason = false, delta.MemoryLimitFlushReason + } + } + return extend, reason +} + +func (a *binlogReplicaApplier) extendOrCommitBatchTxn(ctx *sql.Context, engine *gms.Engine) error { + // If we're in a batched transaction, then we don't want to commit yet. + extend, reason := a.mayExtendBatchTxn() + if extend { + a.pendingPosition = replication.AppendGTID(a.pendingPosition, a.currentGtid) + a.dirtyStream.Store(false) + return nil + } + + return a.commitOngoingTxn(ctx, engine, NormalCommit, reason) +} + +// executeQueryWithEngine executes a query against the engine and returns an error if the query failed. +func (a *binlogReplicaApplier) executeQueryWithEngine(ctx *sql.Context, engine *gms.Engine, query mysql.Query) error { + // TODO(fan): Here we skip the operations on `mysql.time_zone*` tables, which are not supported by go-mysql-server yet. + if query.Database == "mysql" && strings.HasPrefix(query.SQL, "TRUNCATE TABLE time_zone") { + return a.commitOngoingTxn(ctx, engine, ImplicitCommitAfterStmt, delta.DMLStmtFlushReason) + } + + // Create a sub-context when running queries against the engine, so that we get an accurate query start time. + subctx := sql.NewContext(ctx, sql.WithSession(ctx.Session)).WithQuery(query.SQL) + subctx.SetCurrentDatabase(query.Database) + if subctx.GetCurrentDatabase() == "" { + ctx.GetLogger().WithField("query", query).Warn("No current database selected") + } + + // Analyze the query first to check if it's a DDL or DML statement, + // and flush the changelog if necessary. + var ( + implicitCommit bool + flushChangelog bool + flushReason delta.FlushReason + ) + node, err := engine.PrepareQuery(subctx, query.SQL) + if err != nil { + return err + } + + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + log.WithFields(logrus.Fields{ + "query": query.SQL, + "db": subctx.GetCurrentDatabase(), + "ongoingBatchTxn": a.ongoingBatchTxn.Load(), + "dirtyTxn": a.dirtyTxn.Load(), + "dirtyStream": a.dirtyStream.Load(), + }).Tracef("Executing %T query", node) + } + + switch node.(type) { + case *plan.StartTransaction: + var extendable bool + extendable, flushReason = a.mayExtendBatchTxn() + if extendable { + // If we're in an extended batched transaction, + // then we don't want to start a new transaction yet. + a.pendingPosition = replication.AppendGTID(a.pendingPosition, a.previousGtid) + a.dirtyStream.Store(true) + return nil + } + implicitCommit, flushChangelog = true, true + + case *plan.Commit: + return a.extendOrCommitBatchTxn(subctx, engine) + + case *plan.InsertInto, *plan.Update, *plan.DeleteFrom, *plan.LoadData: + implicitCommit, flushChangelog, flushReason = false, true, delta.DMLStmtFlushReason + + default: + if mysqlutil.CauseImplicitCommitBefore(node) { + implicitCommit, flushChangelog, flushReason = true, true, delta.UnknownFlushReason + } + if mysqlutil.CauseSchemaChange(node) { + flushReason = delta.DDLStmtFlushReason + } + } + + if flushChangelog { + // Flush the buffered changes + if err := a.flushDeltaBuffer(subctx, flushReason); err != nil { + return err } } + if implicitCommit && a.dirtyTxn.Load() { + // Commit the previous transaction before executing the query. + if err := a.commitOngoingTxn(subctx, engine, ImplicitCommitBeforeStmt, flushReason); err != nil { + return err + } + } + + if err := a.execute(subctx, engine, query.SQL); err != nil { + return err + } + + a.dirtyTxn.Store(true) + a.dirtyStream.Store(true) + + switch node.(type) { + case *plan.StartTransaction: + a.ongoingBatchTxn.Store(true) + default: + a.ongoingBatchTxn.Store(false) + } + + if mysqlutil.CauseImplicitCommitAfter(node) { + // Commit the transaction after executing the query + return a.commitOngoingTxn(subctx, engine, ImplicitCommitAfterStmt, flushReason) + } + return nil } +func (a *binlogReplicaApplier) execute(ctx *sql.Context, engine *gms.Engine, query string) error { + _, iter, _, err := engine.Query(ctx, query) + if err != nil { + // Log any errors, except for commits with "nothing to commit" + if err.Error() != "nothing to commit" { + return err + } + return nil + } + for { + if _, err := iter.Next(ctx); err != nil { + if err == io.EOF { + return nil + } + ctx.GetLogger().Errorf("ERROR reading query results: %v ", err.Error()) + return err + } + } +} + // processRowEvent processes a WriteRows, DeleteRows, or UpdateRows binlog event and returns an error if any problems // were encountered. func (a *binlogReplicaApplier) processRowEvent(ctx *sql.Context, event mysql.BinlogEvent, engine *gms.Engine) error { @@ -598,7 +884,9 @@ func (a *binlogReplicaApplier) processRowEvent(ctx *sql.Context, event mysql.Bin default: return fmt.Errorf("unsupported event type: %v", event) } - ctx.GetLogger().Tracef("Received binlog event: %s", eventName) + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + log.Tracef("Received binlog event: %s", eventName) + } tableId := event.TableID(*a.format) tableMap, ok := a.tableMapsById[tableId] @@ -620,9 +908,11 @@ func (a *binlogReplicaApplier) processRowEvent(ctx *sql.Context, event mysql.Bin return err } - ctx.GetLogger().WithFields(logrus.Fields{ - "flags": fmt.Sprintf("%x", rows.Flags), - }).Tracef("Processing rows from %s event", eventName) + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + log.WithFields(logrus.Fields{ + "flags": fmt.Sprintf("%x", rows.Flags), + }).Tracef("Processing rows from %s event", eventName) + } flags := rows.Flags foreignKeyChecksDisabled := false @@ -663,12 +953,16 @@ func (a *binlogReplicaApplier) processRowEvent(ctx *sql.Context, event mysql.Bin eventType = binlog.InsertRowEvent isRowFormat = rows.DataColumns.BitCount() == fieldCount } - ctx.GetLogger().Tracef(" - %s Rows (db: %s, table: %s, row-format: %v)", eventType, tableMap.Database, tableName, isRowFormat) + + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + log.Tracef(" - %s Rows (db: %s, table: %s, IsRowFormat: %v, RowCount: %v)", eventType, tableMap.Database, tableName, isRowFormat, len(rows.Rows)) + } if isRowFormat && len(pkSchema.PkOrdinals) > 0 { // --binlog-format=ROW & --binlog-row-image=full - return a.appendRowFormatChanges(ctx, engine, tableMap, tableName, schema, eventType, &rows) + return a.appendRowFormatChanges(ctx, tableMap, tableName, schema, eventType, &rows) } else { + a.ongoingBatchTxn.Store(false) return a.writeChanges(ctx, engine, tableMap, tableName, pkSchema, eventType, &rows, foreignKeyChecksDisabled) } } @@ -706,8 +1000,13 @@ func (a *binlogReplicaApplier) writeChanges( dataRows = append(dataRows, dataRow) } + txn, err := adapter.GetTxn(ctx, nil) + if err != nil { + return err + } tableWriter, err := a.tableWriterProvider.GetTableWriter( - ctx, engine, + ctx, + txn, tableMap.Database, tableName, pkSchema, len(tableMap.Types), len(rows.Rows), @@ -718,7 +1017,6 @@ func (a *binlogReplicaApplier) writeChanges( if err != nil { return err } - defer tableWriter.Rollback() switch event { case binlog.DeleteRowEvent: @@ -732,38 +1030,42 @@ func (a *binlogReplicaApplier) writeChanges( return err } - ctx.GetLogger().WithFields(logrus.Fields{ - "db": tableMap.Database, - "table": tableName, - "event": event, - "rows": len(rows.Rows), - }).Infoln("processRowEvent") + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + log.WithFields(logrus.Fields{ + "db": tableMap.Database, + "table": tableName, + "event": event, + "rows": len(rows.Rows), + }).Trace("processRowEvent") + } - return tableWriter.Commit() + return nil } func (a *binlogReplicaApplier) appendRowFormatChanges( - ctx *sql.Context, engine *gms.Engine, + ctx *sql.Context, tableMap *mysql.TableMap, tableName string, schema sql.Schema, event binlog.RowEventType, rows *mysql.Rows, ) error { - appender, err := a.tableWriterProvider.GetDeltaAppender(ctx, engine, tableMap.Database, tableName, schema) + appender, err := a.tableWriterProvider.GetDeltaAppender(ctx, tableMap.Database, tableName, schema) if err != nil { return err } var ( - fields = appender.Fields() - actions = appender.Action() - txnTags = appender.TxnTag() - txnServers = appender.TxnServer() - txnGroups = appender.TxnGroup() - txnSeqNumbers = appender.TxnSeqNumber() - - txnTag []byte - txnServer []byte - txnGroup []byte - txnSeq uint64 + fields = appender.Fields() + actions = appender.Action() + txnTags = appender.TxnTag() + txnServers = appender.TxnServer() + txnGroups = appender.TxnGroup() + txnSeqNumbers = appender.TxnSeqNumber() + TxnStmtOrdinals = appender.TxnStmtOrdinal() + + txnTag []byte + txnServer []byte + txnGroup []byte + txnSeq uint64 + txnStmtOrdinal = a.inTxnStmtID.Load() ) switch gtid := a.currentGtid.(type) { @@ -794,6 +1096,7 @@ func (a *binlogReplicaApplier) appendRowFormatChanges( txnServers.Append(txnServer) txnGroups.Append(txnGroup) txnSeqNumbers.Append(txnSeq) + TxnStmtOrdinals.Append(txnStmtOrdinal) pos := 0 for i := range schema { @@ -810,6 +1113,7 @@ func (a *binlogReplicaApplier) appendRowFormatChanges( } pos += length } + a.deltaBufSize.Add(uint64(pos)) } } @@ -822,6 +1126,7 @@ func (a *binlogReplicaApplier) appendRowFormatChanges( txnServers.Append(txnServer) txnGroups.Append(txnGroup) txnSeqNumbers.Append(txnSeq) + TxnStmtOrdinals.Append(txnStmtOrdinal) pos := 0 for i := range schema { @@ -838,12 +1143,26 @@ func (a *binlogReplicaApplier) appendRowFormatChanges( } pos += length } + a.deltaBufSize.Add(uint64(pos)) } } - // TODO(fan): Apparently this is not how the delta appender is supposed to be used. - // But let's make it work for now. - return a.tableWriterProvider.FlushDelta(ctx) + return nil +} + +func (a *binlogReplicaApplier) flushDeltaBuffer(ctx *sql.Context, reason delta.FlushReason) error { + tx, err := adapter.GetCatalogTxn(ctx, nil) + if err != nil { + return err + } + + defer a.deltaBufSize.Store(0) + + if err = a.tableWriterProvider.FlushDeltaBuffer(ctx, tx, reason); err != nil { + ctx.GetLogger().Errorf("Failed to flush changelog: %v", err.Error()) + MyBinlogReplicaController.setSqlError(sqlerror.ERUnknownError, err.Error()) + } + return err } // @@ -1060,40 +1379,6 @@ func loadReplicaServerId() (uint32, error) { return serverId, nil } -func executeQueryWithEngine(ctx *sql.Context, engine *gms.Engine, query string) { - // Create a sub-context when running queries against the engine, so that we get an accurate query start time. - queryCtx := sql.NewContext(ctx, sql.WithSession(ctx.Session)).WithQuery(query) - - if queryCtx.GetCurrentDatabase() == "" { - ctx.GetLogger().WithFields(logrus.Fields{ - "query": query, - }).Warn("No current database selected") - } - - _, iter, _, err := engine.Query(queryCtx, query) - if err != nil { - // Log any errors, except for commits with "nothing to commit" - if err.Error() != "nothing to commit" { - queryCtx.GetLogger().WithFields(logrus.Fields{ - "error": err.Error(), - "query": query, - }).Errorf("Applying query failed") - msg := fmt.Sprintf("Applying query failed: %v", err.Error()) - MyBinlogReplicaController.setSqlError(sqlerror.ERUnknownError, msg) - } - return - } - for { - _, err := iter.Next(queryCtx) - if err != nil { - if err != io.EOF { - queryCtx.GetLogger().Errorf("ERROR reading query results: %v ", err.Error()) - } - return - } - } -} - // // Generic util functions... // @@ -1103,11 +1388,15 @@ func convertToHexString(v uint16) string { return fmt.Sprintf("%x", v) } -// keys returns a slice containing the keys in the specified map |m|. -func keys[K comparable, V any](m map[K]V) []K { - keys := make([]K, 0, len(m)) - for k := range m { - keys = append(keys, k) +func getAutocommit(ctx *sql.Context) bool { + autocommit := true + autoCommitSessionVar, err := ctx.GetSessionVariable(ctx, sql.AutoCommitSessionVar) + if err == nil { + autocommit, err = sql.ConvertToBool(ctx, autoCommitSessionVar) + } + if err != nil { + ctx.GetLogger().Warn("Unable to get @@autocommit session variable; assuming autocommit is enabled:", err) + return true } - return keys + return autocommit } diff --git a/binlogreplication/binlog_replication_test.go b/binlogreplication/binlog_replication_test.go index 1c508e9c..7b0252dd 100644 --- a/binlogreplication/binlog_replication_test.go +++ b/binlogreplication/binlog_replication_test.go @@ -854,9 +854,9 @@ func startDuckSqlServer(dir string, persistentSystemVars map[string]string) (int } args := []string{"go", "run", ".", - // "--loglevel=TRACE", fmt.Sprintf("--port=%v", duckPort), fmt.Sprintf("--datadir=%s", dir), + "--loglevel=6", // TRACE } // If we're running in CI, use a precompiled dolt binary instead of go run diff --git a/binlogreplication/writer.go b/binlogreplication/writer.go index b9c6acf7..c72d073b 100644 --- a/binlogreplication/writer.go +++ b/binlogreplication/writer.go @@ -1,9 +1,11 @@ package binlogreplication import ( + stdsql "database/sql" + "github.com/apache/arrow/go/v17/arrow/array" "github.com/apecloud/myduckserver/binlog" - sqle "github.com/dolthub/go-mysql-server" + "github.com/apecloud/myduckserver/delta" "github.com/dolthub/go-mysql-server/sql" "vitess.io/vitess/go/mysql" ) @@ -12,8 +14,6 @@ type TableWriter interface { Insert(ctx *sql.Context, keyRows []sql.Row) error Delete(ctx *sql.Context, keyRows []sql.Row) error Update(ctx *sql.Context, keyRows []sql.Row, valueRows []sql.Row) error - Commit() error - Rollback() error } type DeltaAppender interface { @@ -24,12 +24,14 @@ type DeltaAppender interface { TxnServer() *array.BinaryDictionaryBuilder TxnGroup() *array.BinaryDictionaryBuilder TxnSeqNumber() *array.Uint64Builder + TxnStmtOrdinal() *array.Uint64Builder } type TableWriterProvider interface { // GetTableWriter returns a TableWriter for writing to the specified |table| in the specified |database|. GetTableWriter( - ctx *sql.Context, engine *sqle.Engine, + ctx *sql.Context, + txn *stdsql.Tx, databaseName, tableName string, schema sql.PrimaryKeySchema, columnCount, rowCount int, @@ -38,13 +40,16 @@ type TableWriterProvider interface { foreignKeyChecksDisabled bool, ) (TableWriter, error) - // GetDeltaAppender returns an ArrowAppender for appending updates to the specified |table| in the specified |database|. + // GetDeltaAppender returns a DeltaAppender for appending updates to the specified |table| in the specified |database|. GetDeltaAppender( - ctx *sql.Context, engine *sqle.Engine, + ctx *sql.Context, databaseName, tableName string, schema sql.Schema, ) (DeltaAppender, error) // FlushDelta writes the accumulated changes to the database. - FlushDelta(ctx *sql.Context) error + FlushDeltaBuffer(ctx *sql.Context, tx *stdsql.Tx, reason delta.FlushReason) error + + // DiscardDeltaBuffer discards the accumulated changes. + DiscardDeltaBuffer(ctx *sql.Context) } diff --git a/catalog/database.go b/catalog/database.go index deb1a708..78dce515 100644 --- a/catalog/database.go +++ b/catalog/database.go @@ -77,7 +77,7 @@ func (d *Database) tablesInsensitive(ctx *sql.Context, pattern string) ([]*Table } func (d *Database) findTables(ctx *sql.Context, pattern string) ([]*Table, error) { - rows, err := adapter.QueryCatalogContext(ctx, "SELECT DISTINCT table_name, comment FROM duckdb_tables() where database_name = ? and schema_name = ? and table_name ILIKE ?", d.catalog, d.name, pattern) + rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT table_name, comment FROM duckdb_tables() where database_name = ? and schema_name = ? and table_name ILIKE ?", d.catalog, d.name, pattern) if err != nil { return nil, ErrDuckDB.New(err) } @@ -167,7 +167,7 @@ func (d *Database) CreateTable(ctx *sql.Context, name string, schema sql.Primary sqlsBuild.WriteString(s) } - _, err := adapter.ExecContext(ctx, sqlsBuild.String()) + _, err := adapter.Exec(ctx, sqlsBuild.String()) if err != nil { if IsDuckDBTableAlreadyExistsError(err) { return sql.ErrTableAlreadyExists.New(name) @@ -185,7 +185,7 @@ func (d *Database) DropTable(ctx *sql.Context, name string) error { d.mu.Lock() defer d.mu.Unlock() - _, err := adapter.ExecContext(ctx, fmt.Sprintf(`DROP TABLE %s`, FullTableName(d.catalog, d.name, name))) + _, err := adapter.Exec(ctx, fmt.Sprintf(`DROP TABLE %s`, FullTableName(d.catalog, d.name, name))) if err != nil { if IsDuckDBTableNotFoundError(err) { @@ -201,7 +201,7 @@ func (d *Database) RenameTable(ctx *sql.Context, oldName string, newName string) d.mu.Lock() defer d.mu.Unlock() - _, err := adapter.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s RENAME TO "%s"`, FullTableName(d.catalog, d.name, oldName), newName)) + _, err := adapter.Exec(ctx, fmt.Sprintf(`ALTER TABLE %s RENAME TO "%s"`, FullTableName(d.catalog, d.name, oldName), newName)) if err != nil { if IsDuckDBTableNotFoundError(err) { return sql.ErrTableNotFound.New(oldName) @@ -228,7 +228,7 @@ func (d *Database) extractViewDefinitions(ctx *sql.Context, schemaName string, v args = append(args, viewName) } - rows, err := adapter.QueryCatalogContext(ctx, query, args...) + rows, err := adapter.QueryCatalog(ctx, query, args...) if err != nil { return nil, ErrDuckDB.New(err) } @@ -281,7 +281,7 @@ func (d *Database) CreateView(ctx *sql.Context, name string, selectStatement str d.mu.Lock() defer d.mu.Unlock() - _, err := adapter.ExecContext(ctx, fmt.Sprintf(`USE %s; CREATE VIEW "%s" AS %s`, FullSchemaName(d.catalog, d.name), name, selectStatement)) + _, err := adapter.Exec(ctx, fmt.Sprintf(`USE %s; CREATE VIEW "%s" AS %s`, FullSchemaName(d.catalog, d.name), name, selectStatement)) if err != nil { return ErrDuckDB.New(err) } @@ -293,7 +293,7 @@ func (d *Database) DropView(ctx *sql.Context, name string) error { d.mu.Lock() defer d.mu.Unlock() - _, err := adapter.ExecContext(ctx, fmt.Sprintf(`USE %s; DROP VIEW "%s"`, FullSchemaName(d.catalog, d.name), name)) + _, err := adapter.Exec(ctx, fmt.Sprintf(`USE %s; DROP VIEW "%s"`, FullSchemaName(d.catalog, d.name), name)) if err != nil { if IsDuckDBViewNotFoundError(err) { return sql.ErrViewDoesNotExist.New(name) diff --git a/catalog/internal_tables.go b/catalog/internal_tables.go new file mode 100644 index 00000000..6db33c82 --- /dev/null +++ b/catalog/internal_tables.go @@ -0,0 +1,90 @@ +package catalog + +import "strings" + +type InternalTable struct { + Name string + KeyColumns []string + ValueColumns []string + DDL string +} + +func (it *InternalTable) QualifiedName() string { + return "main." + it.Name +} + +func (it *InternalTable) UpsertStmt() string { + var b strings.Builder + b.Grow(128) + b.WriteString("INSERT OR REPLACE INTO main.") + b.WriteString(it.Name) + b.WriteString(" VALUES (?") + for range it.KeyColumns[1:] { + b.WriteString(", ?") + } + for range it.ValueColumns { + b.WriteString(", ?") + } + b.WriteString(")") + return b.String() +} + +func (it *InternalTable) DeleteStmt() string { + var b strings.Builder + b.Grow(128) + b.WriteString("DELETE FROM main.") + b.WriteString(it.Name) + b.WriteString(" WHERE ") + b.WriteString(it.KeyColumns[0]) + b.WriteString(" = ?") + for _, c := range it.KeyColumns[1:] { + b.WriteString(c) + b.WriteString(" = ?") + } + return b.String() +} + +func (it *InternalTable) SelectStmt() string { + var b strings.Builder + b.Grow(128) + b.WriteString("SELECT ") + b.WriteString(it.ValueColumns[0]) + for _, c := range it.ValueColumns[1:] { + b.WriteString(", ") + b.WriteString(c) + } + b.WriteString(" FROM main.") + b.WriteString(it.Name) + b.WriteString(" WHERE ") + b.WriteString(it.KeyColumns[0]) + b.WriteString(" = ?") + for _, c := range it.KeyColumns[1:] { + b.WriteString(" AND ") + b.WriteString(c) + b.WriteString(" = ?") + } + return b.String() +} + +var InternalTables = struct { + PersistentVariable InternalTable + BinlogPosition InternalTable +}{ + PersistentVariable: InternalTable{ + Name: "persistent_variable", + KeyColumns: []string{"name"}, + ValueColumns: []string{"value", "vtype"}, + DDL: "name TEXT PRIMARY KEY, value TEXT, vtype TEXT", + }, + BinlogPosition: InternalTable{ + Name: "binlog_position", + KeyColumns: []string{"channel"}, + ValueColumns: []string{"position"}, + DDL: "channel TEXT PRIMARY KEY, position TEXT", + }, +} + +var internalTables = []InternalTable{ + InternalTables.PersistentVariable, + InternalTables.BinlogPosition, +} diff --git a/catalog/provider.go b/catalog/provider.go index ebe312f5..0ae9b500 100644 --- a/catalog/provider.go +++ b/catalog/provider.go @@ -67,7 +67,16 @@ func NewDBProvider(dataDir, dbFile string) (*DatabaseProvider, error) { if _, err := storage.ExecContext(context.Background(), q); err != nil { storage.Close() connector.Close() - return nil, fmt.Errorf("failed to execute boot query %q: %v", q, err) + return nil, fmt.Errorf("failed to execute boot query %q: %w", q, err) + } + } + + for _, t := range internalTables { + if _, err := storage.ExecContext( + context.Background(), + "CREATE TABLE IF NOT EXISTS "+t.QualifiedName()+"("+t.DDL+")", + ); err != nil { + return nil, fmt.Errorf("failed to create internal table %q: %w", t.Name, err) } } @@ -117,7 +126,7 @@ func (prov *DatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database { prov.mu.RLock() defer prov.mu.RUnlock() - rows, err := adapter.QueryCatalogContext(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", prov.catalogName) + rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", prov.catalogName) if err != nil { panic(ErrDuckDB.New(err)) } @@ -175,7 +184,7 @@ func (prov *DatabaseProvider) HasDatabase(ctx *sql.Context, name string) bool { } func hasDatabase(ctx *sql.Context, catalog string, name string) (bool, error) { - rows, err := adapter.QueryCatalogContext(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ? AND schema_name ILIKE ?", catalog, name) + rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ? AND schema_name ILIKE ?", catalog, name) if err != nil { return false, ErrDuckDB.New(err) } @@ -188,7 +197,7 @@ func (prov *DatabaseProvider) CreateDatabase(ctx *sql.Context, name string) erro prov.mu.Lock() defer prov.mu.Unlock() - _, err := adapter.ExecCatalogContext(ctx, fmt.Sprintf(`CREATE SCHEMA %s`, FullSchemaName(prov.catalogName, name))) + _, err := adapter.ExecCatalog(ctx, fmt.Sprintf(`CREATE SCHEMA %s`, FullSchemaName(prov.catalogName, name))) if err != nil { return ErrDuckDB.New(err) } @@ -201,7 +210,7 @@ func (prov *DatabaseProvider) DropDatabase(ctx *sql.Context, name string) error prov.mu.Lock() defer prov.mu.Unlock() - _, err := adapter.ExecContext(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`, FullSchemaName(prov.catalogName, name))) + _, err := adapter.Exec(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`, FullSchemaName(prov.catalogName, name))) if err != nil { return ErrDuckDB.New(err) } diff --git a/catalog/table.go b/catalog/table.go index 49fc1693..7d6e8ff9 100644 --- a/catalog/table.go +++ b/catalog/table.go @@ -143,7 +143,7 @@ func (t *Table) PrimaryKeySchema() sql.PrimaryKeySchema { } func getPrimaryKeyOrdinals(ctx *sql.Context, catalogName, dbName, tableName string) []int { - rows, err := adapter.QueryCatalogContext(ctx, ` + rows, err := adapter.QueryCatalog(ctx, ` SELECT constraint_column_indexes FROM duckdb_constraints() WHERE database_name = ? AND schema_name = ? AND table_name = ? AND constraint_type = 'PRIMARY KEY' LIMIT 1 `, catalogName, dbName, tableName) if err != nil { @@ -193,7 +193,7 @@ func (t *Table) AddColumn(ctx *sql.Context, column *sql.Column, order *sql.Colum comment := NewCommentWithMeta(column.Comment, typ.mysql) sql += fmt.Sprintf(`; COMMENT ON COLUMN %s IS '%s'`, FullColumnName(t.db.catalog, t.db.name, t.name, column.Name), comment.Encode()) - _, err = adapter.ExecContext(ctx, sql) + _, err = adapter.Exec(ctx, sql) if err != nil { return ErrDuckDB.New(err) } @@ -208,7 +208,7 @@ func (t *Table) DropColumn(ctx *sql.Context, columnName string) error { sql := fmt.Sprintf(`ALTER TABLE %s DROP COLUMN "%s"`, FullTableName(t.db.catalog, t.db.name, t.name), columnName) - _, err := adapter.ExecContext(ctx, sql) + _, err := adapter.Exec(ctx, sql) if err != nil { return ErrDuckDB.New(err) } @@ -256,7 +256,7 @@ func (t *Table) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Co sqls = append(sqls, fmt.Sprintf(`COMMENT ON COLUMN %s IS '%s'`, FullColumnName(t.db.catalog, t.db.name, t.name, column.Name), comment.Encode())) joinedSQL := strings.Join(sqls, "; ") - _, err = adapter.ExecContext(ctx, joinedSQL) + _, err = adapter.Exec(ctx, joinedSQL) if err != nil { logrus.Errorf("run duckdb sql failed: %s", joinedSQL) return ErrDuckDB.New(err) @@ -361,7 +361,7 @@ func (t *Table) CreateIndex(ctx *sql.Context, indexDef sql.IndexDef) error { } // Execute the SQL statement to create the index - _, err := adapter.ExecContext(ctx, sqlsBuilder.String()) + _, err := adapter.Exec(ctx, sqlsBuilder.String()) if err != nil { if IsDuckDBIndexAlreadyExistsError(err) { return sql.ErrDuplicateKey.New(indexDef.Name) @@ -388,7 +388,7 @@ func (t *Table) DropIndex(ctx *sql.Context, indexName string) error { EncodeIndexName(t.name, indexName)) // Execute the SQL statement to drop the index - _, err := adapter.ExecContext(ctx, sql) + _, err := adapter.Exec(ctx, sql) if err != nil { return ErrDuckDB.New(err) } @@ -408,7 +408,7 @@ func (t *Table) GetIndexes(ctx *sql.Context) ([]sql.Index, error) { defer t.mu.RUnlock() // Query to get the indexes for the table - rows, err := adapter.QueryCatalogContext(ctx, `SELECT index_name, is_unique, comment, sql FROM duckdb_indexes() WHERE database_name = ? AND schema_name = ? AND table_name = ?`, + rows, err := adapter.QueryCatalog(ctx, `SELECT index_name, is_unique, comment, sql FROM duckdb_indexes() WHERE database_name = ? AND schema_name = ? AND table_name = ?`, t.db.catalog, t.db.name, t.name) if err != nil { return nil, ErrDuckDB.New(err) @@ -486,7 +486,7 @@ func (t *Table) Comment() string { } func queryColumns(ctx *sql.Context, catalogName, schemaName, tableName string) ([]*ColumnInfo, error) { - rows, err := adapter.QueryCatalogContext(ctx, ` + rows, err := adapter.QueryCatalog(ctx, ` SELECT column_name, column_index, data_type, is_nullable, column_default, comment, numeric_precision, numeric_scale FROM duckdb_columns() WHERE database_name = ? AND schema_name = ? AND table_name = ? diff --git a/replica/controller.go b/delta/controller.go similarity index 69% rename from replica/controller.go rename to delta/controller.go index bf3b5998..88777699 100644 --- a/replica/controller.go +++ b/delta/controller.go @@ -1,10 +1,10 @@ -package replica +package delta import ( "bytes" - "context" stdsql "database/sql" "fmt" + "math/bits" "strconv" "strings" "sync" @@ -13,30 +13,38 @@ import ( "github.com/apache/arrow/go/v17/arrow/ipc" "github.com/apecloud/myduckserver/backend" "github.com/apecloud/myduckserver/binlog" - "github.com/apecloud/myduckserver/binlogreplication" "github.com/apecloud/myduckserver/catalog" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" "github.com/sirupsen/logrus" ) +type FlushStats struct { + DeltaSize int64 + Insertions int64 + Deletions int64 +} + type DeltaController struct { mutex sync.Mutex - tables map[tableIdentifier]*deltaAppender + tables map[tableIdentifier]*DeltaAppender pool *backend.ConnectionPool } +func NewController(pool *backend.ConnectionPool) *DeltaController { + return &DeltaController{ + pool: pool, + tables: make(map[tableIdentifier]*DeltaAppender), + } +} + func (c *DeltaController) GetDeltaAppender( databaseName, tableName string, schema sql.Schema, -) (binlogreplication.DeltaAppender, error) { +) (*DeltaAppender, error) { c.mutex.Lock() defer c.mutex.Unlock() - if c.tables == nil { - c.tables = make(map[tableIdentifier]*deltaAppender) - } - id := tableIdentifier{databaseName, tableName} appender, ok := c.tables[id] if ok { @@ -50,8 +58,18 @@ func (c *DeltaController) GetDeltaAppender( return appender, nil } +func (c *DeltaController) Close() { + c.mutex.Lock() + defer c.mutex.Unlock() + + for k, da := range c.tables { + da.appender.Release() + delete(c.tables, k) + } +} + // Flush writes the accumulated changes to the database. -func (c *DeltaController) Flush(ctx context.Context) error { +func (c *DeltaController) Flush(ctx *sql.Context, tx *stdsql.Tx, reason FlushReason) (FlushStats, error) { c.mutex.Lock() defer c.mutex.Unlock() @@ -76,30 +94,53 @@ func (c *DeltaController) Flush(ctx context.Context) error { // See: // https://duckdb.org/docs/sql/indexes.html#limitations-of-art-indexes // https://github.com/duckdb/duckdb/issues/14133 + var ( + // Share the buffer among all tables. + buf bytes.Buffer + stats FlushStats + ) - tx, err := c.pool.Begin() - if err != nil { - return err + for table, appender := range c.tables { + deltaRowCount := appender.RowCount() + if deltaRowCount > 0 { + if err := c.updateTable(ctx, tx, table, appender, &buf, &stats); err != nil { + return stats, err + } + } + switch reason { + case DDLStmtFlushReason: + // DDL statement may change the schema + delete(c.tables, table) + default: + // Pre-allocate memory for the next delta + if deltaRowCount > 0 { + // Next power of 2 + appender.Grow(1 << bits.Len64(uint64(deltaRowCount)-1)) + } + } } - defer tx.Rollback() - - // Share the buffer among all tables. - buf := bytes.Buffer{} - for table, appender := range c.tables { - if err := c.updateTable(ctx, tx, table, appender, &buf); err != nil { - return err + if stats.DeltaSize > 0 { + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + ctx.GetLogger().WithFields(logrus.Fields{ + "DeltaSize": stats.DeltaSize, + "Insertions": stats.Insertions, + "Deletions": stats.Deletions, + "Reason": reason.String(), + }).Trace("Flushed delta buffer") } } - return tx.Commit() + + return stats, nil } func (c *DeltaController) updateTable( - ctx context.Context, + ctx *sql.Context, tx *stdsql.Tx, table tableIdentifier, - appender *deltaAppender, + appender *DeltaAppender, buf *bytes.Buffer, + stats *FlushStats, ) error { buf.Reset() @@ -107,16 +148,16 @@ func (c *DeltaController) updateTable( record := appender.Build() defer record.Release() - fmt.Println("record:", record) + // fmt.Println("record:", record) // TODO(fan): Switch to zero-copy Arrow ingestion once this PR is merged: // https://github.com/marcboeker/go-duckdb/pull/283 w := ipc.NewWriter(buf, ipc.WithSchema(record.Schema())) if err := w.Write(record); err != nil { - panic(err) + return err } if err := w.Close(); err != nil { - panic(err) + return err } bytes := buf.Bytes() size := len(bytes) @@ -145,7 +186,7 @@ func (c *DeltaController) updateTable( // FROM ( // SELECT // pk1, pk2, ..., - // LAST(ROW(*COLUMNS(*)) ORDER BY txn_group, txn_seq, action) AS r + // LAST(ROW(*COLUMNS(*)) ORDER BY txn_group, txn_seq, txn_stmt, action) AS r // FROM delta // GROUP BY pk1, pk2, ... // ) @@ -170,7 +211,7 @@ func (c *DeltaController) updateTable( } builder.WriteString(" FROM (SELECT ") builder.WriteString(pkList) - builder.WriteString(", LAST(ROW(*COLUMNS(*)) ORDER BY txn_group, txn_seq, action) AS r") + builder.WriteString(", LAST(ROW(*COLUMNS(*)) ORDER BY txn_group, txn_seq, txn_stmt, action) AS r") builder.WriteString(ipcSQL) builder.WriteString(" GROUP BY ") builder.WriteString(pkList) @@ -178,25 +219,28 @@ func (c *DeltaController) updateTable( condenseDeltaSQL := builder.String() var ( - result stdsql.Result - rowsAffected int64 - err error + result stdsql.Result + affected int64 + err error ) // Create a temporary table to store the latest delta view. result, err = tx.ExecContext(ctx, "CREATE OR REPLACE TEMP TABLE delta AS "+condenseDeltaSQL) if err == nil { - rowsAffected, err = result.RowsAffected() + affected, err = result.RowsAffected() } if err != nil { return err } + stats.DeltaSize += affected defer tx.ExecContext(ctx, "DROP TABLE IF EXISTS temp.main.delta") - logrus.WithFields(logrus.Fields{ - "table": qualifiedTableName, - "rows": rowsAffected, - }).Infoln("Delta created") + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + log.WithFields(logrus.Fields{ + "table": qualifiedTableName, + "rows": affected, + }).Trace("Delta created") + } // Insert or replace new rows (action = INSERT) into the base table. insertSQL := "INSERT OR REPLACE INTO " + @@ -205,16 +249,19 @@ func (c *DeltaController) updateTable( strconv.Itoa(int(binlog.InsertRowEvent)) result, err = tx.ExecContext(ctx, insertSQL) if err == nil { - rowsAffected, err = result.RowsAffected() + affected, err = result.RowsAffected() } if err != nil { return err } + stats.Insertions += affected - logrus.WithFields(logrus.Fields{ - "table": qualifiedTableName, - "rows": rowsAffected, - }).Infoln("Inserted") + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + log.WithFields(logrus.Fields{ + "table": qualifiedTableName, + "rows": affected, + }).Trace("Inserted") + } // Delete rows that have been deleted. // The plan for `IN` is optimized to a SEMI JOIN, @@ -230,16 +277,19 @@ func (c *DeltaController) updateTable( "FROM temp.main.delta WHERE action = " + strconv.Itoa(int(binlog.DeleteRowEvent)) + ")" result, err = tx.ExecContext(ctx, deleteSQL) if err == nil { - rowsAffected, err = result.RowsAffected() + affected, err = result.RowsAffected() } if err != nil { return err } + stats.Deletions += affected - logrus.WithFields(logrus.Fields{ - "table": qualifiedTableName, - "rows": rowsAffected, - }).Infoln("Deleted") + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + log.WithFields(logrus.Fields{ + "table": qualifiedTableName, + "rows": affected, + }).Trace("Deleted") + } return nil } diff --git a/replica/delta.go b/delta/delta.go similarity index 56% rename from replica/delta.go rename to delta/delta.go index 1fb62f38..71c1104c 100644 --- a/replica/delta.go +++ b/delta/delta.go @@ -1,35 +1,32 @@ -package replica +package delta import ( "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" - "github.com/apecloud/myduckserver/binlogreplication" "github.com/apecloud/myduckserver/myarrow" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" ) const ( - AugmentedColumnList = "action, txn_tag, txn_server, txn_group, txn_seq" + AugmentedColumnList = "action, txn_tag, txn_server, txn_group, txn_seq, txn_stmt" ) type tableIdentifier struct { dbName, tableName string } -type deltaAppender struct { +type DeltaAppender struct { schema sql.Schema appender myarrow.ArrowAppender } -var _ binlogreplication.DeltaAppender = &deltaAppender{} - // Create a new appender. // Add action and GTID columns to the schema: // // https://mariadb.com/kb/en/gtid/ // https://dev.mysql.com/doc/refman/9.0/en/replication-gtids-concepts.html -func newDeltaAppender(schema sql.Schema) (*deltaAppender, error) { +func newDeltaAppender(schema sql.Schema) (*DeltaAppender, error) { augmented := make(sql.Schema, 0, len(schema)+5) augmented = append(augmented, &sql.Column{ Name: "action", // delete = 0, update = 1, insert = 2 @@ -44,7 +41,10 @@ func newDeltaAppender(schema sql.Schema) (*deltaAppender, error) { Name: "txn_group", // NULL for MySQL & MariaDB GTID; binlog file name for file position based replication Type: types.Text, }, &sql.Column{ - Name: "txn_seq", + Name: "txn_seq", // Transaction ID for MySQL & MariaDB GTID; binlog position for file position based replication + Type: types.Uint64, + }, &sql.Column{ + Name: "txn_stmt", // Ordinal number of the statement in the transaction Type: types.Uint64, }) augmented = append(augmented, schema...) @@ -54,48 +54,64 @@ func newDeltaAppender(schema sql.Schema) (*deltaAppender, error) { return nil, err } - return &deltaAppender{ + return &DeltaAppender{ schema: augmented, appender: appender, }, nil } -func (a *deltaAppender) Field(i int) array.Builder { - return a.appender.Field(i + 5) +func (a *DeltaAppender) Field(i int) array.Builder { + return a.appender.Field(i + 6) } -func (a *deltaAppender) Fields() []array.Builder { - return a.appender.Fields()[5:] +func (a *DeltaAppender) Fields() []array.Builder { + return a.appender.Fields()[6:] } -func (a *deltaAppender) Schema() sql.Schema { +func (a *DeltaAppender) Schema() sql.Schema { return a.schema } -func (a *deltaAppender) BaseSchema() sql.Schema { +func (a *DeltaAppender) BaseSchema() sql.Schema { return a.schema[5:] } -func (a *deltaAppender) Action() *array.Int8Builder { +func (a *DeltaAppender) Action() *array.Int8Builder { return a.appender.Field(0).(*array.Int8Builder) } -func (a *deltaAppender) TxnTag() *array.BinaryDictionaryBuilder { +func (a *DeltaAppender) TxnTag() *array.BinaryDictionaryBuilder { return a.appender.Field(1).(*array.BinaryDictionaryBuilder) } -func (a *deltaAppender) TxnServer() *array.BinaryDictionaryBuilder { +func (a *DeltaAppender) TxnServer() *array.BinaryDictionaryBuilder { return a.appender.Field(2).(*array.BinaryDictionaryBuilder) } -func (a *deltaAppender) TxnGroup() *array.BinaryDictionaryBuilder { +func (a *DeltaAppender) TxnGroup() *array.BinaryDictionaryBuilder { return a.appender.Field(3).(*array.BinaryDictionaryBuilder) } -func (a *deltaAppender) TxnSeqNumber() *array.Uint64Builder { +func (a *DeltaAppender) TxnSeqNumber() *array.Uint64Builder { return a.appender.Field(4).(*array.Uint64Builder) } -func (a *deltaAppender) Build() arrow.Record { +func (a *DeltaAppender) TxnStmtOrdinal() *array.Uint64Builder { + return a.appender.Field(5).(*array.Uint64Builder) +} + +func (a *DeltaAppender) RowCount() int { + return a.Action().Len() +} + +func (a *DeltaAppender) Build() arrow.Record { return a.appender.Build() } + +func (a *DeltaAppender) Grow(n int) { + a.appender.Grow(n) +} + +func (a *DeltaAppender) Release() { + a.appender.Release() +} diff --git a/delta/flush_reason.go b/delta/flush_reason.go new file mode 100644 index 00000000..54db4b4a --- /dev/null +++ b/delta/flush_reason.go @@ -0,0 +1,41 @@ +package delta + +type FlushReason uint8 + +const ( + // UnknownFlushReason means that the changes have to be flushed for an unknown reason. + UnknownFlushReason FlushReason = iota + // DDLStmtFlushReason means that the changes have to be flushed because of a DDL statement. + DDLStmtFlushReason + // DMLStmtFlushReason means that the changes have to be flushed because of a DML statement. + DMLStmtFlushReason + // RowCountLimitFlushReason means that the changes have to be flushed because the row count limit is reached. + RowCountLimitFlushReason + // MemoryLimitFlushReason means that the changes have to be flushed because the memory limit is reached. + MemoryLimitFlushReason + // TimeTickFlushReason means that the changes have to be flushed because a time ticker is fired. + TimeTickFlushReason + // QueryFlushReason means that the changes have to be flushed because some tables are queried. + QueryFlushReason + // OnCloseFlushReason means that the changes have to be flushed because the controller is closed. + OnCloseFlushReason +) + +func (r FlushReason) String() string { + switch r { + case DDLStmtFlushReason: + return "DDLStmt" + case DMLStmtFlushReason: + return "DMLStmt" + case RowCountLimitFlushReason: + return "RowCountLimit" + case MemoryLimitFlushReason: + return "MemoryLimit" + case TimeTickFlushReason: + return "TimeTick" + case QueryFlushReason: + return "Query" + default: + return "Unknown" + } +} diff --git a/go.mod b/go.mod index 426c5238..6709b47f 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/dolthub/vitess v0.0.0-20240919225659-2ad81685e772 github.com/go-sql-driver/mysql v1.8.1 github.com/jmoiron/sqlx v1.4.0 - github.com/marcboeker/go-duckdb v1.8.1 + github.com/marcboeker/go-duckdb v1.8.2-0.20241002112231-62d5fa8c0697 github.com/prometheus/client_golang v1.20.3 github.com/rs/zerolog v1.33.0 github.com/shopspring/decimal v1.3.1 diff --git a/go.sum b/go.sum index 228e3a0c..44f24e48 100644 --- a/go.sum +++ b/go.sum @@ -225,6 +225,8 @@ github.com/marcboeker/go-duckdb v1.8.0 h1:iOWv1wTL0JIMqpyns6hCf5XJJI4fY6lmJNk+it github.com/marcboeker/go-duckdb v1.8.0/go.mod h1:2oV8BZv88S16TKGKM+Lwd0g7DX84x0jMxjTInThC8Is= github.com/marcboeker/go-duckdb v1.8.1 h1:jQjvsN49PNZC9IJLCIMjfD3lMO0QERKNYeZwhyVA8UY= github.com/marcboeker/go-duckdb v1.8.1/go.mod h1:2oV8BZv88S16TKGKM+Lwd0g7DX84x0jMxjTInThC8Is= +github.com/marcboeker/go-duckdb v1.8.2-0.20241002112231-62d5fa8c0697 h1:PU2n7bbll9b4erOPDi4z08JJsICs4L0jeNhr/dZV1So= +github.com/marcboeker/go-duckdb v1.8.2-0.20241002112231-62d5fa8c0697/go.mod h1:2oV8BZv88S16TKGKM+Lwd0g7DX84x0jMxjTInThC8Is= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= diff --git a/main.go b/main.go index d87ae4ac..d4c0ade4 100644 --- a/main.go +++ b/main.go @@ -43,12 +43,14 @@ var ( dataDirectory = "." dbFileName = "mysql.db" dbFilePath string + logLevel = int(logrus.InfoLevel) ) func init() { flag.StringVar(&address, "address", address, "The address to bind to.") flag.IntVar(&port, "port", port, "The port to bind to.") flag.StringVar(&dataDirectory, "datadir", dataDirectory, "The directory to store the database.") + flag.IntVar(&logLevel, "loglevel", logLevel, "The log level to use.") } func ensureSQLTranslate() { @@ -60,6 +62,9 @@ func ensureSQLTranslate() { func main() { flag.Parse() + + logrus.SetLevel(logrus.Level(logLevel)) + dbFilePath = filepath.Join(dataDirectory, dbFileName) ensureSQLTranslate() @@ -81,7 +86,7 @@ func main() { logrus.Fatalln("Failed to set the persister:", err) } - replica.RegisterReplicaController(provider, engine, pool) + replica.RegisterReplicaController(provider, engine, pool, builder) config := server.Config{ Protocol: "tcp", diff --git a/myarrow/appender.go b/myarrow/appender.go index ee38ff62..d8d6a2c2 100644 --- a/myarrow/appender.go +++ b/myarrow/appender.go @@ -31,6 +31,15 @@ func (a *ArrowAppender) Build() arrow.Record { return a.RecordBuilder.NewRecord() } +// Grow increases the capacity of the builder to at least n rows. +// This method is intended to be used to preallocate memory for the builder +// after Build() has been called. +func (a *ArrowAppender) Grow(n int) { + for _, b := range a.RecordBuilder.Fields() { + b.Reserve(n) + } +} + func (a *ArrowAppender) Append(row sql.Row) error { for i, b := range a.RecordBuilder.Fields() { v := row[i] diff --git a/mysqlutil/mysqlutil.go b/mysqlutil/mysqlutil.go new file mode 100644 index 00000000..88f52544 --- /dev/null +++ b/mysqlutil/mysqlutil.go @@ -0,0 +1,51 @@ +package mysqlutil + +import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" +) + +// CauseImplicitCommitBefore returns true if the statement implicitly commits the current transaction: +// https://dev.mysql.com/doc/refman/8.4/en/implicit-commit.html +func CauseImplicitCommitBefore(node sql.Node) bool { + switch node.(type) { + case *plan.StartTransaction, + *plan.LockTables, *plan.UnlockTables: + return true + default: + return CauseImplicitCommitAfter(node) + } +} + +// CauseImplicitCommitAfter returns true if the statement cause an implicit commit after executing: +// https://dev.mysql.com/doc/refman/8.4/en/implicit-commit.html +func CauseImplicitCommitAfter(node sql.Node) bool { + switch node.(type) { + case *plan.CreateDB, *plan.DropDB, *plan.AlterDB, + *plan.CreateTable, *plan.DropTable, *plan.RenameTable, + *plan.AddColumn, *plan.RenameColumn, *plan.DropColumn, *plan.ModifyColumn, *plan.AlterDefaultSet, *plan.AlterDefaultDrop, + *plan.Truncate, + *plan.AnalyzeTable, + *plan.CreateIndex, *plan.DropIndex, *plan.AlterIndex, + *plan.CreateView, *plan.DropView, + *plan.DropRole, *plan.CreateRole, + *plan.AlterUser, *plan.CreateUser, *plan.DropUser, *plan.Grant, *plan.RenameUser, + *plan.StartReplica, *plan.StopReplica, *plan.ResetReplica, *plan.ChangeReplicationSource: + return true + default: + return false + } +} + +func CauseSchemaChange(node sql.Node) bool { + switch node.(type) { + case *plan.CreateDB, *plan.DropDB, *plan.AlterDB, + *plan.CreateTable, *plan.DropTable, *plan.RenameTable, + *plan.AddColumn, *plan.RenameColumn, *plan.DropColumn, *plan.ModifyColumn, *plan.AlterDefaultSet, *plan.AlterDefaultDrop, + *plan.CreateIndex, *plan.DropIndex, *plan.AlterIndex, + *plan.CreateView, *plan.DropView: + return true + default: + return false + } +} diff --git a/replica/appender.go b/replica/appender.go deleted file mode 100644 index 0a132def..00000000 --- a/replica/appender.go +++ /dev/null @@ -1,86 +0,0 @@ -package replica - -import ( - "database/sql/driver" - - "github.com/apecloud/myduckserver/binlogreplication" - "github.com/dolthub/go-mysql-server/sql" - "github.com/marcboeker/go-duckdb" -) - -func (twp *tableWriterProvider) newTableAppender( - ctx *sql.Context, - databaseName, tableName string, - columnCount int, -) (*tableAppender, error) { - connector := twp.pool.Connector() - conn, err := connector.Connect(ctx.Context) - if err != nil { - connector.Close() - return nil, err - } - - txn, err := conn.(driver.ConnBeginTx).BeginTx(ctx.Context, driver.TxOptions{}) - if err != nil { - conn.Close() - connector.Close() - return nil, err - } - - appender, err := duckdb.NewAppenderFromConn(conn, databaseName, tableName) - if err != nil { - txn.Rollback() - conn.Close() - connector.Close() - return nil, err - } - - return &tableAppender{ - connector: connector, - conn: conn, - txn: txn, - appender: appender, - buffer: make([]driver.Value, columnCount), - }, nil -} - -type tableAppender struct { - connector *duckdb.Connector - conn driver.Conn - txn driver.Tx - appender *duckdb.Appender - buffer []driver.Value -} - -var _ binlogreplication.TableWriter = &tableAppender{} - -func (ta *tableAppender) Insert(ctx *sql.Context, rows []sql.Row) error { - for _, row := range rows { - for i, v := range row { - ta.buffer[i] = v - } - } - return ta.appender.AppendRow(ta.buffer...) -} - -func (ta *tableAppender) Delete(ctx *sql.Context, keyRows []sql.Row) error { - panic("not implemented") -} - -func (ta *tableAppender) Update(ctx *sql.Context, keyRows []sql.Row, valueRows []sql.Row) error { - panic("not implemented") -} - -func (ta *tableAppender) Commit() error { - defer ta.connector.Close() - defer ta.conn.Close() - defer ta.txn.Commit() - return ta.appender.Close() -} - -func (ta *tableAppender) Rollback() error { - defer ta.connector.Close() - defer ta.conn.Close() - defer ta.txn.Rollback() - return ta.appender.Close() -} diff --git a/replica/replication.go b/replica/replication.go index 6af7b000..a543edaf 100644 --- a/replica/replication.go +++ b/replica/replication.go @@ -15,6 +15,7 @@ package replica import ( "context" + stdsql "database/sql" sqle "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/memory" @@ -26,11 +27,12 @@ import ( "github.com/apecloud/myduckserver/binlog" "github.com/apecloud/myduckserver/binlogreplication" "github.com/apecloud/myduckserver/catalog" + "github.com/apecloud/myduckserver/delta" ) // registerReplicaController registers the replica controller into the engine // to handle the replication commands, such as START REPLICA, STOP REPLICA, etc. -func RegisterReplicaController(provider *catalog.DatabaseProvider, engine *sqle.Engine, pool *backend.ConnectionPool) { +func RegisterReplicaController(provider *catalog.DatabaseProvider, engine *sqle.Engine, pool *backend.ConnectionPool, builder *backend.DuckBuilder) { replica := binlogreplication.MyBinlogReplicaController replica.SetEngine(engine) @@ -40,8 +42,10 @@ func RegisterReplicaController(provider *catalog.DatabaseProvider, engine *sqle. replica.SetExecutionContext(ctx) twp := &tableWriterProvider{pool: pool} - twp.delta.pool = pool + twp.controller = delta.NewController(pool) + replica.SetTableWriterProvider(twp) + builder.FlushDeltaBuffer = nil // TODO: implement this engine.Analyzer.Catalog.BinlogReplicaController = binlogreplication.MyBinlogReplicaController @@ -52,14 +56,15 @@ func RegisterReplicaController(provider *catalog.DatabaseProvider, engine *sqle. } type tableWriterProvider struct { - pool *backend.ConnectionPool - delta DeltaController + pool *backend.ConnectionPool + controller *delta.DeltaController } var _ binlogreplication.TableWriterProvider = &tableWriterProvider{} func (twp *tableWriterProvider) GetTableWriter( - ctx *sql.Context, engine *sqle.Engine, + ctx *sql.Context, + txn *stdsql.Tx, databaseName, tableName string, schema sql.PrimaryKeySchema, columnCount, rowCount int, @@ -67,20 +72,22 @@ func (twp *tableWriterProvider) GetTableWriter( eventType binlog.RowEventType, foreignKeyChecksDisabled bool, ) (binlogreplication.TableWriter, error) { - // if eventType == binlogreplication.InsertEvent { - // return twp.newTableAppender(ctx, databaseName, tableName, columnCount) - // } - return twp.newTableUpdater(ctx, databaseName, tableName, schema, columnCount, rowCount, identifyColumns, dataColumns, eventType) + return twp.newTableUpdater(ctx, txn, databaseName, tableName, schema, columnCount, rowCount, identifyColumns, dataColumns, eventType) } func (twp *tableWriterProvider) GetDeltaAppender( - ctx *sql.Context, engine *sqle.Engine, + ctx *sql.Context, databaseName, tableName string, schema sql.Schema, ) (binlogreplication.DeltaAppender, error) { - return twp.delta.GetDeltaAppender(databaseName, tableName, schema) + return twp.controller.GetDeltaAppender(databaseName, tableName, schema) +} + +func (twp *tableWriterProvider) FlushDeltaBuffer(ctx *sql.Context, tx *stdsql.Tx, reason delta.FlushReason) error { + _, err := twp.controller.Flush(ctx, tx, reason) + return err } -func (twp *tableWriterProvider) FlushDelta(ctx *sql.Context) error { - return twp.delta.Flush(ctx) +func (twp *tableWriterProvider) DiscardDeltaBuffer(ctx *sql.Context) { + twp.controller.Close() } diff --git a/replica/updater.go b/replica/updater.go index 42382a63..bf982f0b 100644 --- a/replica/updater.go +++ b/replica/updater.go @@ -43,6 +43,7 @@ func getPrimaryKeyIndices(schema sql.Schema, columns mysql.Bitmap) []int { func (twp *tableWriterProvider) newTableUpdater( ctx *sql.Context, + txn *stdsql.Tx, databaseName, tableName string, pkSchema sql.PrimaryKeySchema, columnCount, rowCount int, @@ -114,20 +115,13 @@ func (twp *tableWriterProvider) newTableUpdater( "pkUpdate": pkUpdate, }).Infoln("Creating table updater...") - tx, err := twp.pool.BeginTx(ctx, nil) + stmt, err := txn.PrepareContext(ctx.Context, sql) if err != nil { return nil, err } - stmt, err := tx.PrepareContext(ctx.Context, sql) - if err != nil { - tx.Rollback() - return nil, err - } - return &tableUpdater{ pool: twp.pool, - tx: tx, stmt: stmt, replace: replace, cleanup: cleanup, @@ -350,14 +344,6 @@ func (tu *tableUpdater) doInsertThenDelete(ctx *sql.Context, beforeRows []sql.Ro return nil } -func (tu *tableUpdater) Commit() error { - return tu.tx.Commit() -} - -func (tu *tableUpdater) Rollback() error { - return tu.tx.Rollback() -} - func quoteIdentifier(identifier string) string { return `"` + identifier + `"` }