Skip to content

Commit

Permalink
feat: flush the delta buffer to duckdb in batched transaction (apeclo…
Browse files Browse the repository at this point in the history
…ud#93)

* feat: flush the delta buffer only if necessary
* Fix INSERT then DELETE in the same transaction
* Transaction boundary handling
* Clear out the states when retrying the connection
* conditional logging
  • Loading branch information
fanyang01 authored Oct 14, 2024
1 parent 28ed15f commit 5598311
Show file tree
Hide file tree
Showing 23 changed files with 979 additions and 452 deletions.
64 changes: 58 additions & 6 deletions adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,47 +9,99 @@ import (

type ConnectionHolder interface {
GetConn(ctx context.Context) (*stdsql.Conn, error)
GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error)
GetCatalogConn(ctx context.Context) (*stdsql.Conn, error)
GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error)
TryGetTxn() *stdsql.Tx
CloseTxn()
}

func GetConn(ctx *sql.Context) (*stdsql.Conn, error) {
return ctx.Session.(ConnectionHolder).GetConn(ctx)
}

func QueryContext(ctx *sql.Context, query string, args ...any) (*stdsql.Rows, error) {
func GetTxn(ctx *sql.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return ctx.Session.(ConnectionHolder).GetTxn(ctx, options)
}

func GetCatalogTxn(ctx *sql.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return ctx.Session.(ConnectionHolder).GetCatalogTxn(ctx, options)
}

func TryGetTxn(ctx *sql.Context) *stdsql.Tx {
return ctx.Session.(ConnectionHolder).TryGetTxn()
}

func CloseTxn(ctx *sql.Context) {
ctx.Session.(ConnectionHolder).CloseTxn()
}

func Query(ctx *sql.Context, query string, args ...any) (*stdsql.Rows, error) {
conn, err := GetConn(ctx)
if err != nil {
return nil, err
}
return conn.QueryContext(ctx, query, args...)
}

// QueryCatalogContext is a helper function to query the catalog, such as information_schema.
func QueryRow(ctx *sql.Context, query string, args ...any) *stdsql.Row {
conn, err := GetConn(ctx)
if err != nil {
return nil
}
return conn.QueryRowContext(ctx, query, args...)
}

// QueryCatalog is a helper function to query the catalog, such as information_schema.
// Unlike QueryContext, this function does not require a schema name to be set on the connection,
// and the current schema of the connection does not matter.
func QueryCatalogContext(ctx *sql.Context, query string, args ...any) (*stdsql.Rows, error) {
func QueryCatalog(ctx *sql.Context, query string, args ...any) (*stdsql.Rows, error) {
conn, err := ctx.Session.(ConnectionHolder).GetCatalogConn(ctx)
if err != nil {
return nil, err
}
return conn.QueryContext(ctx, query, args...)
}

func ExecContext(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) {
func QueryRowCatalog(ctx *sql.Context, query string, args ...any) *stdsql.Row {
conn, err := ctx.Session.(ConnectionHolder).GetCatalogConn(ctx)
if err != nil {
return nil
}
return conn.QueryRowContext(ctx, query, args...)
}

func Exec(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) {
conn, err := GetConn(ctx)
if err != nil {
return nil, err
}
return conn.ExecContext(ctx, query, args...)
}

// ExecCatalogContext is a helper function to execute a catalog modification query, such as creating a database.
// ExecCatalog is a helper function to execute a catalog modification query, such as creating a database.
// Unlike ExecContext, this function does not require a schema name to be set on the connection,
// and the current schema of the connection does not matter.
func ExecCatalogContext(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) {
func ExecCatalog(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) {
conn, err := ctx.Session.(ConnectionHolder).GetCatalogConn(ctx)
if err != nil {
return nil, err
}
return conn.ExecContext(ctx, query, args...)
}

func ExecCatalogInTxn(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) {
tx, err := ctx.Session.(ConnectionHolder).GetCatalogTxn(ctx, nil)
if err != nil {
return nil, err
}
return tx.ExecContext(ctx, query, args...)
}

func ExecInTxn(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) {
tx, err := GetTxn(ctx, nil)
if err != nil {
return nil, err
}
return tx.ExecContext(ctx, query, args...)
}
8 changes: 8 additions & 0 deletions backend/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ func (p *ConnectionPool) GetTxn(ctx context.Context, id uint32, schemaName strin
return tx, nil
}

func (p *ConnectionPool) TryGetTxn(id uint32) *stdsql.Tx {
entry, ok := p.txns.Load(id)
if !ok {
return nil
}
return entry.(*stdsql.Tx)
}

func (p *ConnectionPool) CloseTxn(id uint32) {
p.txns.Delete(id)
}
Expand Down
22 changes: 16 additions & 6 deletions backend/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import (
type DuckBuilder struct {
base sql.NodeExecBuilder
pool *ConnectionPool

FlushDeltaBuffer func() error
}

var _ sql.NodeExecBuilder = (*DuckBuilder)(nil)
Expand All @@ -43,6 +45,14 @@ func NewDuckBuilder(base sql.NodeExecBuilder, pool *ConnectionPool) *DuckBuilder
}

func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.RowIter, error) {
// Flush the delta buffer before executing the query.
// TODO(fan): Be fine-grained and flush only when the replicated tables are touched.
if b.FlushDeltaBuffer != nil {
if err := b.FlushDeltaBuffer(); err != nil {
return nil, err
}
}

n := root
qp, ok := n.(*plan.QueryProcess)
if ok {
Expand All @@ -59,7 +69,7 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
ctx.GetLogger().WithFields(logrus.Fields{
"Query": ctx.Query(),
"NodeType": fmt.Sprintf("%T", n),
}).Infoln("Building node:", n)
}).Trace("Building node:", n)

// TODO; find a better way to fallback to the base builder
switch n.(type) {
Expand Down Expand Up @@ -130,7 +140,7 @@ func (b *DuckBuilder) executeExpressioner(ctx *sql.Context, n sql.Expressioner,
}

func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Conn) (sql.RowIter, error) {
logrus.Infoln("Executing Query...")
ctx.GetLogger().Trace("Executing Query...")

var (
duckSQL string
Expand All @@ -148,10 +158,10 @@ func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Co
return nil, catalog.ErrTranspiler.New(err)
}

logrus.WithFields(logrus.Fields{
ctx.GetLogger().WithFields(logrus.Fields{
"Query": ctx.Query(),
"DuckSQL": duckSQL,
}).Infoln("Executing Query...")
}).Trace("Executing Query...")

// Execute the DuckDB query
rows, err := conn.QueryContext(ctx.Context, duckSQL)
Expand All @@ -169,10 +179,10 @@ func (b *DuckBuilder) executeDML(ctx *sql.Context, conn *stdsql.Conn) (sql.RowIt
return nil, catalog.ErrTranspiler.New(err)
}

logrus.WithFields(logrus.Fields{
ctx.GetLogger().WithFields(logrus.Fields{
"Query": ctx.Query(),
"DuckSQL": duckSQL,
}).Infoln("Executing DML...")
}).Trace("Executing DML...")

// Execute the DuckDB query
result, err := conn.ExecContext(ctx.Context, duckSQL)
Expand Down
38 changes: 23 additions & 15 deletions backend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ func NewSession(base *memory.Session, provider *catalog.DatabaseProvider, pool *

// NewSessionBuilder returns a session builder for the given database provider.
func NewSessionBuilder(provider *catalog.DatabaseProvider, pool *ConnectionPool) func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
_, err := pool.Exec("CREATE TABLE IF NOT EXISTS main.persistent_variables (name TEXT PRIMARY KEY, value TEXT, type TEXT)")
if err != nil {
panic(err)
}

return func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
host := ""
user := ""
Expand Down Expand Up @@ -73,7 +68,7 @@ var _ sql.Transaction = (*Transaction)(nil)

// StartTransaction implements sql.TransactionSession.
func (sess Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.TransactionCharacteristic) (sql.Transaction, error) {
sess.GetLogger().Infoln("StartTransaction")
sess.GetLogger().Trace("StartTransaction")
base, err := sess.Session.StartTransaction(ctx, tCharacteristic)
if err != nil {
return nil, err
Expand All @@ -93,7 +88,7 @@ func (sess Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.Trans

var tx *stdsql.Tx
if startUnderlyingTx {
sess.GetLogger().Infoln("StartDuckTransaction")
sess.GetLogger().Trace("StartDuckTransaction")
tx, err = sess.GetTxn(ctx, &stdsql.TxOptions{ReadOnly: tCharacteristic == sql.ReadOnly})
if err != nil {
return nil, err
Expand All @@ -104,10 +99,10 @@ func (sess Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.Trans

// CommitTransaction implements sql.TransactionSession.
func (sess Session) CommitTransaction(ctx *sql.Context, tx sql.Transaction) error {
sess.GetLogger().Infoln("CommitTransaction")
sess.GetLogger().Trace("CommitTransaction")
transaction := tx.(*Transaction)
if transaction.tx != nil {
sess.GetLogger().Infoln("CommitDuckTransaction")
sess.GetLogger().Trace("CommitDuckTransaction")
defer sess.CloseTxn()
if err := transaction.tx.Commit(); err != nil {
return err
Expand All @@ -118,10 +113,10 @@ func (sess Session) CommitTransaction(ctx *sql.Context, tx sql.Transaction) erro

// Rollback implements sql.TransactionSession.
func (sess Session) Rollback(ctx *sql.Context, tx sql.Transaction) error {
sess.GetLogger().Infoln("Rollback")
sess.GetLogger().Trace("Rollback")
transaction := tx.(*Transaction)
if transaction.tx != nil {
sess.GetLogger().Infoln("RollbackDuckTransaction")
sess.GetLogger().Trace("RollbackDuckTransaction")
defer sess.CloseTxn()
if err := transaction.tx.Rollback(); err != nil {
return err
Expand All @@ -137,7 +132,7 @@ func (sess Session) PersistGlobal(sysVarName string, value interface{}) error {
}
_, err := sess.ExecContext(
context.Background(),
"INSERT OR REPLACE INTO main.persistent_variables (name, value, vtype) VALUES (?, ?, ?)",
catalog.InternalTables.PersistentVariable.UpsertStmt(),
sysVarName, value, fmt.Sprintf("%T", value),
)
return err
Expand All @@ -147,15 +142,15 @@ func (sess Session) PersistGlobal(sysVarName string, value interface{}) error {
func (sess Session) RemovePersistedGlobal(sysVarName string) error {
_, err := sess.ExecContext(
context.Background(),
"DELETE FROM main.persistent_variables WHERE name = ?",
catalog.InternalTables.PersistentVariable.DeleteStmt(),
sysVarName,
)
return err
}

// RemoveAllPersistedGlobals implements sql.PersistableSession.
func (sess Session) RemoveAllPersistedGlobals() error {
_, err := sess.ExecContext(context.Background(), "DELETE FROM main.persistent_variables")
_, err := sess.ExecContext(context.Background(), "DELETE FROM "+catalog.InternalTables.PersistentVariable.Name)
return err
}

Expand All @@ -164,7 +159,8 @@ func (sess Session) GetPersistedValue(k string) (interface{}, error) {
var value, vtype string
err := sess.QueryRow(
context.Background(),
"SELECT value, vtype FROM main.persistent_variables WHERE name = ?", k,
catalog.InternalTables.PersistentVariable.SelectStmt(),
k,
).Scan(&value, &vtype)
switch {
case err == stdsql.ErrNoRows:
Expand Down Expand Up @@ -195,10 +191,22 @@ func (sess Session) GetCatalogConn(ctx context.Context) (*stdsql.Conn, error) {
return sess.pool.GetConn(ctx, sess.ID())
}

// GetTxn implements adapter.ConnectionHolder.
func (sess Session) GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return sess.pool.GetTxn(ctx, sess.ID(), sess.GetCurrentDatabase(), options)
}

// GetCatalogTxn implements adapter.ConnectionHolder.
func (sess Session) GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return sess.pool.GetTxn(ctx, sess.ID(), "", options)
}

// TryGetTxn implements adapter.ConnectionHolder.
func (sess Session) TryGetTxn() *stdsql.Tx {
return sess.pool.TryGetTxn(sess.ID())
}

// CloseTxn implements adapter.ConnectionHolder.
func (sess Session) CloseTxn() {
sess.pool.CloseTxn(sess.ID())
}
Expand Down
Loading

0 comments on commit 5598311

Please sign in to comment.