Skip to content

Commit

Permalink
Merge pull request #8753 from dolthub/db/async
Browse files Browse the repository at this point in the history
make autoincrement tracker load async
  • Loading branch information
coffeegoddd authored Jan 16, 2025
2 parents 30258b9 + cabcb31 commit 57effdc
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 9 deletions.
75 changes: 70 additions & 5 deletions go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ package dsess

import (
"context"
"errors"
"io"
"math"
"strings"
"sync"
"time"

"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
Expand Down Expand Up @@ -48,6 +50,8 @@ type AutoIncrementTracker struct {
sequences *sync.Map // map[string]uint64
mm *mutexmap.MutexMap
lockMode LockMode
init chan struct{}
initErr error
}

var _ globalstate.AutoIncrementTracker = &AutoIncrementTracker{}
Expand All @@ -61,8 +65,9 @@ func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb
dbName: dbName,
sequences: &sync.Map{},
mm: mutexmap.NewMutexMap(),
init: make(chan struct{}),
}
ait.InitWithRoots(ctx, roots...)
ait.runInitWithRootsAsync(ctx, roots...)
return &ait, nil
}

Expand All @@ -76,13 +81,22 @@ func loadAutoIncValue(sequences *sync.Map, tableName string) uint64 {
}

// Current returns the next value to be generated in the auto increment sequence for the table named
func (a *AutoIncrementTracker) Current(tableName string) uint64 {
return loadAutoIncValue(a.sequences, tableName)
func (a *AutoIncrementTracker) Current(tableName string) (uint64, error) {
err := a.waitForInit()
if err != nil {
return 0, err
}
return loadAutoIncValue(a.sequences, tableName), nil
}

// Next returns the next auto increment value for the table named using the provided value from an insert (which may
// be null or 0, in which case it will be generated from the sequence).
func (a *AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
err := a.waitForInit()
if err != nil {
return 0, err
}

tbl = strings.ToLower(tbl)

given, err := CoerceAutoIncrementValue(insertVal)
Expand Down Expand Up @@ -113,6 +127,10 @@ func (a *AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64,
}

func (a *AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) {
err := a.waitForInit()
if err != nil {
return 0, err
}
return CoerceAutoIncrementValue(val)
}

Expand Down Expand Up @@ -140,6 +158,11 @@ func CoerceAutoIncrementValue(val interface{}) (uint64, error) {
// table. Otherwise, the update is silently disregarded. So far this matches the MySQL behavior, but Dolt uses the
// maximum value for this table across all branches.
func (a *AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
err := a.waitForInit()
if err != nil {
return nil, err
}

tableName = strings.ToLower(tableName)

release := a.mm.Lock(tableName)
Expand Down Expand Up @@ -338,16 +361,27 @@ func getMaxIndexValue(ctx context.Context, indexData durable.Index) (uint64, err
}

// AddNewTable initializes a new table with an auto increment column to the tracker, as necessary
func (a *AutoIncrementTracker) AddNewTable(tableName string) {
func (a *AutoIncrementTracker) AddNewTable(tableName string) error {
err := a.waitForInit()
if err != nil {
return err
}

tableName = strings.ToLower(tableName)
// only initialize the sequence for this table if no other branch has such a table
a.sequences.LoadOrStore(tableName, uint64(1))
return nil
}

// DropTable drops the table with the name given.
// To establish the new auto increment value, callers must also pass all other working sets in scope that may include
// a table with the same name, omitting the working set that just deleted the table named.
func (a *AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error {
err := a.waitForInit()
if err != nil {
return err
}

tableName = strings.ToLower(tableName)

release := a.mm.Lock(tableName)
Expand Down Expand Up @@ -389,6 +423,11 @@ func (a *AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wse
}

func (a *AutoIncrementTracker) AcquireTableLock(ctx *sql.Context, tableName string) (func(), error) {
err := a.waitForInit()
if err != nil {
return nil, err
}

_, i, _ := sql.SystemVariables.GetGlobal("innodb_autoinc_lock_mode")
lockMode := LockMode(i.(int64))
if lockMode == LockMode_Interleaved {
Expand All @@ -398,7 +437,23 @@ func (a *AutoIncrementTracker) AcquireTableLock(ctx *sql.Context, tableName stri
return a.mm.Lock(tableName), nil
}

func (a *AutoIncrementTracker) InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error {
func (a *AutoIncrementTracker) waitForInit() error {
select {
case <-a.init:
return a.initErr
case <-time.After(5 * time.Minute):
return errors.New("failed to initialize autoincrement tracker")
}
}

func (a *AutoIncrementTracker) runInitWithRootsAsync(ctx context.Context, roots ...doltdb.Rootish) {
go func() {
defer close(a.init)
a.initErr = a.initWithRoots(ctx, roots...)
}()
}

func (a *AutoIncrementTracker) initWithRoots(ctx context.Context, roots ...doltdb.Rootish) error {
eg, egCtx := errgroup.WithContext(ctx)
eg.SetLimit(128)

Expand Down Expand Up @@ -435,3 +490,13 @@ func (a *AutoIncrementTracker) InitWithRoots(ctx context.Context, roots ...doltd

return eg.Wait()
}

func (a *AutoIncrementTracker) InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error {
err := a.waitForInit()
if err != nil {
return err
}
a.init = make(chan struct{})
a.runInitWithRootsAsync(ctx, roots...)
return a.waitForInit()
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ import (
// interface here because implementations need to reach into session state, requiring a dependency on this package.
type AutoIncrementTracker interface {
// Current returns the current auto increment value for the given table.
Current(tableName string) uint64
Current(tableName string) (uint64, error)
// Next returns the next auto increment value for the given table, and increments the current value.
Next(tbl string, insertVal interface{}) (uint64, error)
// AddNewTable adds a new table to the tracker, initializing the auto increment value to 1.
AddNewTable(tableName string)
AddNewTable(tableName string) error
// DropTable removes a table from the tracker.
DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error
// CoerceAutoIncrementValue coerces the given value to a uint64, returning an error if it can't be done.
Expand Down
5 changes: 4 additions & 1 deletion go/libraries/doltcore/sqle/writer/noms_write_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ func (s *nomsWriteSession) flush(ctx *sql.Context) (*doltdb.WorkingSet, error) {
// Update the auto increment value for the table if a tracker was provided
// TODO: the table probably needs an autoincrement tracker no matter what
if schema.HasAutoIncrement(ed.Schema()) {
v := s.aiTracker.Current(name)
v, err := s.aiTracker.Current(name)
if err != nil {
return err
}
tbl, err = tbl.SetAutoIncrementValue(ctx, v)
if err != nil {
return err
Expand Down
5 changes: 4 additions & 1 deletion go/libraries/doltcore/sqle/writer/prolly_write_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ func (s *prollyWriteSession) flush(ctx *sql.Context, autoIncSet bool, manualAuto
// override was specified (e.g. if the next value was set explicitly)
if schema.HasAutoIncrement(wr.sch) {
// TODO: need schema name for auto increment
autoIncVal := s.aiTracker.Current(name.Name)
autoIncVal, err := s.aiTracker.Current(name.Name)
if err != nil {
return err
}
override, hasManuallySetAi := manualAutoIncrementsSettings[name.Name]
if hasManuallySetAi {
autoIncVal = override
Expand Down

0 comments on commit 57effdc

Please sign in to comment.