Skip to content

Commit

Permalink
Merge pull request #8820 from dolthub/aaron/doltdb-commit-hooks-sqle
Browse files Browse the repository at this point in the history
[no-release-notes] go: sqle,doltdb: Make commit hooks take a DoltDB instead of a datas.Database. Move them to sqle.
  • Loading branch information
reltuk authored Feb 11, 2025
2 parents b4fef81 + c248410 commit 529e51e
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 88 deletions.
2 changes: 1 addition & 1 deletion go/cmd/dolt/commands/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ func TestCommitHooksNoErrors(t *testing.T) {
t.Error("failed to produce noop hook")
} else {
switch h := hooks[0].(type) {
case *doltdb.LogHook:
case *sqle.LogHook:
default:
t.Errorf("expected LogHook, found: %s", h)
}
Expand Down
5 changes: 3 additions & 2 deletions go/cmd/dolt/commands/sqlserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,14 @@ func TestServerArgs(t *testing.T) {
}

func TestDeprecatedUserPasswordServerArgs(t *testing.T) {
ctx := context.Background()
controller := svcs.NewController()
dEnv, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
defer func() {
assert.NoError(t, dEnv.DoltDB(context.Background()).Close())
assert.NoError(t, dEnv.DoltDB(ctx).Close())
}()
err = StartServer(context.Background(), "0.0.0", "dolt sql-server", []string{
err = StartServer(ctx, "0.0.0", "dolt sql-server", []string{
"-H", "localhost",
"-P", "15200",
"-u", "username",
Expand Down
8 changes: 6 additions & 2 deletions go/libraries/doltcore/doltdb/doltdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ func DoltDBFromCS(cs chunks.ChunkStore, databaseName string) *DoltDB {
ns := tree.NewNodeStore(cs)
db := datas.NewTypesDatabase(vrw, ns)

return &DoltDB{db: hooksDatabase{Database: db}, vrw: vrw, ns: ns, databaseName: databaseName}
ret := &DoltDB{db: hooksDatabase{Database: db}, vrw: vrw, ns: ns, databaseName: databaseName}
ret.db.db = ret
return ret
}

// GetDatabaseName returns the name of the database.
Expand Down Expand Up @@ -148,7 +150,9 @@ func LoadDoltDBWithParams(ctx context.Context, nbf *types.NomsBinFormat, urlStr
return nil, err
}

return &DoltDB{db: hooksDatabase{Database: db}, vrw: vrw, ns: ns, databaseName: name}, nil
ret := &DoltDB{db: hooksDatabase{Database: db}, vrw: vrw, ns: ns, databaseName: name}
ret.db.db = ret
return ret, nil
}

// NomsRoot returns the hash of the noms dataset map
Expand Down
25 changes: 13 additions & 12 deletions go/libraries/doltcore/doltdb/hooksdatabase.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ import (

type hooksDatabase struct {
datas.Database
postCommitHooks []CommitHook
rsc *ReplicationStatusController
db *DoltDB
hooks []CommitHook
rsc *ReplicationStatusController
}

// CommitHook is an abstraction for executing arbitrary commands after atomic database commits
type CommitHook interface {
// Execute is arbitrary read-only function whose arguments are new Dataset commit into a specific Database
Execute(ctx context.Context, ds datas.Dataset, db datas.Database) (func(context.Context) error, error)
Execute(ctx context.Context, ds datas.Dataset, db *DoltDB) (func(context.Context) error, error)
// HandleError is an bridge function to handle Execute errors
HandleError(ctx context.Context, err error) error
// SetLogger lets clients specify an output stream for HandleError
Expand All @@ -51,13 +52,13 @@ type NotifyWaitFailedCommitHook interface {
}

func (db hooksDatabase) SetCommitHooks(ctx context.Context, postHooks []CommitHook) hooksDatabase {
db.postCommitHooks = make([]CommitHook, len(postHooks))
copy(db.postCommitHooks, postHooks)
db.hooks = make([]CommitHook, len(postHooks))
copy(db.hooks, postHooks)
return db
}

func (db hooksDatabase) SetCommitHookLogger(ctx context.Context, wr io.Writer) hooksDatabase {
for _, h := range db.postCommitHooks {
for _, h := range db.hooks {
h.SetLogger(ctx, wr)
}
return db
Expand All @@ -69,8 +70,8 @@ func (db hooksDatabase) withReplicationStatusController(rsc *ReplicationStatusCo
}

func (db hooksDatabase) PostCommitHooks() []CommitHook {
toret := make([]CommitHook, len(db.postCommitHooks))
copy(toret, db.postCommitHooks)
toret := make([]CommitHook, len(db.hooks))
copy(toret, db.hooks)
return toret
}

Expand All @@ -80,17 +81,17 @@ func (db hooksDatabase) ExecuteCommitHooks(ctx context.Context, ds datas.Dataset
var ioff int
if rsc != nil {
ioff = len(rsc.Wait)
rsc.Wait = append(rsc.Wait, make([]func(context.Context) error, len(db.postCommitHooks))...)
rsc.NotifyWaitFailed = append(rsc.NotifyWaitFailed, make([]func(), len(db.postCommitHooks))...)
rsc.Wait = append(rsc.Wait, make([]func(context.Context) error, len(db.hooks))...)
rsc.NotifyWaitFailed = append(rsc.NotifyWaitFailed, make([]func(), len(db.hooks))...)
}
for il, hook := range db.postCommitHooks {
for il, hook := range db.hooks {
if !onlyWS || hook.ExecuteForWorkingSets() {
i := il
hook := hook
wg.Add(1)
go func() {
defer wg.Done()
f, err := hook.Execute(ctx, ds, db)
f, err := hook.Execute(ctx, ds, db.db)
if err != nil {
hook.HandleError(ctx, err)
}
Expand Down
5 changes: 2 additions & 3 deletions go/libraries/doltcore/sqle/cluster/commithook.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,10 @@ var errDetectedBrokenConfigStr = "error: more than one server was configured as

// Execute on this commithook updates the target root hash we're attempting to
// replicate and wakes the replication thread.
func (h *commithook) Execute(ctx context.Context, ds datas.Dataset, db datas.Database) (func(context.Context) error, error) {
func (h *commithook) Execute(ctx context.Context, ds datas.Dataset, db *doltdb.DoltDB) (func(context.Context) error, error) {
lgr := h.logger()
lgr.Tracef("cluster/commithook: Execute called post commit")
cs := datas.ChunkStoreFromDatabase(db)
root, err := cs.Root(ctx)
root, err := db.NomsRoot(ctx)
if err != nil {
lgr.Errorf("cluster/commithook: Execute: error retrieving local database root: %v", err)
return nil, err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package doltdb
package sqle

import (
"context"
Expand All @@ -23,44 +23,43 @@ import (

"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/store/datas"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/types"
)

type PushOnWriteHook struct {
destDB datas.Database
destDB *doltdb.DoltDB
tmpDir string
out io.Writer
fmt *types.NomsBinFormat
}

var _ CommitHook = (*PushOnWriteHook)(nil)
var _ doltdb.CommitHook = (*PushOnWriteHook)(nil)

// NewPushOnWriteHook creates a ReplicateHook, parameterizaed by the backup database
// and a local tempfile for pushing
func NewPushOnWriteHook(destDB *DoltDB, tmpDir string) *PushOnWriteHook {
func NewPushOnWriteHook(destDB *doltdb.DoltDB, tmpDir string) *PushOnWriteHook {
return &PushOnWriteHook{
destDB: destDB.db,
destDB: destDB,
tmpDir: tmpDir,
fmt: destDB.Format(),
}
}

// Execute implements CommitHook, replicates head updates to the destDb field
func (ph *PushOnWriteHook) Execute(ctx context.Context, ds datas.Dataset, db datas.Database) (func(context.Context) error, error) {
func (ph *PushOnWriteHook) Execute(ctx context.Context, ds datas.Dataset, db *doltdb.DoltDB) (func(context.Context) error, error) {
return nil, pushDataset(ctx, ph.destDB, db, ds, ph.tmpDir)
}

func pushDataset(ctx context.Context, destDB, srcDB datas.Database, ds datas.Dataset, tmpDir string) error {
func pushDataset(ctx context.Context, destDB, srcDB *doltdb.DoltDB, ds datas.Dataset, tmpDir string) error {
addr, ok := ds.MaybeHeadAddr()
if !ok {
_, err := destDB.Delete(ctx, ds, "")
// TODO: fix up hack usage.
_, err := doltdb.HackDatasDatabaseFromDoltDB(destDB).Delete(ctx, ds, "")
return err
}

err := pullHash(ctx, destDB, srcDB, []hash.Hash{addr}, tmpDir, nil, nil)
err := destDB.PullChunks(ctx, tmpDir, srcDB, []hash.Hash{addr}, nil, nil)
if err != nil {
return err
}
Expand All @@ -70,13 +69,7 @@ func pushDataset(ctx context.Context, destDB, srcDB datas.Database, ds datas.Dat
return err
}

ds, err = destDB.GetDataset(ctx, rf.String())
if err != nil {
return err
}

_, err = destDB.SetHead(ctx, ds, addr, "")
return err
return destDB.SetHead(ctx, rf, addr)
}

// HandleError implements CommitHook
Expand All @@ -102,7 +95,7 @@ func (ph *PushOnWriteHook) SetLogger(ctx context.Context, wr io.Writer) error {

type PushArg struct {
ds datas.Dataset
db datas.Database
db *doltdb.DoltDB
hash hash.Hash
}

Expand All @@ -118,10 +111,10 @@ const (
asyncPushSyncReplica = "async_push_sync_replica"
)

var _ CommitHook = (*AsyncPushOnWriteHook)(nil)
var _ doltdb.CommitHook = (*AsyncPushOnWriteHook)(nil)

// NewAsyncPushOnWriteHook creates a AsyncReplicateHook
func NewAsyncPushOnWriteHook(bThreads *sql.BackgroundThreads, destDB *DoltDB, tmpDir string, logger io.Writer) (*AsyncPushOnWriteHook, error) {
func NewAsyncPushOnWriteHook(bThreads *sql.BackgroundThreads, destDB *doltdb.DoltDB, tmpDir string, logger io.Writer) (*AsyncPushOnWriteHook, error) {
ch := make(chan PushArg, asyncPushBufferSize)
err := RunAsyncReplicationThreads(bThreads, ch, destDB, tmpDir, logger)
if err != nil {
Expand All @@ -135,16 +128,11 @@ func (*AsyncPushOnWriteHook) ExecuteForWorkingSets() bool {
}

// Execute implements CommitHook, replicates head updates to the destDb field
func (ah *AsyncPushOnWriteHook) Execute(ctx context.Context, ds datas.Dataset, db datas.Database) (func(context.Context) error, error) {
func (ah *AsyncPushOnWriteHook) Execute(ctx context.Context, ds datas.Dataset, db *doltdb.DoltDB) (func(context.Context) error, error) {
addr, _ := ds.MaybeHeadAddr()

select {
case ah.ch <- PushArg{ds: ds, db: db, hash: addr}:
case <-ctx.Done():
ah.ch <- PushArg{ds: ds, db: db, hash: addr}
return nil, ctx.Err()
}
return nil, nil
// TODO: Unconditional push here seems dangerous.
ah.ch <- PushArg{ds: ds, db: db, hash: addr}
return nil, ctx.Err()
}

// HandleError implements CommitHook
Expand All @@ -166,15 +154,15 @@ type LogHook struct {
out io.Writer
}

var _ CommitHook = (*LogHook)(nil)
var _ doltdb.CommitHook = (*LogHook)(nil)

// NewLogHook is a noop that logs to a writer when invoked
func NewLogHook(msg []byte) *LogHook {
return &LogHook{msg: msg}
}

// Execute implements CommitHook, writes message to log channel
func (lh *LogHook) Execute(ctx context.Context, ds datas.Dataset, db datas.Database) (func(context.Context) error, error) {
func (lh *LogHook) Execute(ctx context.Context, ds datas.Dataset, db *doltdb.DoltDB) (func(context.Context) error, error) {
if lh.out != nil {
_, err := lh.out.Write(lh.msg)
return nil, err
Expand All @@ -200,7 +188,7 @@ func (*LogHook) ExecuteForWorkingSets() bool {
return false
}

func RunAsyncReplicationThreads(bThreads *sql.BackgroundThreads, ch chan PushArg, destDB *DoltDB, tmpDir string, logger io.Writer) error {
func RunAsyncReplicationThreads(bThreads *sql.BackgroundThreads, ch chan PushArg, destDB *doltdb.DoltDB, tmpDir string, logger io.Writer) error {
mu := &sync.Mutex{}
var newHeads = make(map[string]PushArg, asyncPushBufferSize)

Expand Down Expand Up @@ -259,7 +247,7 @@ func RunAsyncReplicationThreads(bThreads *sql.BackgroundThreads, ch chan PushArg
for id, newCm := range newHeadsCopy {
if latest, ok := latestHeads[id]; !ok || latest != newCm.hash {
// use background context to drain after sql context is canceled
err := pushDataset(context.Background(), destDB.db, newCm.db, newCm.ds, tmpDir)
err := pushDataset(context.Background(), destDB, newCm.db, newCm.ds, tmpDir)
if err != nil {
logger.Write([]byte("replication failed: " + err.Error()))
}
Expand Down
Loading

0 comments on commit 529e51e

Please sign in to comment.