From 41fe1acfb58c1cd889b7dd63c965f0a22213614d Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Sat, 25 Jan 2025 06:30:03 -0800 Subject: [PATCH 01/34] [no-release-notes] Change some tests to use t.TempDir instead of os.TempDir so that their temporary files get cleaned up. --- go/cmd/dolt/commands/signed_commits_test.go | 13 +++--- .../doltcore/doltdb/commit_hooks_test.go | 8 ++-- go/libraries/doltcore/doltdb/doltdb_test.go | 6 +-- .../doltcore/dtestutils/environment.go | 8 +--- .../dtestutils/sql_server_driver/cmd.go | 8 +--- .../doltcore/env/multi_repo_env_test.go | 6 +-- .../doltcore/sqle/enginetest/dolt_harness.go | 2 +- go/libraries/events/event_flush_test.go | 2 +- go/libraries/utils/filesys/fs_test.go | 10 ++--- go/libraries/utils/test/files.go | 16 ++------ go/libraries/utils/test/test_test.go | 2 +- go/performance/import_benchmarker/cmd/main.go | 10 ++++- go/performance/import_benchmarker/testdef.go | 2 +- go/performance/sysbench/cmd/main.go | 10 ++++- go/performance/sysbench/testdef.go | 2 +- go/store/nbs/byte_sink_test.go | 16 ++++---- go/store/nbs/cmp_chunk_table_writer_test.go | 6 +-- go/store/nbs/journal_writer_test.go | 6 +-- go/store/nbs/mem_table_test.go | 7 +--- go/store/nbs/store_test.go | 3 ++ go/store/valuefile/value_file_test.go | 2 +- .../concurrent_gc_test.go | 2 +- .../gc_oldgen_conjoin_test.go | 2 +- .../go-sql-server-driver/main_test.go | 41 ++++++++++--------- .../sqlserver_info_test.go | 4 +- .../go-sql-server-driver/testdef.go | 4 +- 26 files changed, 96 insertions(+), 102 deletions(-) diff --git a/go/cmd/dolt/commands/signed_commits_test.go b/go/cmd/dolt/commands/signed_commits_test.go index e83d1132bf3..05a44f60f3d 100644 --- a/go/cmd/dolt/commands/signed_commits_test.go +++ b/go/cmd/dolt/commands/signed_commits_test.go @@ -45,10 +45,9 @@ func importKey(t *testing.T, ctx context.Context) { } func setupTestDB(t *testing.T, ctx context.Context, fs filesys.Filesys) string { - dir, err := os.MkdirTemp(os.TempDir(), "signed_commits") - require.NoError(t, err) + dir := t.TempDir() dbDir := filepath.Join(dir, "db") - err = filesys.CopyDir("testdata/signed_commits/db/", dbDir, fs) + err := filesys.CopyDir("testdata/signed_commits/db/", dbDir, fs) require.NoError(t, err) log.Println(dbDir) @@ -90,7 +89,7 @@ func TestSignAndVerifyCommit(t *testing.T) { apr, err := cli.CreateCommitArgParser().Parse(test.commitArgs) require.NoError(t, err) - _, err = execCommand(ctx, dbDir, CommitCmd{}, test.commitArgs, apr, map[string]string{}, global) + _, err = execCommand(ctx, t, dbDir, CommitCmd{}, test.commitArgs, apr, map[string]string{}, global) if test.expectErr { require.Error(t, err) @@ -103,14 +102,14 @@ func TestSignAndVerifyCommit(t *testing.T) { apr, err = cli.CreateLogArgParser(false).Parse(args) require.NoError(t, err) - logOutput, err := execCommand(ctx, dbDir, LogCmd{}, args, apr, map[string]string{}, global) + logOutput, err := execCommand(ctx, t, dbDir, LogCmd{}, args, apr, map[string]string{}, global) require.NoError(t, err) require.Contains(t, logOutput, "Good signature from \"Test User \"") }) } } -func execCommand(ctx context.Context, wd string, cmd cli.Command, args []string, apr *argparser.ArgParseResults, local, global map[string]string) (output string, err error) { +func execCommand(ctx context.Context, t *testing.T, wd string, cmd cli.Command, args []string, apr *argparser.ArgParseResults, local, global map[string]string) (output string, err error) { err = os.Chdir(wd) if err != nil { err = fmt.Errorf("error changing directory to %s: %w", wd, err) @@ -157,7 +156,7 @@ func execCommand(ctx context.Context, wd string, cmd cli.Command, args []string, initialOut := os.Stdout initialErr := os.Stderr - f, err := os.CreateTemp(os.TempDir(), "signed-commit-test-*") + f, err := os.CreateTemp(t.TempDir(), "signed-commit-test-*") if err != nil { err = fmt.Errorf("error creating temp file: %w", err) return diff --git a/go/libraries/doltcore/doltdb/commit_hooks_test.go b/go/libraries/doltcore/doltdb/commit_hooks_test.go index fa20ac9c52e..fae89ab9824 100644 --- a/go/libraries/doltcore/doltdb/commit_hooks_test.go +++ b/go/libraries/doltcore/doltdb/commit_hooks_test.go @@ -43,7 +43,7 @@ func TestPushOnWriteHook(t *testing.T) { ctx := context.Background() // destination repo - testDir, err := test.ChangeToTestDir("TestReplicationDest") + testDir, err := test.ChangeToTestDir(t.TempDir(), "TestReplicationDest") if err != nil { panic("Couldn't change the working directory to the test directory.") @@ -62,7 +62,7 @@ func TestPushOnWriteHook(t *testing.T) { destDB, _ := LoadDoltDB(context.Background(), types.Format_Default, LocalDirDoltDB, filesys.LocalFS) // source repo - testDir, err = test.ChangeToTestDir("TestReplicationSource") + testDir, err = test.ChangeToTestDir(t.TempDir(), "TestReplicationSource") if err != nil { panic("Couldn't change the working directory to the test directory.") @@ -183,7 +183,7 @@ func TestAsyncPushOnWrite(t *testing.T) { ctx := context.Background() // destination repo - testDir, err := test.ChangeToTestDir("TestReplicationDest") + testDir, err := test.ChangeToTestDir(t.TempDir(), "TestReplicationDest") if err != nil { panic("Couldn't change the working directory to the test directory.") @@ -202,7 +202,7 @@ func TestAsyncPushOnWrite(t *testing.T) { destDB, _ := LoadDoltDB(context.Background(), types.Format_Default, LocalDirDoltDB, filesys.LocalFS) // source repo - testDir, err = test.ChangeToTestDir("TestReplicationSource") + testDir, err = test.ChangeToTestDir(t.TempDir(), "TestReplicationSource") if err != nil { panic("Couldn't change the working directory to the test directory.") diff --git a/go/libraries/doltcore/doltdb/doltdb_test.go b/go/libraries/doltcore/doltdb/doltdb_test.go index ef1f09de50e..c04abe1be23 100644 --- a/go/libraries/doltcore/doltdb/doltdb_test.go +++ b/go/libraries/doltcore/doltdb/doltdb_test.go @@ -219,7 +219,7 @@ func TestEmptyInMemoryRepoCreation(t *testing.T) { } func TestLoadNonExistentLocalFSRepo(t *testing.T) { - _, err := test.ChangeToTestDir("TestLoadRepo") + _, err := test.ChangeToTestDir(t.TempDir(), "TestLoadRepo") if err != nil { panic("Couldn't change the working directory to the test directory.") @@ -231,7 +231,7 @@ func TestLoadNonExistentLocalFSRepo(t *testing.T) { } func TestLoadBadLocalFSRepo(t *testing.T) { - testDir, err := test.ChangeToTestDir("TestLoadRepo") + testDir, err := test.ChangeToTestDir(t.TempDir(), "TestLoadRepo") if err != nil { panic("Couldn't change the working directory to the test directory.") @@ -246,7 +246,7 @@ func TestLoadBadLocalFSRepo(t *testing.T) { } func TestLDNoms(t *testing.T) { - testDir, err := test.ChangeToTestDir("TestLoadRepo") + testDir, err := test.ChangeToTestDir(t.TempDir(), "TestLoadRepo") if err != nil { panic("Couldn't change the working directory to the test directory.") diff --git a/go/libraries/doltcore/dtestutils/environment.go b/go/libraries/doltcore/dtestutils/environment.go index c2d1eddf86b..9555afab831 100644 --- a/go/libraries/doltcore/dtestutils/environment.go +++ b/go/libraries/doltcore/dtestutils/environment.go @@ -16,7 +16,6 @@ package dtestutils import ( "context" - "os" "path/filepath" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" @@ -41,12 +40,7 @@ func CreateTestEnv() *env.DoltEnv { // CreateTestEnvForLocalFilesystem creates a new DoltEnv for testing, using a local FS, instead of an in-memory // filesystem, for persisting files. This is useful for tests that require a disk-based filesystem and will not // work correctly with an in-memory filesystem and in-memory blob store (e.g. dolt_undrop() tests). -func CreateTestEnvForLocalFilesystem() *env.DoltEnv { - tempDir, err := os.MkdirTemp(os.TempDir(), "dolt-*") - if err != nil { - panic(err) - } - +func CreateTestEnvForLocalFilesystem(tempDir string) *env.DoltEnv { fs, err := filesys.LocalFilesysWithWorkingDir(tempDir) if err != nil { panic(err) diff --git a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go index 802abf595ca..75c6cd936ad 100644 --- a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go +++ b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go @@ -76,13 +76,9 @@ type DoltUser struct { var _ DoltCmdable = DoltUser{} var _ DoltDebuggable = DoltUser{} -func NewDoltUser() (DoltUser, error) { - tmpdir, err := os.MkdirTemp("", "go-sql-server-driver-") - if err != nil { - return DoltUser{}, err - } +func NewDoltUser(tmpdir string) (DoltUser, error) { res := DoltUser{tmpdir} - err = res.DoltExec("config", "--global", "--add", "metrics.disabled", "true") + err := res.DoltExec("config", "--global", "--add", "metrics.disabled", "true") if err != nil { return DoltUser{}, err } diff --git a/go/libraries/doltcore/env/multi_repo_env_test.go b/go/libraries/doltcore/env/multi_repo_env_test.go index 9a5b6bdb82b..d97baeacd8b 100644 --- a/go/libraries/doltcore/env/multi_repo_env_test.go +++ b/go/libraries/doltcore/env/multi_repo_env_test.go @@ -115,7 +115,7 @@ func initRepoWithRelativePath(t *testing.T, envPath string, hdp HomeDirProvider) } func TestMultiEnvForDirectory(t *testing.T) { - rootPath, err := test.ChangeToTestDir("TestDoltEnvAsMultiEnv") + rootPath, err := test.ChangeToTestDir(t.TempDir(), "TestDoltEnvAsMultiEnv") require.NoError(t, err) hdp := func() (string, error) { return rootPath, nil } @@ -150,7 +150,7 @@ func TestMultiEnvForDirectory(t *testing.T) { } func TestMultiEnvForDirectoryWithMultipleRepos(t *testing.T) { - rootPath, err := test.ChangeToTestDir("TestDoltEnvAsMultiEnvWithMultipleRepos") + rootPath, err := test.ChangeToTestDir(t.TempDir(), "TestDoltEnvAsMultiEnvWithMultipleRepos") require.NoError(t, err) hdp := func() (string, error) { return rootPath, nil } @@ -177,7 +177,7 @@ func TestMultiEnvForDirectoryWithMultipleRepos(t *testing.T) { } func initMultiEnv(t *testing.T, testName string, names []string) (string, HomeDirProvider, map[string]*DoltEnv) { - rootPath, err := test.ChangeToTestDir(testName) + rootPath, err := test.ChangeToTestDir(t.TempDir(), testName) require.NoError(t, err) hdp := func() (string, error) { return rootPath, nil } diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go index be5b92cb257..43436848afd 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go @@ -515,7 +515,7 @@ func (d *DoltHarness) newProvider() sql.MutableDatabaseProvider { var dEnv *env.DoltEnv if d.useLocalFilesystem { - dEnv = dtestutils.CreateTestEnvForLocalFilesystem() + dEnv = dtestutils.CreateTestEnvForLocalFilesystem(d.t.TempDir()) } else { dEnv = dtestutils.CreateTestEnv() } diff --git a/go/libraries/events/event_flush_test.go b/go/libraries/events/event_flush_test.go index 1798aba4db6..70934a593de 100644 --- a/go/libraries/events/event_flush_test.go +++ b/go/libraries/events/event_flush_test.go @@ -105,7 +105,7 @@ func TestEventFlushing(t *testing.T) { fs := filesys.LocalFS path := filepath.Join(dPath, evtPath) - dDir := testLib.TestDir(path) + dDir := testLib.TestDir(t.TempDir(), path) ft = createFlushTester(fs, "", dDir) } diff --git a/go/libraries/utils/filesys/fs_test.go b/go/libraries/utils/filesys/fs_test.go index 0982dd79931..a5cfa7ed318 100644 --- a/go/libraries/utils/filesys/fs_test.go +++ b/go/libraries/utils/filesys/fs_test.go @@ -41,8 +41,8 @@ var filesysetmsToTest = map[string]Filesys{ } func TestFilesystems(t *testing.T) { - dir := test.TestDir("filesys_test") - newLocation := test.TestDir("newLocation") + dir := test.TestDir(t.TempDir(), "filesys_test") + newLocation := test.TestDir(t.TempDir(), "newLocation") subdir := filepath.Join(dir, "subdir") subdirFile := filepath.Join(subdir, testSubdirFilename) fp := filepath.Join(dir, testFilename) @@ -186,7 +186,7 @@ func TestNewInMemFS(t *testing.T) { } func TestRecursiveFSIteration(t *testing.T) { - dir := test.TestDir("TestRecursiveFSIteration") + dir := test.TestDir(t.TempDir(), "TestRecursiveFSIteration") for fsName, fs := range filesysetmsToTest { var expectedDirs []string @@ -215,7 +215,7 @@ func TestRecursiveFSIteration(t *testing.T) { } func TestFSIteration(t *testing.T) { - dir := test.TestDir("TestFSIteration") + dir := test.TestDir(t.TempDir(), "TestFSIteration") for fsName, fs := range filesysetmsToTest { var expectedDirs []string @@ -249,7 +249,7 @@ func TestFSIteration(t *testing.T) { } func TestDeletes(t *testing.T) { - dir := test.TestDir("TestDeletes") + dir := test.TestDir(t.TempDir(), "TestDeletes") for fsName, fs := range filesysetmsToTest { var ignored []string diff --git a/go/libraries/utils/test/files.go b/go/libraries/utils/test/files.go index d8561fa9e29..c75105ee5d8 100644 --- a/go/libraries/utils/test/files.go +++ b/go/libraries/utils/test/files.go @@ -22,27 +22,19 @@ import ( ) // TestDir creates a subdirectory inside the systems temp directory -func TestDir(testName string) string { - id, err := uuid.NewRandom() - - if err != nil { - panic(ShouldNeverHappen) - } - - return filepath.Join(os.TempDir(), testName, id.String()) +func TestDir(dir, testName string) string { + return filepath.Join(dir, testName, uuid.NewString()) } // ChangeToTestDir creates a new test directory and changes the current directory to be -func ChangeToTestDir(testName string) (string, error) { - dir := TestDir(testName) +func ChangeToTestDir(tempDir, testName string) (string, error) { + dir := TestDir(tempDir, testName) err := os.MkdirAll(dir, os.ModePerm) - if err != nil { return "", err } err = os.Chdir(dir) - if err != nil { return "", err } diff --git a/go/libraries/utils/test/test_test.go b/go/libraries/utils/test/test_test.go index a5c040346e9..a486f794f53 100644 --- a/go/libraries/utils/test/test_test.go +++ b/go/libraries/utils/test/test_test.go @@ -24,7 +24,7 @@ import ( // test your tests so you can test while you test func TestLDTestUtils(t *testing.T) { - dir, err := ChangeToTestDir("TestLDTestUtils") + dir, err := ChangeToTestDir(t.TempDir(), "TestLDTestUtils") if err != nil { t.Fatal("Couldn't change to test dir") diff --git a/go/performance/import_benchmarker/cmd/main.go b/go/performance/import_benchmarker/cmd/main.go index 71ad9eccf98..728ea3f9584 100644 --- a/go/performance/import_benchmarker/cmd/main.go +++ b/go/performance/import_benchmarker/cmd/main.go @@ -42,9 +42,16 @@ func main() { if err != nil { log.Fatalln(err) } + defer os.RemoveAll(tmpdir) + + userdir, err := os.MkdirTemp("", "import-benchmarker-") + if err != nil { + log.Fatalln(err) + } + defer os.RemoveAll(userdir) results := new(ib.ImportResults) - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(userdir) for _, test := range def.Tests { test.Results = results test.InitWithTmpDir(tmpdir) @@ -73,5 +80,4 @@ func main() { } else { fmt.Println(results.SqlDump()) } - os.Exit(0) } diff --git a/go/performance/import_benchmarker/testdef.go b/go/performance/import_benchmarker/testdef.go index 3bc352cfc10..b0c42d500bd 100644 --- a/go/performance/import_benchmarker/testdef.go +++ b/go/performance/import_benchmarker/testdef.go @@ -214,7 +214,7 @@ func (test *ImportTest) Run(t *testing.T) { test.InitWithTmpDir(tmp) } - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(t.TempDir()) for _, r := range test.Repos { if r.ExternalServer != nil { err := test.RunExternalServerTests(r.Name, r.ExternalServer) diff --git a/go/performance/sysbench/cmd/main.go b/go/performance/sysbench/cmd/main.go index 77f6769799e..caf493e6aa2 100644 --- a/go/performance/sysbench/cmd/main.go +++ b/go/performance/sysbench/cmd/main.go @@ -52,9 +52,16 @@ func main() { if err != nil { log.Fatalln(err) } + defer os.RemoveAll(tmpdir) + + userdir, err := os.MkdirTemp("", "sysbench-user-dir_") + if err != nil { + log.Fatalln(err) + } + defer os.RemoveAll(userdir) results := new(sysbench.Results) - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(userdir) for _, test := range defs.Tests { test.InitWithTmpDir(tmpdir) @@ -83,5 +90,4 @@ func main() { } else { fmt.Println(results.SqlDump()) } - os.Exit(0) } diff --git a/go/performance/sysbench/testdef.go b/go/performance/sysbench/testdef.go index 5fc910d2524..7b3d3d71bdd 100644 --- a/go/performance/sysbench/testdef.go +++ b/go/performance/sysbench/testdef.go @@ -440,7 +440,7 @@ func (test *Script) Run(t *testing.T) { } results := new(Results) - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(t.TempDir()) test.Results = results test.InitWithTmpDir(tmpdir) for _, r := range test.Repos { diff --git a/go/store/nbs/byte_sink_test.go b/go/store/nbs/byte_sink_test.go index d388f2d42df..7717e32a253 100644 --- a/go/store/nbs/byte_sink_test.go +++ b/go/store/nbs/byte_sink_test.go @@ -27,7 +27,7 @@ import ( ) func TestBlockBufferTableSink(t *testing.T) { - createSink := func() ByteSink { + createSink := func(*testing.T) ByteSink { return NewBlockBufferByteSink(128) } @@ -35,7 +35,7 @@ func TestBlockBufferTableSink(t *testing.T) { } func TestFixedBufferTableSink(t *testing.T) { - createSink := func() ByteSink { + createSink := func(*testing.T) ByteSink { return NewFixedBufferByteSink(make([]byte, 32*1024)) } @@ -43,8 +43,8 @@ func TestFixedBufferTableSink(t *testing.T) { } func TestBufferedFileByteSink(t *testing.T) { - createSink := func() ByteSink { - sink, err := NewBufferedFileByteSink("", 4*1024, 16) + createSink := func(t *testing.T) ByteSink { + sink, err := NewBufferedFileByteSink(t.TempDir(), 4*1024, 16) require.NoError(t, err) return sink @@ -53,7 +53,7 @@ func TestBufferedFileByteSink(t *testing.T) { suite.Run(t, &TableSinkSuite{createSink, t}) t.Run("ReaderTwice", func(t *testing.T) { - sink, err := NewBufferedFileByteSink("", 4*1024, 16) + sink, err := NewBufferedFileByteSink(t.TempDir(), 4*1024, 16) require.NoError(t, err) _, err = sink.Write([]byte{1, 2, 3, 4}) require.NoError(t, err) @@ -76,7 +76,7 @@ func TestBufferedFileByteSink(t *testing.T) { } type TableSinkSuite struct { - sinkFactory func() ByteSink + sinkFactory func(*testing.T) ByteSink t *testing.T } @@ -116,7 +116,7 @@ func verifyContents(t *testing.T, bytes []byte) { } func (suite *TableSinkSuite) TestWriteAndFlush() { - sink := suite.sinkFactory() + sink := suite.sinkFactory(suite.t) err := writeToSink(sink) require.NoError(suite.t, err) @@ -128,7 +128,7 @@ func (suite *TableSinkSuite) TestWriteAndFlush() { } func (suite *TableSinkSuite) TestWriteAndFlushToFile() { - sink := suite.sinkFactory() + sink := suite.sinkFactory(suite.t) err := writeToSink(sink) require.NoError(suite.t, err) diff --git a/go/store/nbs/cmp_chunk_table_writer_test.go b/go/store/nbs/cmp_chunk_table_writer_test.go index 170cc43cb64..1f91d55b025 100644 --- a/go/store/nbs/cmp_chunk_table_writer_test.go +++ b/go/store/nbs/cmp_chunk_table_writer_test.go @@ -56,7 +56,7 @@ func TestCmpChunkTableWriter(t *testing.T) { require.NoError(t, eg.Wait()) // for all the chunks we find, write them using the compressed writer - tw, err := NewCmpChunkTableWriter("") + tw, err := NewCmpChunkTableWriter(t.TempDir()) require.NoError(t, err) for _, cmpChnk := range found { err = tw.AddCmpChunk(cmpChnk) @@ -67,7 +67,7 @@ func TestCmpChunkTableWriter(t *testing.T) { require.NoError(t, err) t.Run("ErrDuplicateChunkWritten", func(t *testing.T) { - tw, err := NewCmpChunkTableWriter("") + tw, err := NewCmpChunkTableWriter(t.TempDir()) require.NoError(t, err) for _, cmpChnk := range found { err = tw.AddCmpChunk(cmpChnk) @@ -96,7 +96,7 @@ func TestCmpChunkTableWriter(t *testing.T) { } func TestCmpChunkTableWriterGhostChunk(t *testing.T) { - tw, err := NewCmpChunkTableWriter("") + tw, err := NewCmpChunkTableWriter(t.TempDir()) require.NoError(t, err) require.Error(t, tw.AddCmpChunk(NewGhostCompressedChunk(hash.Parse("6af71afc2ea0hmp4olev0vp9q1q5gvb1")))) } diff --git a/go/store/nbs/journal_writer_test.go b/go/store/nbs/journal_writer_test.go index df8c45946f3..e6f6e59fd33 100644 --- a/go/store/nbs/journal_writer_test.go +++ b/go/store/nbs/journal_writer_test.go @@ -279,9 +279,7 @@ func TestJournalWriterSyncClose(t *testing.T) { } func newTestFilePath(t *testing.T) string { - path, err := os.MkdirTemp("", "") - require.NoError(t, err) - return filepath.Join(path, "journal.log") + return filepath.Join(t.TempDir(), "journal.log") } func TestJournalIndexBootstrap(t *testing.T) { @@ -398,6 +396,8 @@ func TestJournalIndexBootstrap(t *testing.T) { require.True(t, ok) _, err = jnl.bootstrapJournal(ctx, nil) assert.Error(t, err) + err = jnl.Close() + require.NoError(t, err) }) } } diff --git a/go/store/nbs/mem_table_test.go b/go/store/nbs/mem_table_test.go index 250f2994636..6802abc66c4 100644 --- a/go/store/nbs/mem_table_test.go +++ b/go/store/nbs/mem_table_test.go @@ -69,12 +69,7 @@ func TestWriteChunks(t *testing.T) { t.Error(err) } - dir, err := os.MkdirTemp("", "write_chunks_test") - if err != nil { - t.Error(err) - } - - err = os.WriteFile(dir+name, data, os.ModePerm) + err = os.WriteFile(t.TempDir()+name, data, os.ModePerm) if err != nil { t.Error(err) } diff --git a/go/store/nbs/store_test.go b/go/store/nbs/store_test.go index 90f204bc996..041394e45a8 100644 --- a/go/store/nbs/store_test.go +++ b/go/store/nbs/store_test.go @@ -44,6 +44,9 @@ func makeTestLocalStore(t *testing.T, maxTableFiles int) (st *NomsBlockStore, no nomsDir = filepath.Join(tempfiles.MovableTempFileProvider.GetTempDir(), "noms_"+uuid.New().String()[:8]) err := os.MkdirAll(nomsDir, os.ModePerm) require.NoError(t, err) + t.Cleanup(func() { + os.RemoveAll(nomsDir) + }) // create a v5 manifest fm, err := getFileManifest(ctx, nomsDir, asyncFlush) diff --git a/go/store/valuefile/value_file_test.go b/go/store/valuefile/value_file_test.go index 2d410ad0800..318dde9f97d 100644 --- a/go/store/valuefile/value_file_test.go +++ b/go/store/valuefile/value_file_test.go @@ -53,7 +53,7 @@ func TestReadWriteValueFile(t *testing.T) { values = append(values, m) } - path := filepath.Join(os.TempDir(), "file.nvf") + path := filepath.Join(t.TempDir(), "file.nvf") err = WriteValueFile(ctx, path, store, values...) require.NoError(t, err) diff --git a/integration-tests/go-sql-server-driver/concurrent_gc_test.go b/integration-tests/go-sql-server-driver/concurrent_gc_test.go index c6a3c5e6c27..002876c14f3 100644 --- a/integration-tests/go-sql-server-driver/concurrent_gc_test.go +++ b/integration-tests/go-sql-server-driver/concurrent_gc_test.go @@ -172,7 +172,7 @@ func (gct gcTest) finalize(t *testing.T, ctx context.Context, db *sql.DB) { } func (gct gcTest) run(t *testing.T) { - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(t.TempDir()) require.NoError(t, err) t.Cleanup(func() { u.Cleanup() diff --git a/integration-tests/go-sql-server-driver/gc_oldgen_conjoin_test.go b/integration-tests/go-sql-server-driver/gc_oldgen_conjoin_test.go index b527f0f7ecf..564b5bfca9d 100644 --- a/integration-tests/go-sql-server-driver/gc_oldgen_conjoin_test.go +++ b/integration-tests/go-sql-server-driver/gc_oldgen_conjoin_test.go @@ -29,7 +29,7 @@ import ( ) func TestGCConjoinsOldgen(t *testing.T) { - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(t.TempDir() ) require.NoError(t, err) t.Cleanup(func() { u.Cleanup() diff --git a/integration-tests/go-sql-server-driver/main_test.go b/integration-tests/go-sql-server-driver/main_test.go index cba1f98d948..b1bff123a3d 100644 --- a/integration-tests/go-sql-server-driver/main_test.go +++ b/integration-tests/go-sql-server-driver/main_test.go @@ -30,26 +30,29 @@ import ( // It's good enough for now, and it keeps us from checking in certificates or // JWT which will expire at some point in the future. func TestMain(m *testing.M) { - old := os.Getenv("TESTGENDIR") - defer func() { - os.Setenv("TESTGENDIR", old) + res := func() int { + old := os.Getenv("TESTGENDIR") + defer func() { + os.Setenv("TESTGENDIR", old) + }() + gendir, err := os.MkdirTemp(os.TempDir(), "go-sql-server-driver-gen-*") + if err != nil { + log.Fatalf("could not create temp dir: %v", err) + } + defer os.RemoveAll(gendir) + err = GenerateTestJWTs(gendir) + if err != nil { + log.Fatalf("%v", err) + } + err = GenerateX509Certs(gendir) + if err != nil { + log.Fatalf("%v", err) + } + os.Setenv("TESTGENDIR", gendir) + flag.Parse() + return m.Run() }() - gendir, err := os.MkdirTemp(os.TempDir(), "go-sql-server-driver-gen-*") - if err != nil { - log.Fatalf("could not create temp dir: %v", err) - } - defer os.RemoveAll(gendir) - err = GenerateTestJWTs(gendir) - if err != nil { - log.Fatalf("%v", err) - } - err = GenerateX509Certs(gendir) - if err != nil { - log.Fatalf("%v", err) - } - os.Setenv("TESTGENDIR", gendir) - flag.Parse() - os.Exit(m.Run()) + os.Exit(res) } func TestConfig(t *testing.T) { diff --git a/integration-tests/go-sql-server-driver/sqlserver_info_test.go b/integration-tests/go-sql-server-driver/sqlserver_info_test.go index 5461ecc03b0..e537832efc1 100644 --- a/integration-tests/go-sql-server-driver/sqlserver_info_test.go +++ b/integration-tests/go-sql-server-driver/sqlserver_info_test.go @@ -30,7 +30,7 @@ import ( func TestSQLServerInfoFile(t *testing.T) { t.Run("With Two Repos", func(t *testing.T) { - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(t.TempDir()) require.NoError(t, err) t.Cleanup(func() { u.Cleanup() @@ -298,7 +298,7 @@ func TestSQLServerInfoFile(t *testing.T) { }) }) t.Run("With Empty RepoStore", func(t *testing.T) { - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(t.TempDir()) require.NoError(t, err) t.Cleanup(func() { u.Cleanup() diff --git a/integration-tests/go-sql-server-driver/testdef.go b/integration-tests/go-sql-server-driver/testdef.go index 821371fb235..714cbacf96a 100644 --- a/integration-tests/go-sql-server-driver/testdef.go +++ b/integration-tests/go-sql-server-driver/testdef.go @@ -129,7 +129,7 @@ func (test Test) Run(t *testing.T) { t.Skip(test.Skip) } - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(t.TempDir()) require.NoError(t, err) t.Cleanup(func() { u.Cleanup() @@ -153,7 +153,7 @@ func (test Test) Run(t *testing.T) { } for _, mr := range test.MultiRepos { // Each MultiRepo gets its own dolt config --global. - u, err := driver.NewDoltUser() + u, err := driver.NewDoltUser(t.TempDir()) require.NoError(t, err) t.Cleanup(func() { u.Cleanup() From 742f2760c75c5a2e05cf6d6cd01de57ac44585ad Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Tue, 28 Jan 2025 06:37:35 -0800 Subject: [PATCH 02/34] go/cmd/dolt: signed_commits_test.go: Try to close DoltDBs so that windows test cleanup works. --- go/cmd/dolt/commands/signed_commits_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/go/cmd/dolt/commands/signed_commits_test.go b/go/cmd/dolt/commands/signed_commits_test.go index 05a44f60f3d..e2d91ad0425 100644 --- a/go/cmd/dolt/commands/signed_commits_test.go +++ b/go/cmd/dolt/commands/signed_commits_test.go @@ -141,6 +141,12 @@ func execCommand(ctx context.Context, t *testing.T, wd string, cmd cli.Command, err = fmt.Errorf("error creating multi repo: %w", err) return } + t.Cleanup(func() { + mr.Iter(func(_ string, env *env.DoltEnv) (bool, error) { + env.DoltDB.Close() + return false, nil + }) + }) latebind, verr := BuildSqlEngineQueryist(ctx, dEnv.FS, mr, &cli.UserPassword{}, apr) if verr != nil { From be448ee6821efe5c6b4362c428934d00eecb0468 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Tue, 28 Jan 2025 09:25:20 -0800 Subject: [PATCH 03/34] go/cmd/dolt: signed_commits_test.go: Try a different way to cleanup the databases. --- go/cmd/dolt/commands/signed_commits_test.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/go/cmd/dolt/commands/signed_commits_test.go b/go/cmd/dolt/commands/signed_commits_test.go index e2d91ad0425..2e73d280026 100644 --- a/go/cmd/dolt/commands/signed_commits_test.go +++ b/go/cmd/dolt/commands/signed_commits_test.go @@ -24,6 +24,7 @@ import ( "testing" "github.com/dolthub/dolt/go/cmd/dolt/cli" + "github.com/dolthub/dolt/go/libraries/doltcore/dbfactory" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/utils/argparser" "github.com/dolthub/dolt/go/libraries/utils/config" @@ -78,6 +79,9 @@ func TestSignAndVerifyCommit(t *testing.T) { ctx := context.Background() importKey(t, ctx) dbDir := setupTestDB(t, ctx, filesys.LocalFS) + t.Cleanup(func() { + dbfactory.CloseAllLocalDatabases() + }) global := map[string]string{ "user.name": "First Last", @@ -141,12 +145,6 @@ func execCommand(ctx context.Context, t *testing.T, wd string, cmd cli.Command, err = fmt.Errorf("error creating multi repo: %w", err) return } - t.Cleanup(func() { - mr.Iter(func(_ string, env *env.DoltEnv) (bool, error) { - env.DoltDB.Close() - return false, nil - }) - }) latebind, verr := BuildSqlEngineQueryist(ctx, dEnv.FS, mr, &cli.UserPassword{}, apr) if verr != nil { From 98547ce6ce7b36e93e04a8f34edca915e7996e88 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Thu, 23 Jan 2025 16:59:11 -0800 Subject: [PATCH 04/34] Bug fix: fixing ref conflict case that can cause tags to be temporarily removed on a replica --- .../doltcore/sqle/read_replica_database.go | 16 ++++++- .../doltcore/sqle/replication_test.go | 9 +++- integration-tests/bats/replication.bats | 43 ++++++++++++++++++- 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/go/libraries/doltcore/sqle/read_replica_database.go b/go/libraries/doltcore/sqle/read_replica_database.go index fba8ca14cd4..4c757684ea9 100644 --- a/go/libraries/doltcore/sqle/read_replica_database.go +++ b/go/libraries/doltcore/sqle/read_replica_database.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "sort" "strings" "sync" @@ -509,9 +510,20 @@ func getReplicationRefs(ctx *sql.Context, rrd ReadReplicaDatabase) ( func refsToDelete(remRefs, localRefs []doltdb.RefWithHash) []doltdb.RefWithHash { toDelete := make([]doltdb.RefWithHash, 0, len(localRefs)) var i, j int + + // Before we map remote refs to local refs to determine which refs to delete, we need to sort them + // by Ref.String() – this ensures a unique identifier that does not conflict with other refs, unlike + // Ref.GetPath(), which can conflict if a branch or tag has the same name. + sort.Slice(remRefs, func(i, j int) bool { + return remRefs[i].Ref.String() < remRefs[j].Ref.String() + }) + sort.Slice(localRefs, func(i, j int) bool { + return localRefs[i].Ref.String() < localRefs[j].Ref.String() + }) + for i < len(remRefs) && j < len(localRefs) { - rem := remRefs[i].Ref.GetPath() - local := localRefs[j].Ref.GetPath() + rem := remRefs[i].Ref.String() + local := localRefs[j].Ref.String() if rem == local { i++ j++ diff --git a/go/libraries/doltcore/sqle/replication_test.go b/go/libraries/doltcore/sqle/replication_test.go index f2922de2147..538db4c3aea 100644 --- a/go/libraries/doltcore/sqle/replication_test.go +++ b/go/libraries/doltcore/sqle/replication_test.go @@ -80,12 +80,17 @@ func TestReplicationBranches(t *testing.T) { local: []string{"feature4", "feature5", "feature6", "feature7", "feature8", "feature9"}, expToDelete: []string{"feature4", "feature5", "feature6", "feature7", "feature8", "feature9"}, }, + { + remote: []string{"main", "new1", "a1"}, + local: []string{"main", "a1"}, + expToDelete: []string{}, + }, } for _, tt := range tests { remoteRefs := make([]doltdb.RefWithHash, len(tt.remote)) for i := range tt.remote { - remoteRefs[i] = doltdb.RefWithHash{Ref: ref.NewRemoteRef("", tt.remote[i])} + remoteRefs[i] = doltdb.RefWithHash{Ref: ref.NewBranchRef(tt.remote[i])} } localRefs := make([]doltdb.RefWithHash, len(tt.local)) for i := range tt.local { @@ -96,6 +101,6 @@ func TestReplicationBranches(t *testing.T) { for i := range diff { diffNames[i] = diff[i].Ref.GetPath() } - assert.Equal(t, diffNames, tt.expToDelete) + assert.Equal(t, tt.expToDelete, diffNames) } } diff --git a/integration-tests/bats/replication.bats b/integration-tests/bats/replication.bats index 460ebdc3f62..d10420ab689 100644 --- a/integration-tests/bats/replication.bats +++ b/integration-tests/bats/replication.bats @@ -201,6 +201,48 @@ teardown() { [[ "$output" =~ "b1" ]] || false } +# When a replica pulls refs, the remote refs are compared with the local refs to identify which local refs +# need to be deleted. Branches, tags, and remotes all share the ref space and previous versions of Dolt could +# incorrectly map remote refs and local refs, resulting in local refs being incorrectly removed, until future +# runs of replica synchronization. +@test "replication: local tag refs are not deleted" { + # Configure repo1 to push changes on commit and create tag a1 + cd repo1 + dolt config --local --add sqlserver.global.dolt_replicate_to_remote remote1 + dolt sql -q "call dolt_tag('a1');" + + # Configure repo2 to pull changes on read + cd .. + dolt clone file://./rem1 repo2 + cd repo2 + dolt config --local --add sqlserver.global.dolt_read_replica_remote origin + dolt config --local --add sqlserver.global.dolt_replicate_all_heads 1 + run dolt sql -q "select tag_name from dolt_tags;" + [ "$status" -eq 0 ] + [[ "$output" =~ "| tag_name |" ]] || false + [[ "$output" =~ "| a1 |" ]] || false + + # Create branch new1 in repo1 – "new1" sorts after "main", but before "a1", and previous + # versions of Dolt had problems computing which local refs to delete in this case. + cd ../repo1 + dolt sql -q "call dolt_branch('new1');" + + # Confirm that tag a1 has not been deleted. Note that we need to check for this immediately after + # creating branch new1 (i.e. before looking at branches), because the bug in the previous versions + # of Dolt would only manifest in the next command, and would be fixed by later remote pulls. + cd ../repo2 + run dolt sql -q "select tag_name from dolt_tags;" + [ "$status" -eq 0 ] + [[ "$output" =~ "| tag_name |" ]] || false + [[ "$output" =~ "| a1 |" ]] || false + + # Try again to make sure the results are stable + run dolt sql -q "select tag_name from dolt_tags;" + [ "$status" -eq 0 ] + [[ "$output" =~ "| tag_name |" ]] || false + [[ "$output" =~ "| a1 |" ]] || false +} + @test "replication: pull branch delete current branch" { skip "broken by latest transaction changes" @@ -627,7 +669,6 @@ SQL } @test "replication: pull all heads pulls tags" { - dolt clone file://./rem1 repo2 cd repo2 dolt checkout -b new_feature From 2e8fdedbfec3d97195820789d9df5841b3ce1b2a Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Fri, 24 Jan 2025 12:24:05 -0800 Subject: [PATCH 05/34] Fixing a merge error that included the env section twice --- .github/workflows/ci-check-repo.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/ci-check-repo.yaml b/.github/workflows/ci-check-repo.yaml index 770ecbba2e1..fa16ce85111 100644 --- a/.github/workflows/ci-check-repo.yaml +++ b/.github/workflows/ci-check-repo.yaml @@ -152,8 +152,6 @@ jobs: cwd: "." pull: "--ff" - name: Check generated protobufs - env: - USE_BAZEL_VERSION: 7.4.0 working-directory: ./proto env: USE_BAZEL_VERSION: 7.4.0 From f9affcc64be760832a323066fa3de589aabc93df Mon Sep 17 00:00:00 2001 From: coffeegoddd Date: Fri, 24 Jan 2025 21:35:09 +0000 Subject: [PATCH 06/34] [ga-bump-release] Update Dolt version to 1.47.2 and release v1.47.2 --- go/cmd/dolt/doltversion/version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/cmd/dolt/doltversion/version.go b/go/cmd/dolt/doltversion/version.go index 8e9fa83276f..3c53914a1b0 100644 --- a/go/cmd/dolt/doltversion/version.go +++ b/go/cmd/dolt/doltversion/version.go @@ -16,5 +16,5 @@ package doltversion const ( - Version = "1.47.1" + Version = "1.47.2" ) From bd8c83419c1c180eb179876f715aaa2fc8d248af Mon Sep 17 00:00:00 2001 From: Dustin Brown Date: Mon, 27 Jan 2025 10:53:58 -0800 Subject: [PATCH 07/34] [ga-bump-dep] Bump dependency in Dolt by jycor (#8790) Co-authored-by: jycor --- go/go.mod | 2 +- go/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/go.mod b/go/go.mod index ca0a435c3e7..50ac73be0e7 100644 --- a/go/go.mod +++ b/go/go.mod @@ -56,7 +56,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 github.com/creasty/defaults v1.6.0 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.19.1-0.20250123004221-f5a5bcea7eed + github.com/dolthub/go-mysql-server v0.19.1-0.20250124213954-8a1af52235d7 github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 github.com/dolthub/swiss v0.1.0 github.com/esote/minmaxheap v1.0.0 diff --git a/go/go.sum b/go/go.sum index e18888690e4..223d7a54136 100644 --- a/go/go.sum +++ b/go/go.sum @@ -179,8 +179,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90 h1:Sni8jrP0sy/w9ZYXoff4g/ixe+7bFCZlfCqXKJSU+zM= github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.19.1-0.20250123004221-f5a5bcea7eed h1:2EQHWtMkjyN/SNfbg/nh/a0RANq8V8gxNynYum2Kq+s= -github.com/dolthub/go-mysql-server v0.19.1-0.20250123004221-f5a5bcea7eed/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc= +github.com/dolthub/go-mysql-server v0.19.1-0.20250124213954-8a1af52235d7 h1:DjirOAU+gMlWqr3Ut9PsVT5iqdirAcLr84Cbbi60Kis= +github.com/dolthub/go-mysql-server v0.19.1-0.20250124213954-8a1af52235d7/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= From cce2336b293ff231d113a25d1a24bde1d2e26428 Mon Sep 17 00:00:00 2001 From: Neil Macneale IV Date: Mon, 27 Jan 2025 10:10:49 -0800 Subject: [PATCH 08/34] Add the dirty column to the dolt_branches system table --- .../doltcore/sqle/dtables/branches_table.go | 59 ++++++++++++++++++- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/go/libraries/doltcore/sqle/dtables/branches_table.go b/go/libraries/doltcore/sqle/dtables/branches_table.go index a343d825d0b..291f97d60cf 100644 --- a/go/libraries/doltcore/sqle/dtables/branches_table.go +++ b/go/libraries/doltcore/sqle/dtables/branches_table.go @@ -18,6 +18,7 @@ import ( "fmt" "io" + "github.com/dolthub/dolt/go/store/hash" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" @@ -90,6 +91,7 @@ func (bt *BranchesTable) Schema() sql.Schema { if !bt.remote { columns = append(columns, &sql.Column{Name: "remote", Type: types.Text, Source: bt.tableName, PrimaryKey: false, Nullable: true}) columns = append(columns, &sql.Column{Name: "branch", Type: types.Text, Source: bt.tableName, PrimaryKey: false, Nullable: true}) + columns = append(columns, &sql.Column{Name: "dirty", Type: types.Boolean, Source: bt.tableName, PrimaryKey: false, Nullable: true}) } return columns } @@ -114,6 +116,7 @@ type BranchItr struct { table *BranchesTable branches []string commits []*doltdb.Commit + dirty []bool idx int } @@ -145,19 +148,28 @@ func NewBranchItr(ctx *sql.Context, table *BranchesTable) (*BranchItr, error) { branchNames := make([]string, len(branchRefs)) commits := make([]*doltdb.Commit, len(branchRefs)) + dirtyBits := make([]bool, len(branchRefs)) for i, branch := range branchRefs { commit, err := ddb.ResolveCommitRefAtRoot(ctx, branch, txRoot) - if err != nil { return nil, err } + var dirty bool + if !remote { + dirty, err = isDirty(ctx, ddb, commit, branch, txRoot) + if err != nil { + return nil, err + } + } + if branch.GetType() == ref.RemoteRefType { branchNames[i] = "remotes/" + branch.GetPath() } else { branchNames[i] = branch.GetPath() } + dirtyBits[i] = dirty commits[i] = commit } @@ -165,6 +177,7 @@ func NewBranchItr(ctx *sql.Context, table *BranchesTable) (*BranchItr, error) { table: table, branches: branchNames, commits: commits, + dirty: dirtyBits, idx: 0, }, nil } @@ -182,6 +195,7 @@ func (itr *BranchItr) Next(ctx *sql.Context) (sql.Row, error) { name := itr.branches[itr.idx] cm := itr.commits[itr.idx] + dirty := itr.dirty[itr.idx] meta, err := cm.GetCommitMeta(ctx) if err != nil { @@ -211,8 +225,49 @@ func (itr *BranchItr) Next(ctx *sql.Context) (sql.Row, error) { remoteName = branch.Remote branchName = branch.Merge.Ref.GetPath() } - return sql.NewRow(name, h.String(), meta.Name, meta.Email, meta.Time(), meta.Description, remoteName, branchName), nil + return sql.NewRow(name, h.String(), meta.Name, meta.Email, meta.Time(), meta.Description, remoteName, branchName, dirty), nil + } +} + +// isDirty returns true if the working ref points to a dirty branch. +func isDirty(ctx *sql.Context, ddb *doltdb.DoltDB, commit *doltdb.Commit, branch ref.DoltRef, txRoot hash.Hash) (bool, error) { + wsRef, err := ref.WorkingSetRefForHead(branch) + if err != nil { + return false, err + } + ws, err := ddb.ResolveWorkingSetAtRoot(ctx, wsRef, txRoot) + if err != nil { + return false, err + } + + workingRoot := ws.WorkingRoot() + workingRootHash, err := workingRoot.HashOf() + if err != nil { + return false, err + } + stagedRoot := ws.StagedRoot() + stagedRootHash, err := stagedRoot.HashOf() + if err != nil { + return false, err + } + + dirty := false + if workingRootHash != stagedRootHash { + dirty = true + } else { + cmRt, err := commit.GetRootValue(ctx) + if err != nil { + return false, err + } + cmRtHash, err := cmRt.HashOf() + if err != nil { + return false, err + } + if cmRtHash != workingRootHash { + dirty = true + } } + return dirty, nil } // Close closes the iterator. From ecea68ce222f71a5c4090adef975bc10db72b67c Mon Sep 17 00:00:00 2001 From: Neil Macneale IV Date: Mon, 27 Jan 2025 11:22:02 -0800 Subject: [PATCH 09/34] Tests for the dolt_branches dirty column --- go/libraries/doltcore/sqle/sqlselect_test.go | 2 + go/libraries/doltcore/sqle/testutil.go | 2 +- integration-tests/bats/sql-server.bats | 42 +++++++++++++++++++- 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/go/libraries/doltcore/sqle/sqlselect_test.go b/go/libraries/doltcore/sqle/sqlselect_test.go index 7f8fc08464d..d065a84f0bb 100644 --- a/go/libraries/doltcore/sqle/sqlselect_test.go +++ b/go/libraries/doltcore/sqle/sqlselect_test.go @@ -784,6 +784,7 @@ func BasicSelectTests() []SelectTest { "Initialize data repository", "", "", + true, // Test setup has a dirty workspace. }, }, ExpectedSqlSchema: sql.Schema{ @@ -795,6 +796,7 @@ func BasicSelectTests() []SelectTest { &sql.Column{Name: "latest_commit_message", Type: gmstypes.Text}, &sql.Column{Name: "remote", Type: gmstypes.Text}, &sql.Column{Name: "branch", Type: gmstypes.Text}, + &sql.Column{Name: "dirty", Type: gmstypes.Boolean}, }, }, } diff --git a/go/libraries/doltcore/sqle/testutil.go b/go/libraries/doltcore/sqle/testutil.go index ee4ad010714..5379fbc77f3 100644 --- a/go/libraries/doltcore/sqle/testutil.go +++ b/go/libraries/doltcore/sqle/testutil.go @@ -452,7 +452,7 @@ func CreateEmptyTestTable(dEnv *env.DoltEnv, tableName string, sch schema.Schema return dEnv.UpdateWorkingRoot(ctx, newRoot) } -// CreateTestDatabase creates a test database with the test data set in it. +// CreateTestDatabase creates a test database with the test data set in it. Has a dirty workspace as well. func CreateTestDatabase() (*env.DoltEnv, error) { ctx := context.Background() dEnv, err := CreateEmptyTestDatabase() diff --git a/integration-tests/bats/sql-server.bats b/integration-tests/bats/sql-server.bats index ce6373a6b30..2f06c100ff5 100644 --- a/integration-tests/bats/sql-server.bats +++ b/integration-tests/bats/sql-server.bats @@ -1996,4 +1996,44 @@ EOF run dolt --data-dir datadir1 sql-server --data-dir datadir2 [ $status -eq 1 ] [[ "$output" =~ "cannot specify both global --data-dir argument and --data-dir in sql-server config" ]] || false -} \ No newline at end of file +} + +# This is really a test of the dolt_Branches system table, but due to needing a server with multiple dirty branches +# it was easier to test it with a sql-server. +@test "sql-server: dirty branches listed properly in dolt_branches table" { + skiponwindows "Missing dependencies" + + cd repo1 + dolt checkout main + dolt branch br1 # Will be a clean commit, ahead of main. + dolt branch br2 # will be a dirty branch, on main. + dolt branch br3 # will be a dirty branch, on br1 + start_sql_server repo1 + + dolt --use-db "repo1" --branch br1 sql -q "CREATE TABLE tbl (i int primary key)" + dolt --use-db "repo1" --branch br1 sql -q "CALL DOLT_COMMIT('-Am', 'commit it')" + + dolt --use-db "repo1" --branch br2 sql -q "CREATE TABLE tbl (j int primary key)" + + # Fast forward br3 to br1, then make it dirty. + dolt --use-db "repo1" --branch br3 sql -q "CALL DOLT_MERGE('br1')" + dolt --use-db "repo1" --branch br3 sql -q "CREATE TABLE othertbl (k int primary key)" + + stop_sql_server 1 && sleep 0.5 + + run dolt sql -q "SELECT name,dirty FROM dolt_branches" + [ "$status" -eq 0 ] + [[ "$output" =~ "br1 | false" ]] || false + [[ "$output" =~ "br2 | true " ]] || false + [[ "$output" =~ "br3 | true" ]] || false + [[ "$output" =~ "main | false" ]] || false + + # Verify that the dolt_branches table show the same output, regardless of the checked out branch. + dolt checkout br1 + run dolt sql -q "SELECT name,dirty FROM dolt_branches" + [ "$status" -eq 0 ] + [[ "$output" =~ "br1 | false" ]] || false + [[ "$output" =~ "br2 | true " ]] || false + [[ "$output" =~ "br3 | true" ]] || false + [[ "$output" =~ "main | false" ]] || false +} From 4bda2e4d8a49ed234d1d25c1cba42ba108732ddf Mon Sep 17 00:00:00 2001 From: macneale4 Date: Mon, 27 Jan 2025 19:31:31 +0000 Subject: [PATCH 10/34] [ga-format-pr] Run go/utils/repofmt/format_repo.sh and go/Godeps/update.sh --- go/libraries/doltcore/sqle/dtables/branches_table.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/libraries/doltcore/sqle/dtables/branches_table.go b/go/libraries/doltcore/sqle/dtables/branches_table.go index 291f97d60cf..e332f50510d 100644 --- a/go/libraries/doltcore/sqle/dtables/branches_table.go +++ b/go/libraries/doltcore/sqle/dtables/branches_table.go @@ -18,7 +18,6 @@ import ( "fmt" "io" - "github.com/dolthub/dolt/go/store/hash" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" @@ -27,6 +26,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" + "github.com/dolthub/dolt/go/store/hash" ) const branchesDefaultRowCount = 10 From ccb34c42630234f49259f0b5e521ce2330cf81b2 Mon Sep 17 00:00:00 2001 From: Neil Macneale IV Date: Mon, 27 Jan 2025 12:15:40 -0800 Subject: [PATCH 11/34] Fix workbench tests to expect dirty column --- .../mysql-client-tests/node/workbenchTests/branches.js | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integration-tests/mysql-client-tests/node/workbenchTests/branches.js b/integration-tests/mysql-client-tests/node/workbenchTests/branches.js index 1f0b63b575e..09fb1bdd51c 100644 --- a/integration-tests/mysql-client-tests/node/workbenchTests/branches.js +++ b/integration-tests/mysql-client-tests/node/workbenchTests/branches.js @@ -58,6 +58,7 @@ export const branchTests = [ latest_commit_message: "Initialize data repository", remote: "", branch: "", + dirty: 0, }, { name: "mybranch", @@ -68,6 +69,7 @@ export const branchTests = [ latest_commit_message: "Create table test", remote: "", branch: "", + dirty: 0, }, ], matcher: branchesMatcher, From 7f7f2c6bc577e1359a2782c3d65cff977712d20f Mon Sep 17 00:00:00 2001 From: Neil Macneale IV Date: Mon, 27 Jan 2025 14:27:25 -0800 Subject: [PATCH 12/34] Don't error out when no workingset is found --- go/libraries/doltcore/sqle/dtables/branches_table.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/go/libraries/doltcore/sqle/dtables/branches_table.go b/go/libraries/doltcore/sqle/dtables/branches_table.go index e332f50510d..1a32fd0e764 100644 --- a/go/libraries/doltcore/sqle/dtables/branches_table.go +++ b/go/libraries/doltcore/sqle/dtables/branches_table.go @@ -15,6 +15,7 @@ package dtables import ( + "errors" "fmt" "io" @@ -237,6 +238,10 @@ func isDirty(ctx *sql.Context, ddb *doltdb.DoltDB, commit *doltdb.Commit, branch } ws, err := ddb.ResolveWorkingSetAtRoot(ctx, wsRef, txRoot) if err != nil { + if errors.Is(err, doltdb.ErrWorkingSetNotFound) { + // If there is no working set for this branch, then it is never dirty. This happens on servers commonly. + return false, nil + } return false, err } From 83bbd1760ec7245393d39bfd9b51fea114037db7 Mon Sep 17 00:00:00 2001 From: coffeegoddd Date: Tue, 28 Jan 2025 00:01:37 +0000 Subject: [PATCH 13/34] [ga-bump-release] Update Dolt version to 1.48.0 and release v1.48.0 --- go/cmd/dolt/doltversion/version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/cmd/dolt/doltversion/version.go b/go/cmd/dolt/doltversion/version.go index 3c53914a1b0..e6561764635 100644 --- a/go/cmd/dolt/doltversion/version.go +++ b/go/cmd/dolt/doltversion/version.go @@ -16,5 +16,5 @@ package doltversion const ( - Version = "1.47.2" + Version = "1.48.0" ) From 8f25c115d6f0a34deb361bf27b622329e438cff6 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Wed, 15 Jan 2025 10:58:52 -0800 Subject: [PATCH 14/34] go/store/nbs: store.go: Add a mechanism to bracket outstanding read requests so that a GC does not end while they are in progress. --- go/store/nbs/store.go | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index 7f08eb62c99..6830c344702 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -99,8 +99,14 @@ type NomsBlockStore struct { tables tableSet upstream manifestContents - cond *sync.Cond + cond *sync.Cond + // |true| after BeginGC is called, and false once the corresponding EndGC call returns. gcInProgress bool + // When unlocked read operations are occuring against the + // block store, and they started when |gcInProgress == true|, + // this variable is incremented. EndGC will not return until + // no outstanding reads are in progress. + gcOutstandingReads int // keeperFunc is set when |gcInProgress| and appends to the GC sweep queue // or blocks on GC finalize keeperFunc func(hash.Hash) bool @@ -703,7 +709,9 @@ func (nbs *NomsBlockStore) WithoutConjoiner() *NomsBlockStore { } } -// Wait for GC to complete to continue with writes +// Wait for GC to complete to continue with ongoing operations. +// Called with nbs.mu held. When this function returns with a nil +// error, gcInProgress will be false. func (nbs *NomsBlockStore) waitForGC(ctx context.Context) error { stop := make(chan struct{}) defer close(stop) @@ -1593,11 +1601,36 @@ func (nbs *NomsBlockStore) EndGC() { if !nbs.gcInProgress { panic("EndGC called when gc was not in progress") } + for nbs.gcOutstandingReads > 0 { + nbs.cond.Wait() + } nbs.gcInProgress = false nbs.keeperFunc = nil nbs.cond.Broadcast() } +// beginRead() is called with |nbs.mu| held. It signals an ongoing +// read operation which will be operating against the existing table +// files without |nbs.mu| held. The read should be bracket with a call +// to the returned |endRead|, which must be called with |nbs.mu| held +// if it is non-|nil|, and should not be called otherwise. +// +// If there is an ongoing GC operation which this call is made, it is +// guaranteed not to complete until the corresponding |endRead| call. +func (nbs *NomsBlockStore) beginRead() (endRead func()) { + if nbs.gcInProgress { + nbs.gcOutstandingReads += 1 + return func() { + nbs.gcOutstandingReads -= 1 + if nbs.gcOutstandingReads < 0 { + panic("impossible") + } + nbs.cond.Broadcast() + } + } + return nil +} + func (nbs *NomsBlockStore) MarkAndSweepChunks(ctx context.Context, getAddrs chunks.GetAddrsCurry, filter chunks.HasManyFunc, dest chunks.ChunkStore, mode chunks.GCMode) (chunks.MarkAndSweeper, error) { return markAndSweepChunks(ctx, nbs, nbs, dest, getAddrs, filter, mode) } From 1f4b4870fef58f4c8bd84560e214812c55e4c8fe Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Wed, 15 Jan 2025 13:59:07 -0800 Subject: [PATCH 15/34] go/store/nbs: chunkReader,chunkSource: GC: Add the ability to take dependencies on chunks during reads. --- go/store/nbs/archive_build.go | 7 +- go/store/nbs/archive_chunk_source.go | 48 ++++-- go/store/nbs/archive_test.go | 16 +- go/store/nbs/cmp_chunk_table_writer_test.go | 4 +- go/store/nbs/conjoiner_test.go | 4 +- go/store/nbs/empty_chunk_source.go | 24 +-- go/store/nbs/journal.go | 3 +- go/store/nbs/journal_chunk_source.go | 61 ++++--- go/store/nbs/journal_test.go | 8 +- go/store/nbs/journal_writer_test.go | 2 +- go/store/nbs/mem_table.go | 45 ++++-- go/store/nbs/mem_table_test.go | 94 ++++++----- go/store/nbs/root_tracker_test.go | 2 +- go/store/nbs/store.go | 46 ++++-- go/store/nbs/table.go | 29 +++- go/store/nbs/table_reader.go | 97 +++++++---- go/store/nbs/table_set.go | 168 ++++++++++++-------- go/store/nbs/table_test.go | 30 ++-- 18 files changed, 427 insertions(+), 261 deletions(-) diff --git a/go/store/nbs/archive_build.go b/go/store/nbs/archive_build.go index 72f15d94f04..349bf239d11 100644 --- a/go/store/nbs/archive_build.go +++ b/go/store/nbs/archive_build.go @@ -425,7 +425,7 @@ func gatherAllChunks(ctx context.Context, cs chunkSource, idx tableIndex, stats return nil, nil, err } - bytes, err := cs.get(ctx, h, stats) + bytes, _, err := cs.get(ctx, h, nil, stats) if err != nil { return nil, nil, err } @@ -907,7 +907,7 @@ func (csc *simpleChunkSourceCache) get(ctx context.Context, h hash.Hash, stats * return chk, nil } - bytes, err := csc.cs.get(ctx, h, stats) + bytes, _, err := csc.cs.get(ctx, h, nil, stats) if bytes == nil || err != nil { return nil, err } @@ -919,7 +919,8 @@ func (csc *simpleChunkSourceCache) get(ctx context.Context, h hash.Hash, stats * // has returns true if the chunk is in the ChunkSource. This is not related to what is cached, just a helper. func (csc *simpleChunkSourceCache) has(h hash.Hash) (bool, error) { - return csc.cs.has(h) + res, _, err := csc.cs.has(h, nil) + return res, err } // addresses get all chunk addresses of the ChunkSource as a hash.HashSet. diff --git a/go/store/nbs/archive_chunk_source.go b/go/store/nbs/archive_chunk_source.go index 3ccd0183c57..f2d7a3a69e2 100644 --- a/go/store/nbs/archive_chunk_source.go +++ b/go/store/nbs/archive_chunk_source.go @@ -64,42 +64,60 @@ func openReader(file string) (io.ReaderAt, uint64, error) { return f, uint64(stat.Size()), nil } -func (acs archiveChunkSource) has(h hash.Hash) (bool, error) { - return acs.aRdr.has(h), nil +func (acs archiveChunkSource) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { + res := acs.aRdr.has(h) + if res && keeper != nil && keeper(h) { + return false, gcBehavior_Block, nil + } + return res, gcBehavior_Continue, nil } -func (acs archiveChunkSource) hasMany(addrs []hasRecord) (bool, error) { +func (acs archiveChunkSource) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { // single threaded first pass. foundAll := true for i, addr := range addrs { - if acs.aRdr.has(*(addr.a)) { + h := *addr.a + if acs.aRdr.has(h) { + if keeper != nil && keeper(h) { + return false, gcBehavior_Block, nil + } addrs[i].has = true } else { foundAll = false } } - return !foundAll, nil + return !foundAll, gcBehavior_Continue, nil } -func (acs archiveChunkSource) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { - // ctx, stats ? NM4. - return acs.aRdr.get(h) +func (acs archiveChunkSource) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { + res, err := acs.aRdr.get(h) + if err != nil { + return nil, gcBehavior_Continue, err + } + if res != nil && keeper != nil && keeper(h) { + return nil, gcBehavior_Block, nil + } + return res, gcBehavior_Continue, nil } -func (acs archiveChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { +func (acs archiveChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { // single threaded first pass. foundAll := true for i, req := range reqs { - data, err := acs.aRdr.get(*req.a) + h := *req.a + data, err := acs.aRdr.get(h) if err != nil || data == nil { foundAll = false } else { + if keeper != nil && keeper(h) { + return true, gcBehavior_Block, nil + } chunk := chunks.NewChunk(data) found(ctx, &chunk) reqs[i].found = true } } - return !foundAll, nil + return !foundAll, gcBehavior_Continue, nil } // iterate iterates over the archive chunks. The callback is called for each chunk in the archive. This is not optimized @@ -146,14 +164,14 @@ func (acs archiveChunkSource) clone() (chunkSource, error) { return archiveChunkSource{acs.file, rdr}, nil } -func (acs archiveChunkSource) getRecordRanges(_ context.Context, _ []getRecord) (map[hash.Hash]Range, error) { - return nil, errors.New("Archive chunk source does not support getRecordRanges") +func (acs archiveChunkSource) getRecordRanges(_ context.Context, _ []getRecord, _ keeperF) (map[hash.Hash]Range, gcBehavior, error) { + return nil, gcBehavior_Continue, errors.New("Archive chunk source does not support getRecordRanges") } -func (acs archiveChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (acs archiveChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { return acs.getMany(ctx, eg, reqs, func(ctx context.Context, chk *chunks.Chunk) { found(ctx, ChunkToCompressedChunk(*chk)) - }, stats) + }, keeper, stats) } func (acs archiveChunkSource) iterateAllChunks(ctx context.Context, cb func(chunks.Chunk)) error { diff --git a/go/store/nbs/archive_test.go b/go/store/nbs/archive_test.go index da3c324cf2a..c78d3b5710e 100644 --- a/go/store/nbs/archive_test.go +++ b/go/store/nbs/archive_test.go @@ -655,28 +655,28 @@ type testChunkSource struct { var _ chunkSource = (*testChunkSource)(nil) -func (tcs *testChunkSource) get(_ context.Context, h hash.Hash, _ *Stats) ([]byte, error) { +func (tcs *testChunkSource) get(_ context.Context, h hash.Hash, _ keeperF, _ *Stats) ([]byte, gcBehavior, error) { for _, chk := range tcs.chunks { if chk.Hash() == h { - return chk.Data(), nil + return chk.Data(), gcBehavior_Continue, nil } } - return nil, errors.New("not found") + return nil, gcBehavior_Continue, errors.New("not found") } -func (tcs *testChunkSource) has(h hash.Hash) (bool, error) { +func (tcs *testChunkSource) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { panic("never used") } -func (tcs *testChunkSource) hasMany(addrs []hasRecord) (bool, error) { +func (tcs *testChunkSource) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { panic("never used") } -func (tcs *testChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { +func (tcs *testChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { panic("never used") } -func (tcs *testChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (tcs *testChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { panic("never used") } @@ -700,7 +700,7 @@ func (tcs *testChunkSource) reader(ctx context.Context) (io.ReadCloser, uint64, panic("never used") } -func (tcs *testChunkSource) getRecordRanges(ctx context.Context, requests []getRecord) (map[hash.Hash]Range, error) { +func (tcs *testChunkSource) getRecordRanges(ctx context.Context, requests []getRecord, keeper keeperF) (map[hash.Hash]Range, gcBehavior, error) { panic("never used") } diff --git a/go/store/nbs/cmp_chunk_table_writer_test.go b/go/store/nbs/cmp_chunk_table_writer_test.go index 1f91d55b025..33323d018b0 100644 --- a/go/store/nbs/cmp_chunk_table_writer_test.go +++ b/go/store/nbs/cmp_chunk_table_writer_test.go @@ -51,7 +51,7 @@ func TestCmpChunkTableWriter(t *testing.T) { found := make([]CompressedChunk, 0) eg, egCtx := errgroup.WithContext(ctx) - _, err = tr.getManyCompressed(egCtx, eg, reqs, func(ctx context.Context, c CompressedChunk) { found = append(found, c) }, &Stats{}) + _, _, err = tr.getManyCompressed(egCtx, eg, reqs, func(ctx context.Context, c CompressedChunk) { found = append(found, c) }, nil, &Stats{}) require.NoError(t, err) require.NoError(t, eg.Wait()) @@ -146,7 +146,7 @@ func readAllChunks(ctx context.Context, hashes hash.HashSet, reader tableReader) reqs := toGetRecords(hashes) found := make([]*chunks.Chunk, 0) eg, ctx := errgroup.WithContext(ctx) - _, err := reader.getMany(ctx, eg, reqs, func(ctx context.Context, c *chunks.Chunk) { found = append(found, c) }, &Stats{}) + _, _, err := reader.getMany(ctx, eg, reqs, func(ctx context.Context, c *chunks.Chunk) { found = append(found, c) }, nil, &Stats{}) if err != nil { return nil, err } diff --git a/go/store/nbs/conjoiner_test.go b/go/store/nbs/conjoiner_test.go index 846aa4c7f4a..01438450e7e 100644 --- a/go/store/nbs/conjoiner_test.go +++ b/go/store/nbs/conjoiner_test.go @@ -159,11 +159,11 @@ func testConjoin(t *testing.T, factory func(t *testing.T) tablePersister) { var ok bool for _, act := range actualSrcs { var err error - ok, err = act.has(rec.a) + ok, _, err = act.has(rec.a, nil) require.NoError(t, err) var buf []byte if ok { - buf, err = act.get(ctx, rec.a, stats) + buf, _, err = act.get(ctx, rec.a, nil, stats) require.NoError(t, err) assert.Equal(t, rec.data, buf) break diff --git a/go/store/nbs/empty_chunk_source.go b/go/store/nbs/empty_chunk_source.go index 5df2696c33d..8d00c820de4 100644 --- a/go/store/nbs/empty_chunk_source.go +++ b/go/store/nbs/empty_chunk_source.go @@ -34,24 +34,24 @@ import ( type emptyChunkSource struct{} -func (ecs emptyChunkSource) has(h hash.Hash) (bool, error) { - return false, nil +func (ecs emptyChunkSource) has(h hash.Hash, _ keeperF) (bool, gcBehavior, error) { + return false, gcBehavior_Continue, nil } -func (ecs emptyChunkSource) hasMany(addrs []hasRecord) (bool, error) { - return true, nil +func (ecs emptyChunkSource) hasMany(addrs []hasRecord, _ keeperF) (bool, gcBehavior, error) { + return true, gcBehavior_Continue, nil } -func (ecs emptyChunkSource) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { - return nil, nil +func (ecs emptyChunkSource) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { + return nil, gcBehavior_Continue, nil } -func (ecs emptyChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { - return true, nil +func (ecs emptyChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + return true, gcBehavior_Continue, nil } -func (ecs emptyChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { - return true, nil +func (ecs emptyChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + return true, gcBehavior_Continue, nil } func (ecs emptyChunkSource) count() (uint32, error) { @@ -74,8 +74,8 @@ func (ecs emptyChunkSource) reader(context.Context) (io.ReadCloser, uint64, erro return io.NopCloser(&bytes.Buffer{}), 0, nil } -func (ecs emptyChunkSource) getRecordRanges(ctx context.Context, requests []getRecord) (map[hash.Hash]Range, error) { - return map[hash.Hash]Range{}, nil +func (ecs emptyChunkSource) getRecordRanges(ctx context.Context, requests []getRecord, keeper keeperF) (map[hash.Hash]Range, gcBehavior, error) { + return map[hash.Hash]Range{}, gcBehavior_Continue, nil } func (ecs emptyChunkSource) currentSize() uint64 { diff --git a/go/store/nbs/journal.go b/go/store/nbs/journal.go index 8f415cbfa3d..df7fbf95622 100644 --- a/go/store/nbs/journal.go +++ b/go/store/nbs/journal.go @@ -248,7 +248,8 @@ func (j *ChunkJournal) Persist(ctx context.Context, mt *memTable, haver chunkRea if haver != nil { sort.Sort(hasRecordByPrefix(mt.order)) // hasMany() requires addresses to be sorted. - if _, err := haver.hasMany(mt.order); err != nil { + // TODO: keeperF + if _, _, err := haver.hasMany(mt.order, nil); err != nil { return nil, err } sort.Sort(hasRecordByOrder(mt.order)) // restore "insertion" order for write diff --git a/go/store/nbs/journal_chunk_source.go b/go/store/nbs/journal_chunk_source.go index c8dd8a4ac02..e7bf50dc4a9 100644 --- a/go/store/nbs/journal_chunk_source.go +++ b/go/store/nbs/journal_chunk_source.go @@ -39,20 +39,29 @@ type journalChunkSource struct { var _ chunkSource = journalChunkSource{} -func (s journalChunkSource) has(h hash.Hash) (bool, error) { - return s.journal.hasAddr(h), nil +func (s journalChunkSource) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { + res := s.journal.hasAddr(h) + if res && keeper != nil && keeper(h) { + return false, gcBehavior_Block, nil + } + return res, gcBehavior_Continue, nil } -func (s journalChunkSource) hasMany(addrs []hasRecord) (missing bool, err error) { +func (s journalChunkSource) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { + missing := false for i := range addrs { - ok := s.journal.hasAddr(*addrs[i].a) + h := *addrs[i].a + ok := s.journal.hasAddr(h) if ok { + if keeper != nil && keeper(h) { + return true, gcBehavior_Block, nil + } addrs[i].has = true } else { missing = true } } - return + return missing, gcBehavior_Continue, nil } func (s journalChunkSource) getCompressed(ctx context.Context, h hash.Hash, _ *Stats) (CompressedChunk, error) { @@ -60,20 +69,23 @@ func (s journalChunkSource) getCompressed(ctx context.Context, h hash.Hash, _ *S return s.journal.getCompressedChunk(h) } -func (s journalChunkSource) get(ctx context.Context, h hash.Hash, _ *Stats) ([]byte, error) { +func (s journalChunkSource) get(ctx context.Context, h hash.Hash, keeper keeperF, _ *Stats) ([]byte, gcBehavior, error) { defer trace.StartRegion(ctx, "journalChunkSource.get").End() cc, err := s.journal.getCompressedChunk(h) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } else if cc.IsEmpty() { - return nil, nil + return nil, gcBehavior_Continue, nil + } + if keeper != nil && keeper(h) { + return nil, gcBehavior_Block, nil } ch, err := cc.ToChunk() if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } - return ch.Data(), nil + return ch.Data(), gcBehavior_Continue, nil } type journalRecord struct { @@ -83,7 +95,7 @@ type journalRecord struct { idx int } -func (s journalChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { +func (s journalChunkSource) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { return s.getManyCompressed(ctx, eg, reqs, func(ctx context.Context, cc CompressedChunk) { ch, err := cc.ToChunk() if err != nil { @@ -94,7 +106,7 @@ func (s journalChunkSource) getMany(ctx context.Context, eg *errgroup.Group, req } chWHash := chunks.NewChunkWithHash(cc.Hash(), ch.Data()) found(ctx, &chWHash) - }, stats) + }, keeper, stats) } // getManyCompressed implements chunkReader. Here we (1) synchronously check @@ -103,7 +115,7 @@ func (s journalChunkSource) getMany(ctx context.Context, eg *errgroup.Group, req // and then (4) asynchronously perform reads. We release the journal read // lock after returning when all reads are completed, which can be after the // function returns. -func (s journalChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (s journalChunkSource) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { defer trace.StartRegion(ctx, "journalChunkSource.getManyCompressed").End() var remaining bool @@ -114,11 +126,16 @@ func (s journalChunkSource) getManyCompressed(ctx context.Context, eg *errgroup. if r.found { continue } - rang, ok := s.journal.ranges.get(*r.a) + h := *r.a + rang, ok := s.journal.ranges.get(h) if !ok { remaining = true continue } + if keeper != nil && keeper(h) { + s.journal.lock.RUnlock() + return true, gcBehavior_Block, nil + } jReqs = append(jReqs, journalRecord{r: rang, idx: i}) reqs[i].found = true } @@ -150,7 +167,7 @@ func (s journalChunkSource) getManyCompressed(ctx context.Context, eg *errgroup. wg.Wait() s.journal.lock.RUnlock() }() - return remaining, nil + return remaining, gcBehavior_Continue, nil } func (s journalChunkSource) count() (uint32, error) { @@ -171,22 +188,26 @@ func (s journalChunkSource) reader(ctx context.Context) (io.ReadCloser, uint64, return rdr, uint64(sz), err } -func (s journalChunkSource) getRecordRanges(ctx context.Context, requests []getRecord) (map[hash.Hash]Range, error) { +func (s journalChunkSource) getRecordRanges(ctx context.Context, requests []getRecord, keeper keeperF) (map[hash.Hash]Range, gcBehavior, error) { ranges := make(map[hash.Hash]Range, len(requests)) for _, req := range requests { if req.found { continue } - rng, ok, err := s.journal.getRange(ctx, *req.a) + h := *req.a + rng, ok, err := s.journal.getRange(ctx, h) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } else if !ok { continue } + if keeper != nil && keeper(h) { + return nil, gcBehavior_Block, nil + } req.found = true // update |requests| - ranges[hash.Hash(*req.a)] = rng + ranges[h] = rng } - return ranges, nil + return ranges, gcBehavior_Continue, nil } // size implements chunkSource. diff --git a/go/store/nbs/journal_test.go b/go/store/nbs/journal_test.go index 9486f1edf17..b2f24f01702 100644 --- a/go/store/nbs/journal_test.go +++ b/go/store/nbs/journal_test.go @@ -71,10 +71,10 @@ func TestChunkJournalPersist(t *testing.T) { assert.NoError(t, err) for h, ch := range chunkMap { - ok, err := source.has(h) + ok, _, err := source.has(h, nil) assert.NoError(t, err) assert.True(t, ok) - data, err := source.get(ctx, h, stats) + data, _, err := source.get(ctx, h, nil, stats) assert.NoError(t, err) assert.Equal(t, ch.Data(), data) } @@ -108,11 +108,11 @@ func TestReadRecordRanges(t *testing.T) { require.NoError(t, err) assert.Equal(t, int(sz), n) - ranges, err := jcs.getRecordRanges(ctx, gets) + ranges, _, err := jcs.getRecordRanges(ctx, gets, nil) require.NoError(t, err) for h, rng := range ranges { - b, err := jcs.get(ctx, h, &Stats{}) + b, _, err := jcs.get(ctx, h, nil, &Stats{}) assert.NoError(t, err) ch1 := chunks.NewChunkWithHash(h, b) assert.Equal(t, data[h], ch1) diff --git a/go/store/nbs/journal_writer_test.go b/go/store/nbs/journal_writer_test.go index e6f6e59fd33..77263f23ab2 100644 --- a/go/store/nbs/journal_writer_test.go +++ b/go/store/nbs/journal_writer_test.go @@ -228,7 +228,7 @@ func TestJournalWriterBootstrap(t *testing.T) { source := journalChunkSource{journal: j} for a, cc := range data { - buf, err := source.get(ctx, a, nil) + buf, _, err := source.get(ctx, a, nil, nil) require.NoError(t, err) ch, err := cc.ToChunk() require.NoError(t, err) diff --git a/go/store/nbs/mem_table.go b/go/store/nbs/mem_table.go index cbffa34a72b..135e3f8feb5 100644 --- a/go/store/nbs/mem_table.go +++ b/go/store/nbs/mem_table.go @@ -135,22 +135,27 @@ func (mt *memTable) uncompressedLen() (uint64, error) { return mt.totalData, nil } -func (mt *memTable) has(h hash.Hash) (bool, error) { +func (mt *memTable) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { _, has := mt.chunks[h] - return has, nil + if has && keeper != nil && keeper(h) { + return false, gcBehavior_Block, nil + } + return has, gcBehavior_Continue, nil } -func (mt *memTable) hasMany(addrs []hasRecord) (bool, error) { +func (mt *memTable) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { var remaining bool for i, addr := range addrs { if addr.has { continue } - ok, err := mt.has(*addr.a) - + ok, gcb, err := mt.has(*addr.a, keeper) if err != nil { - return false, err + return false, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return ok, gcb, nil } if ok { @@ -159,18 +164,25 @@ func (mt *memTable) hasMany(addrs []hasRecord) (bool, error) { remaining = true } } - return remaining, nil + return remaining, gcBehavior_Continue, nil } -func (mt *memTable) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { - return mt.chunks[h], nil +func (mt *memTable) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { + c, ok := mt.chunks[h] + if ok && keeper != nil && keeper(h) { + return nil, gcBehavior_Block, nil + } + return c, gcBehavior_Continue, nil } -func (mt *memTable) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { +func (mt *memTable) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { var remaining bool for i, r := range reqs { data := mt.chunks[*r.a] if data != nil { + if keeper != nil && keeper(*r.a) { + return true, gcBehavior_Block, nil + } c := chunks.NewChunkWithHash(hash.Hash(*r.a), data) reqs[i].found = true found(ctx, &c) @@ -178,14 +190,17 @@ func (mt *memTable) getMany(ctx context.Context, eg *errgroup.Group, reqs []getR remaining = true } } - return remaining, nil + return remaining, gcBehavior_Continue, nil } -func (mt *memTable) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (mt *memTable) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { var remaining bool for i, r := range reqs { data := mt.chunks[*r.a] if data != nil { + if keeper != nil && keeper(*r.a) { + return true, gcBehavior_Block, nil + } c := chunks.NewChunkWithHash(hash.Hash(*r.a), data) reqs[i].found = true found(ctx, ChunkToCompressedChunk(c)) @@ -194,7 +209,7 @@ func (mt *memTable) getManyCompressed(ctx context.Context, eg *errgroup.Group, r } } - return remaining, nil + return remaining, gcBehavior_Continue, nil } func (mt *memTable) extract(ctx context.Context, chunks chan<- extractRecord) error { @@ -217,8 +232,8 @@ func (mt *memTable) write(haver chunkReader, stats *Stats) (name hash.Hash, data if haver != nil { sort.Sort(hasRecordByPrefix(mt.order)) // hasMany() requires addresses to be sorted. - _, err := haver.hasMany(mt.order) - + // TODO: keeperF + _, _, err := haver.hasMany(mt.order, nil) if err != nil { return hash.Hash{}, nil, 0, err } diff --git a/go/store/nbs/mem_table_test.go b/go/store/nbs/mem_table_test.go index 6802abc66c4..3b7b2c222e5 100644 --- a/go/store/nbs/mem_table_test.go +++ b/go/store/nbs/mem_table_test.go @@ -92,14 +92,14 @@ func TestMemTableAddHasGetChunk(t *testing.T) { assertChunksInReader(chunks, mt, assert) for _, c := range chunks { - data, err := mt.get(context.Background(), computeAddr(c), &Stats{}) + data, _, err := mt.get(context.Background(), computeAddr(c), nil, &Stats{}) require.NoError(t, err) assert.Equal(bytes.Compare(c, data), 0) } notPresent := []byte("nope") - assert.False(mt.has(computeAddr(notPresent))) - assert.Nil(mt.get(context.Background(), computeAddr(notPresent), &Stats{})) + assert.False(mt.has(computeAddr(notPresent), nil)) + assert.Nil(mt.get(context.Background(), computeAddr(notPresent), nil, &Stats{})) } func TestMemTableAddOverflowChunk(t *testing.T) { @@ -112,9 +112,9 @@ func TestMemTableAddOverflowChunk(t *testing.T) { bigAddr := computeAddr(big) mt := newMemTable(memTableSize) assert.Equal(mt.addChunk(bigAddr, big), chunkAdded) - assert.True(mt.has(bigAddr)) + assert.True(mt.has(bigAddr, nil)) assert.Equal(mt.addChunk(computeAddr(little), little), chunkNotAdded) - assert.False(mt.has(computeAddr(little))) + assert.False(mt.has(computeAddr(little), nil)) } { @@ -122,12 +122,12 @@ func TestMemTableAddOverflowChunk(t *testing.T) { bigAddr := computeAddr(big) mt := newMemTable(memTableSize) assert.Equal(mt.addChunk(bigAddr, big), chunkAdded) - assert.True(mt.has(bigAddr)) + assert.True(mt.has(bigAddr, nil)) assert.Equal(mt.addChunk(computeAddr(little), little), chunkAdded) - assert.True(mt.has(computeAddr(little))) + assert.True(mt.has(computeAddr(little), nil)) other := []byte("o") assert.Equal(mt.addChunk(computeAddr(other), other), chunkNotAdded) - assert.False(mt.has(computeAddr(other))) + assert.False(mt.has(computeAddr(other), nil)) } } @@ -153,7 +153,7 @@ func TestMemTableWrite(t *testing.T) { tr1, err := newTableReader(ti1, tableReaderAtFromBytes(td1), fileBlockSize) require.NoError(t, err) defer tr1.close() - assert.True(tr1.has(computeAddr(chunks[1]))) + assert.True(tr1.has(computeAddr(chunks[1]), nil)) td2, _, err := buildTable(chunks[2:]) require.NoError(t, err) @@ -162,7 +162,7 @@ func TestMemTableWrite(t *testing.T) { tr2, err := newTableReader(ti2, tableReaderAtFromBytes(td2), fileBlockSize) require.NoError(t, err) defer tr2.close() - assert.True(tr2.has(computeAddr(chunks[2]))) + assert.True(tr2.has(computeAddr(chunks[2]), nil)) _, data, count, err := mt.write(chunkReaderGroup{tr1, tr2}, &Stats{}) require.NoError(t, err) @@ -173,9 +173,9 @@ func TestMemTableWrite(t *testing.T) { outReader, err := newTableReader(ti, tableReaderAtFromBytes(data), fileBlockSize) require.NoError(t, err) defer outReader.close() - assert.True(outReader.has(computeAddr(chunks[0]))) - assert.False(outReader.has(computeAddr(chunks[1]))) - assert.False(outReader.has(computeAddr(chunks[2]))) + assert.True(outReader.has(computeAddr(chunks[0]), nil)) + assert.False(outReader.has(computeAddr(chunks[1]), nil)) + assert.False(outReader.has(computeAddr(chunks[2]), nil)) } type tableReaderAtAdapter struct { @@ -239,72 +239,82 @@ func (o *outOfLineSnappy) Encode(dst, src []byte) []byte { type chunkReaderGroup []chunkReader -func (crg chunkReaderGroup) has(h hash.Hash) (bool, error) { +func (crg chunkReaderGroup) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { for _, haver := range crg { - ok, err := haver.has(h) - + ok, gcb, err := haver.has(h, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil } if ok { - return true, nil + return true, gcb, nil } } - - return false, nil + return false, gcBehavior_Continue, nil } -func (crg chunkReaderGroup) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { +func (crg chunkReaderGroup) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { for _, haver := range crg { - if data, err := haver.get(ctx, h, stats); err != nil { - return nil, err + if data, gcb, err := haver.get(ctx, h, keeper, stats); err != nil { + return nil, gcb, err + } else if gcb != gcBehavior_Continue { + return nil, gcb, nil } else if data != nil { - return data, nil + return data, gcb, nil } } - return nil, nil + return nil, gcBehavior_Continue, nil } -func (crg chunkReaderGroup) hasMany(addrs []hasRecord) (bool, error) { +func (crg chunkReaderGroup) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { for _, haver := range crg { - remaining, err := haver.hasMany(addrs) - + remaining, gcb, err := haver.hasMany(addrs, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return false, gcb, nil } - if !remaining { - return false, nil + return false, gcb, nil } } - return true, nil + return true, gcBehavior_Continue, nil } -func (crg chunkReaderGroup) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) { +func (crg chunkReaderGroup) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { for _, haver := range crg { - remaining, err := haver.getMany(ctx, eg, reqs, found, stats) + remaining, gcb, err := haver.getMany(ctx, eg, reqs, found, keeper, stats) if err != nil { - return true, err + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil } if !remaining { - return false, nil + return false, gcb, nil } } - return true, nil + return true, gcBehavior_Continue, nil } -func (crg chunkReaderGroup) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (crg chunkReaderGroup) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { for _, haver := range crg { - remaining, err := haver.getManyCompressed(ctx, eg, reqs, found, stats) + remaining, gcb, err := haver.getManyCompressed(ctx, eg, reqs, found, keeper, stats) if err != nil { - return true, err + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil } if !remaining { - return false, nil + return false, gcb, nil } } - return true, nil + return true, gcBehavior_Continue, nil } func (crg chunkReaderGroup) count() (count uint32, err error) { diff --git a/go/store/nbs/root_tracker_test.go b/go/store/nbs/root_tracker_test.go index f428f0817eb..b5e82e8e182 100644 --- a/go/store/nbs/root_tracker_test.go +++ b/go/store/nbs/root_tracker_test.go @@ -661,7 +661,7 @@ func extractAllChunks(ctx context.Context, src chunkSource, cb func(rec extractR return err } - data, err := src.get(ctx, h, nil) + data, _, err := src.get(ctx, h, nil, nil) if err != nil { return err } diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index 6830c344702..f8f6bda09bb 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -163,7 +163,8 @@ func (nbs *NomsBlockStore) GetChunkLocations(ctx context.Context, hashes hash.Ha fn := func(css chunkSourceSet) error { for _, cs := range css { - rng, err := cs.getRecordRanges(ctx, gr) + // TODO: keeperF + rng, _, err := cs.getRecordRanges(ctx, gr, nil) if err != nil { return err } @@ -859,7 +860,8 @@ func (nbs *NomsBlockStore) Get(ctx context.Context, h hash.Hash) (chunks.Chunk, defer nbs.mu.RUnlock() if nbs.mt != nil { var err error - data, err = nbs.mt.get(ctx, h, nbs.stats) + // TODO: keeperF + data, _, err = nbs.mt.get(ctx, h, nil, nbs.stats) if err != nil { return nil, nil, err @@ -876,7 +878,8 @@ func (nbs *NomsBlockStore) Get(ctx context.Context, h hash.Hash) (chunks.Chunk, return chunks.NewChunkWithHash(h, data), nil } - data, err = tables.get(ctx, h, nbs.stats) + // TODO: keeperF + data, _, err = tables.get(ctx, h, nil, nbs.stats) if err != nil { return chunks.EmptyChunk, err @@ -893,7 +896,9 @@ func (nbs *NomsBlockStore) GetMany(ctx context.Context, hashes hash.HashSet, fou ctx, span := tracer.Start(ctx, "nbs.GetMany", trace.WithAttributes(attribute.Int("num_hashes", len(hashes)))) span.End() return nbs.getManyWithFunc(ctx, hashes, func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, stats *Stats) (bool, error) { - return cr.getMany(ctx, eg, reqs, found, nbs.stats) + // TODO: keeperF + res, _, err := cr.getMany(ctx, eg, reqs, found, nil, nbs.stats) + return res, err }) } @@ -901,7 +906,9 @@ func (nbs *NomsBlockStore) GetManyCompressed(ctx context.Context, hashes hash.Ha ctx, span := tracer.Start(ctx, "nbs.GetManyCompressed", trace.WithAttributes(attribute.Int("num_hashes", len(hashes)))) defer span.End() return nbs.getManyWithFunc(ctx, hashes, func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, stats *Stats) (bool, error) { - return cr.getManyCompressed(ctx, eg, reqs, found, nbs.stats) + // TODO: keeperF + res, _, err := cr.getManyCompressed(ctx, eg, reqs, found, nil, nbs.stats) + return res, err }) } @@ -1005,7 +1012,8 @@ func (nbs *NomsBlockStore) Has(ctx context.Context, h hash.Hash) (bool, error) { defer nbs.mu.RUnlock() if nbs.mt != nil { - has, err := nbs.mt.has(h) + // TODO: keeperF + has, _, err := nbs.mt.has(h, nil) if err != nil { return false, nil, err @@ -1022,7 +1030,8 @@ func (nbs *NomsBlockStore) Has(ctx context.Context, h hash.Hash) (bool, error) { } if !has { - has, err = tables.has(h) + // TODO: keeperF + has, _, err = tables.has(h, nil) if err != nil { return false, err @@ -1060,7 +1069,8 @@ func (nbs *NomsBlockStore) hasManyInSources(srcs []hash.Hash, hashes hash.HashSe records := toHasRecords(hashes) - _, err := nbs.tables.hasManyInSources(srcs, records) + // TODO: keeperF + _, _, err := nbs.tables.hasManyInSources(srcs, records, nil) if err != nil { return nil, err } @@ -1080,7 +1090,8 @@ func (nbs *NomsBlockStore) hasMany(reqs []hasRecord) (hash.HashSet, error) { remaining = true if nbs.mt != nil { - remaining, err = nbs.mt.hasMany(reqs) + // TODO: keeperF + remaining, _, err = nbs.mt.hasMany(reqs, nil) if err != nil { return nil, false, err @@ -1095,7 +1106,8 @@ func (nbs *NomsBlockStore) hasMany(reqs []hasRecord) (hash.HashSet, error) { } if remaining { - _, err := tables.hasMany(reqs) + // TODO: keeperF + _, _, err := tables.hasMany(reqs, nil) if err != nil { return nil, err @@ -1965,7 +1977,7 @@ func (nbs *NomsBlockStore) setRootChunk(ctx context.Context, root, previous hash } // CalcReads computes the number of IO operations necessary to fetch |hashes|. -func CalcReads(nbs *NomsBlockStore, hashes hash.HashSet, blockSize uint64) (reads int, split bool, err error) { +func CalcReads(nbs *NomsBlockStore, hashes hash.HashSet, blockSize uint64, keeper keeperF) (int, bool, gcBehavior, error) { reqs := toGetRecords(hashes) tables := func() (tables tableSet) { nbs.mu.RLock() @@ -1975,15 +1987,17 @@ func CalcReads(nbs *NomsBlockStore, hashes hash.HashSet, blockSize uint64) (read return }() - reads, split, remaining, err := tableSetCalcReads(tables, reqs, blockSize) - + reads, split, remaining, gcb, err := tableSetCalcReads(tables, reqs, blockSize, keeper) if err != nil { - return 0, false, err + return 0, false, gcb, err + } + if gcb != gcBehavior_Continue { + return 0, false, gcb, nil } if remaining { - return 0, false, errors.New("failed to find all chunks") + return 0, false, gcBehavior_Continue, errors.New("failed to find all chunks") } - return + return reads, split, gcb, err } diff --git a/go/store/nbs/table.go b/go/store/nbs/table.go index b487422d079..106c6f7f6aa 100644 --- a/go/store/nbs/table.go +++ b/go/store/nbs/table.go @@ -187,24 +187,41 @@ type extractRecord struct { err error } +// Returned by read methods that take a |keeperFunc|, this lets a +// caller know whether the operation was successful or if it needs to +// be retried. It may need to be retried if a GC is in progress but +// the dependencies indicated by the operation cannot be added to the +// GC process. In that case, the caller needs to wait until the GC is +// over and run the entire operation again. +type gcBehavior bool + +const ( + // Operation was successful, go forward with the result. + gcBehavior_Continue gcBehavior = false + // Operation needs to block until the GC is over and then retry. + gcBehavior_Block = true +) + +type keeperF func(hash.Hash) bool + type chunkReader interface { // has returns true if a chunk with addr |h| is present. - has(h hash.Hash) (bool, error) + has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) // hasMany sets hasRecord.has to true for each present hasRecord query, it returns // true if any hasRecord query was not found in this chunkReader. - hasMany(addrs []hasRecord) (bool, error) + hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) // get returns the chunk data for a chunk with addr |h| if present, and nil otherwise. - get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) + get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) // getMany sets getRecord.found to true, and calls |found| for each present getRecord query. // It returns true if any getRecord query was not found in this chunkReader. - getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) + getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) // getManyCompressed sets getRecord.found to true, and calls |found| for each present getRecord query. // It returns true if any getRecord query was not found in this chunkReader. - getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) + getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) // count returns the chunk count for this chunkReader. count() (uint32, error) @@ -226,7 +243,7 @@ type chunkSource interface { reader(context.Context) (io.ReadCloser, uint64, error) // getRecordRanges sets getRecord.found to true, and returns a Range for each present getRecord query. - getRecordRanges(ctx context.Context, requests []getRecord) (map[hash.Hash]Range, error) + getRecordRanges(ctx context.Context, requests []getRecord, keeper keeperF) (map[hash.Hash]Range, gcBehavior, error) // index returns the tableIndex of this chunkSource. index() (tableIndex, error) diff --git a/go/store/nbs/table_reader.go b/go/store/nbs/table_reader.go index 3ff059fb480..c55d48c28a4 100644 --- a/go/store/nbs/table_reader.go +++ b/go/store/nbs/table_reader.go @@ -178,7 +178,7 @@ func newTableReader(index tableIndex, r tableReaderAt, blockSize uint64) (tableR } // Scan across (logically) two ordered slices of address prefixes. -func (tr tableReader) hasMany(addrs []hasRecord) (bool, error) { +func (tr tableReader) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { filterIdx := uint32(0) filterLen := uint32(tr.idx.chunkCount()) @@ -206,7 +206,7 @@ func (tr tableReader) hasMany(addrs []hasRecord) (bool, error) { } if filterIdx >= filterLen { - return true, nil + return true, gcBehavior_Continue, nil } if addr.prefix != tr.prefixes[filterIdx] { @@ -218,9 +218,12 @@ func (tr tableReader) hasMany(addrs []hasRecord) (bool, error) { for j := filterIdx; j < filterLen && addr.prefix == tr.prefixes[j]; j++ { m, err := tr.idx.entrySuffixMatches(j, addr.a) if err != nil { - return false, err + return false, gcBehavior_Continue, err } if m { + if keeper != nil && keeper(*addr.a) { + return true, gcBehavior_Block, nil + } addrs[i].has = true break } @@ -231,7 +234,7 @@ func (tr tableReader) hasMany(addrs []hasRecord) (bool, error) { } } - return remaining, nil + return remaining, gcBehavior_Continue, nil } func (tr tableReader) count() (uint32, error) { @@ -247,20 +250,27 @@ func (tr tableReader) index() (tableIndex, error) { } // returns true iff |h| can be found in this table. -func (tr tableReader) has(h hash.Hash) (bool, error) { +func (tr tableReader) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { _, ok, err := tr.idx.lookup(&h) - return ok, err + if ok && keeper != nil && keeper(h) { + return false, gcBehavior_Block, nil + } + return ok, gcBehavior_Continue, err } // returns the storage associated with |h|, iff present. Returns nil if absent. On success, // the returned byte slice directly references the underlying storage. -func (tr tableReader) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { +func (tr tableReader) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { e, found, err := tr.idx.lookup(&h) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } if !found { - return nil, nil + return nil, gcBehavior_Continue, nil + } + + if keeper != nil && keeper(h) { + return nil, gcBehavior_Block, nil } offset := e.Offset() @@ -270,30 +280,30 @@ func (tr tableReader) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byt n, err := tr.r.ReadAtWithStats(ctx, buff, int64(offset), stats) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } if n != int(length) { - return nil, errors.New("failed to read all data") + return nil, gcBehavior_Continue, errors.New("failed to read all data") } cmp, err := NewCompressedChunk(h, buff) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } if len(cmp.CompressedData) == 0 { - return nil, errors.New("failed to get data") + return nil, gcBehavior_Continue, errors.New("failed to get data") } chnk, err := cmp.ToChunk() if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } - return chnk.Data(), nil + return chnk.Data(), gcBehavior_Continue, nil } type offsetRec struct { @@ -380,26 +390,33 @@ func (tr tableReader) getMany( eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), - stats *Stats) (bool, error) { + keeper keeperF, + stats *Stats) (bool, gcBehavior, error) { // Pass #1: Iterate over |reqs| and |tr.prefixes| (both sorted by address) and build the set // of table locations which must be read in order to satisfy the getMany operation. - offsetRecords, remaining, err := tr.findOffsets(reqs) + offsetRecords, remaining, gcb, err := tr.findOffsets(reqs, keeper) if err != nil { - return false, err + return false, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return remaining, gcb, nil } err = tr.getManyAtOffsets(ctx, eg, offsetRecords, found, stats) - return remaining, err + return remaining, gcBehavior_Continue, err } -func (tr tableReader) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) { +func (tr tableReader) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { // Pass #1: Iterate over |reqs| and |tr.prefixes| (both sorted by address) and build the set // of table locations which must be read in order to satisfy the getMany operation. - offsetRecords, remaining, err := tr.findOffsets(reqs) + offsetRecords, remaining, gcb, err := tr.findOffsets(reqs, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return remaining, gcb, nil } err = tr.getManyCompressedAtOffsets(ctx, eg, offsetRecords, found, stats) - return remaining, err + return remaining, gcBehavior_Continue, err } func (tr tableReader) getManyCompressedAtOffsets(ctx context.Context, eg *errgroup.Group, offsetRecords offsetRecSlice, found func(context.Context, CompressedChunk), stats *Stats) error { @@ -498,7 +515,7 @@ func (tr tableReader) getManyAtOffsetsWithReadFunc( // chunks remaining will be set to false upon return. If some are not here, // then remaining will be true. The result offsetRecSlice is sorted in offset // order. -func (tr tableReader) findOffsets(reqs []getRecord) (ors offsetRecSlice, remaining bool, err error) { +func (tr tableReader) findOffsets(reqs []getRecord, keeper keeperF) (ors offsetRecSlice, remaining bool, gcb gcBehavior, err error) { filterIdx := uint32(0) filterLen := uint32(len(tr.prefixes)) ors = make(offsetRecSlice, 0, len(reqs)) @@ -541,13 +558,16 @@ func (tr tableReader) findOffsets(reqs []getRecord) (ors offsetRecSlice, remaini for j := filterIdx; j < filterLen && req.prefix == tr.prefixes[j]; j++ { m, err := tr.idx.entrySuffixMatches(j, req.a) if err != nil { - return nil, false, err + return nil, false, gcBehavior_Continue, err } if m { + if keeper != nil && keeper(*req.a) { + return nil, false, gcBehavior_Block, nil + } reqs[i].found = true entry, err := tr.idx.indexEntry(j, nil) if err != nil { - return nil, false, err + return nil, false, gcBehavior_Continue, err } ors = append(ors, offsetRec{req.a, entry.Offset(), entry.Length()}) break @@ -560,7 +580,7 @@ func (tr tableReader) findOffsets(reqs []getRecord) (ors offsetRecSlice, remaini } sort.Sort(ors) - return ors, remaining, nil + return ors, remaining, gcBehavior_Continue, nil } func canReadAhead(fRec offsetRec, curStart, curEnd, blockSize uint64) (newEnd uint64, canRead bool) { @@ -584,12 +604,15 @@ func canReadAhead(fRec offsetRec, curStart, curEnd, blockSize uint64) (newEnd ui return fRec.offset + uint64(fRec.length), true } -func (tr tableReader) calcReads(reqs []getRecord, blockSize uint64) (reads int, remaining bool, err error) { +func (tr tableReader) calcReads(reqs []getRecord, blockSize uint64, keeper keeperF) (int, bool, gcBehavior, error) { var offsetRecords offsetRecSlice // Pass #1: Build the set of table locations which must be read in order to find all the elements of |reqs| which are present in this table. - offsetRecords, remaining, err = tr.findOffsets(reqs) + offsetRecords, remaining, gcb, err := tr.findOffsets(reqs, keeper) if err != nil { - return 0, false, err + return 0, false, gcb, err + } + if gcb != gcBehavior_Continue { + return 0, false, gcb, nil } // Now |offsetRecords| contains all locations within the table which must @@ -597,6 +620,7 @@ func (tr tableReader) calcReads(reqs []getRecord, blockSize uint64) (reads int, // location). Scan forward, grouping sequences of reads into large physical // reads. + var reads int var readStart, readEnd uint64 readStarted := false @@ -622,7 +646,7 @@ func (tr tableReader) calcReads(reqs []getRecord, blockSize uint64) (reads int, readStarted = false } - return + return reads, remaining, gcBehavior_Continue, err } func (tr tableReader) extract(ctx context.Context, chunks chan<- extractRecord) error { @@ -681,11 +705,14 @@ func (tr tableReader) reader(ctx context.Context) (io.ReadCloser, uint64, error) return r, sz, nil } -func (tr tableReader) getRecordRanges(ctx context.Context, requests []getRecord) (map[hash.Hash]Range, error) { +func (tr tableReader) getRecordRanges(ctx context.Context, requests []getRecord, keeper keeperF) (map[hash.Hash]Range, gcBehavior, error) { // findOffsets sets getRecord.found - recs, _, err := tr.findOffsets(requests) + recs, _, gcb, err := tr.findOffsets(requests, keeper) if err != nil { - return nil, err + return nil, gcb, err + } + if gcb != gcBehavior_Continue { + return nil, gcb, nil } ranges := make(map[hash.Hash]Range, len(recs)) for _, r := range recs { @@ -694,7 +721,7 @@ func (tr tableReader) getRecordRanges(ctx context.Context, requests []getRecord) Length: r.length, } } - return ranges, nil + return ranges, gcBehavior_Continue, nil } func (tr tableReader) currentSize() uint64 { diff --git a/go/store/nbs/table_set.go b/go/store/nbs/table_set.go index 185743199a0..b647aad46a3 100644 --- a/go/store/nbs/table_set.go +++ b/go/store/nbs/table_set.go @@ -58,58 +58,62 @@ type tableSet struct { rl chan struct{} } -func (ts tableSet) has(h hash.Hash) (bool, error) { - f := func(css chunkSourceSet) (bool, error) { +func (ts tableSet) has(h hash.Hash, keeper keeperF) (bool, gcBehavior, error) { + f := func(css chunkSourceSet) (bool, gcBehavior, error) { for _, haver := range css { - has, err := haver.has(h) - + has, gcb, err := haver.has(h, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return false, gcb, nil } - if has { - return true, nil + return true, gcBehavior_Continue, nil } } - return false, nil + return false, gcBehavior_Continue, nil } - novelHas, err := f(ts.novel) - + novelHas, gcb, err := f(ts.novel) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return false, gcb, nil } - if novelHas { - return true, nil + return true, gcBehavior_Continue, nil } return f(ts.upstream) } -func (ts tableSet) hasMany(addrs []hasRecord) (bool, error) { - f := func(css chunkSourceSet) (bool, error) { +func (ts tableSet) hasMany(addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { + f := func(css chunkSourceSet) (bool, gcBehavior, error) { for _, haver := range css { - has, err := haver.hasMany(addrs) - + has, gcb, err := haver.hasMany(addrs, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return false, gcb, nil } - if !has { - return false, nil + return false, gcBehavior_Continue, nil } } - return true, nil + return true, gcBehavior_Continue, nil } - remaining, err := f(ts.novel) - + remaining, gcb, err := f(ts.novel) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return remaining, gcb, err } - if !remaining { - return false, nil + return false, gcBehavior_Continue, nil } return f(ts.upstream) @@ -124,7 +128,10 @@ func (ts tableSet) hasMany(addrs []hasRecord) (bool, error) { // consulted. Only used for part of the GC workflow where we want to have // access to all chunks in the store but need to check for existing chunk // presence in only a subset of its files. -func (ts tableSet) hasManyInSources(srcs []hash.Hash, addrs []hasRecord) (remaining bool, err error) { +func (ts tableSet) hasManyInSources(srcs []hash.Hash, addrs []hasRecord, keeper keeperF) (bool, gcBehavior, error) { + var remaining bool + var err error + var gcb gcBehavior for _, rec := range addrs { if !rec.has { remaining = true @@ -132,7 +139,7 @@ func (ts tableSet) hasManyInSources(srcs []hash.Hash, addrs []hasRecord) (remain } } if !remaining { - return false, nil + return false, gcBehavior_Continue, nil } for _, srcAddr := range srcs { src, ok := ts.novel[srcAddr] @@ -142,83 +149,114 @@ func (ts tableSet) hasManyInSources(srcs []hash.Hash, addrs []hasRecord) (remain continue } } - remaining, err = src.hasMany(addrs) + remaining, gcb, err = src.hasMany(addrs, keeper) if err != nil { - return false, err + return false, gcb, err + } + if gcb != gcBehavior_Continue { + return false, gcb, nil } if !remaining { break } } - return remaining, nil + return remaining, gcBehavior_Continue, nil } -func (ts tableSet) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) { +func (ts tableSet) get(ctx context.Context, h hash.Hash, keeper keeperF, stats *Stats) ([]byte, gcBehavior, error) { if err := ctx.Err(); err != nil { - return nil, err + return nil, gcBehavior_Continue, err } - f := func(css chunkSourceSet) ([]byte, error) { + f := func(css chunkSourceSet) ([]byte, gcBehavior, error) { for _, haver := range css { - data, err := haver.get(ctx, h, stats) - + data, gcb, err := haver.get(ctx, h, keeper, stats) if err != nil { - return nil, err + return nil, gcb, err + } + if gcb != gcBehavior_Continue { + return nil, gcb, nil } - if data != nil { - return data, nil + return data, gcBehavior_Continue, nil } } - - return nil, nil + return nil, gcBehavior_Continue, nil } - data, err := f(ts.novel) - + data, gcb, err := f(ts.novel) if err != nil { - return nil, err + return nil, gcb, err + } + if gcb != gcBehavior_Continue { + return nil, gcb, nil } - if data != nil { - return data, nil + return data, gcBehavior_Continue, nil } return f(ts.upstream) } -func (ts tableSet) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (remaining bool, err error) { - f := func(css chunkSourceSet) bool { +func (ts tableSet) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + f := func(css chunkSourceSet) (bool, gcBehavior, error) { for _, haver := range css { - remaining, err = haver.getMany(ctx, eg, reqs, found, stats) + remaining, gcb, err := haver.getMany(ctx, eg, reqs, found, keeper, stats) if err != nil { - return true + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil } if !remaining { - return false + return false, gcb, nil } } - return true + return true, gcBehavior_Continue, nil + } + + remaining, gcb, err := f(ts.novel) + if err != nil { + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil + } + if !remaining { + return false, gcBehavior_Continue, nil } - return f(ts.novel) && err == nil && f(ts.upstream), err + return f(ts.upstream) } -func (ts tableSet) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (remaining bool, err error) { - f := func(css chunkSourceSet) bool { +func (ts tableSet) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + f := func(css chunkSourceSet) (bool, gcBehavior, error) { for _, haver := range css { - remaining, err = haver.getManyCompressed(ctx, eg, reqs, found, stats) + remaining, gcb, err := haver.getManyCompressed(ctx, eg, reqs, found, keeper, stats) if err != nil { - return true + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return true, gcb, nil } if !remaining { - return false + return false, gcBehavior_Continue, nil } } + return true, gcBehavior_Continue, nil + } - return true + remaining, gcb, err := f(ts.novel) + if err != nil { + return true, gcb, err + } + if gcb != gcBehavior_Continue { + return remaining, gcb, nil + } + if !remaining { + return false, gcBehavior_Continue, nil } - return f(ts.novel) && err == nil && f(ts.upstream), err + return f(ts.upstream) } func (ts tableSet) count() (uint32, error) { @@ -500,11 +538,12 @@ func (ts tableSet) toSpecs() ([]tableSpec, error) { return tableSpecs, nil } -func tableSetCalcReads(ts tableSet, reqs []getRecord, blockSize uint64) (reads int, split, remaining bool, err error) { +func tableSetCalcReads(ts tableSet, reqs []getRecord, blockSize uint64, keeper keeperF) (reads int, split, remaining bool, gcb gcBehavior, err error) { all := copyChunkSourceSet(ts.upstream) for a, cs := range ts.novel { all[a] = cs } + gcb = gcBehavior_Continue for _, tbl := range all { rdr, ok := tbl.(*fileTableReader) if !ok { @@ -514,9 +553,12 @@ func tableSetCalcReads(ts tableSet, reqs []getRecord, blockSize uint64) (reads i var n int var more bool - n, more, err = rdr.calcReads(reqs, blockSize) + n, more, gcb, err = rdr.calcReads(reqs, blockSize, keeper) if err != nil { - return 0, false, false, err + return 0, false, false, gcb, err + } + if gcb != gcBehavior_Continue { + return 0, false, false, gcb, nil } reads += n diff --git a/go/store/nbs/table_test.go b/go/store/nbs/table_test.go index 596bebc6890..e62bfc1618e 100644 --- a/go/store/nbs/table_test.go +++ b/go/store/nbs/table_test.go @@ -62,7 +62,7 @@ func buildTable(chunks [][]byte) ([]byte, hash.Hash, error) { } func mustGetString(assert *assert.Assertions, ctx context.Context, tr tableReader, data []byte) string { - bytes, err := tr.get(ctx, computeAddr(data), &Stats{}) + bytes, _, err := tr.get(ctx, computeAddr(data), nil, &Stats{}) assert.NoError(err) return string(bytes) } @@ -106,13 +106,13 @@ func TestSimple(t *testing.T) { func assertChunksInReader(chunks [][]byte, r chunkReader, assert *assert.Assertions) { for _, c := range chunks { - assert.True(r.has(computeAddr(c))) + assert.True(r.has(computeAddr(c), nil)) } } func assertChunksNotInReader(chunks [][]byte, r chunkReader, assert *assert.Assertions) { for _, c := range chunks { - assert.False(r.has(computeAddr(c))) + assert.False(r.has(computeAddr(c), nil)) } } @@ -142,7 +142,7 @@ func TestHasMany(t *testing.T) { } sort.Sort(hasRecordByPrefix(hasAddrs)) - _, err = tr.hasMany(hasAddrs) + _, _, err = tr.hasMany(hasAddrs, nil) require.NoError(t, err) for _, ha := range hasAddrs { assert.True(ha.has, "Nothing for prefix %d", ha.prefix) @@ -192,7 +192,7 @@ func TestHasManySequentialPrefix(t *testing.T) { hasAddrs[0] = hasRecord{&addrs[1], addrs[1].Prefix(), 1, false} hasAddrs[1] = hasRecord{&addrs[2], addrs[2].Prefix(), 2, false} - _, err = tr.hasMany(hasAddrs) + _, _, err = tr.hasMany(hasAddrs, nil) require.NoError(t, err) for _, ha := range hasAddrs { @@ -246,7 +246,7 @@ func BenchmarkHasMany(b *testing.B) { b.Run("dense has many", func(b *testing.B) { var ok bool for i := 0; i < b.N; i++ { - ok, err = tr.hasMany(hrecs) + ok, _, err = tr.hasMany(hrecs, nil) } assert.False(b, ok) assert.NoError(b, err) @@ -254,7 +254,7 @@ func BenchmarkHasMany(b *testing.B) { b.Run("sparse has many", func(b *testing.B) { var ok bool for i := 0; i < b.N; i++ { - ok, err = tr.hasMany(sparse) + ok, _, err = tr.hasMany(sparse, nil) } assert.True(b, ok) assert.NoError(b, err) @@ -290,7 +290,7 @@ func TestGetMany(t *testing.T) { eg, ctx := errgroup.WithContext(context.Background()) got := make([]*chunks.Chunk, 0) - _, err = tr.getMany(ctx, eg, getBatch, func(ctx context.Context, c *chunks.Chunk) { got = append(got, c) }, &Stats{}) + _, _, err = tr.getMany(ctx, eg, getBatch, func(ctx context.Context, c *chunks.Chunk) { got = append(got, c) }, nil, &Stats{}) require.NoError(t, err) require.NoError(t, eg.Wait()) @@ -324,13 +324,13 @@ func TestCalcReads(t *testing.T) { gb2 := []getRecord{getBatch[0], getBatch[2]} sort.Sort(getRecordByPrefix(getBatch)) - reads, remaining, err := tr.calcReads(getBatch, 0) + reads, remaining, _, err := tr.calcReads(getBatch, 0, nil) require.NoError(t, err) assert.False(remaining) assert.Equal(1, reads) sort.Sort(getRecordByPrefix(gb2)) - reads, remaining, err = tr.calcReads(gb2, 0) + reads, remaining, _, err = tr.calcReads(gb2, 0, nil) require.NoError(t, err) assert.False(remaining) assert.Equal(2, reads) @@ -398,8 +398,8 @@ func Test65k(t *testing.T) { for i := 0; i < count; i++ { data := dataFn(i) h := computeAddr(data) - assert.True(tr.has(computeAddr(data))) - bytes, err := tr.get(context.Background(), h, &Stats{}) + assert.True(tr.has(computeAddr(data), nil)) + bytes, _, err := tr.get(context.Background(), h, nil, &Stats{}) require.NoError(t, err) assert.Equal(string(data), string(bytes)) } @@ -407,8 +407,8 @@ func Test65k(t *testing.T) { for i := count; i < count*2; i++ { data := dataFn(i) h := computeAddr(data) - assert.False(tr.has(computeAddr(data))) - bytes, err := tr.get(context.Background(), h, &Stats{}) + assert.False(tr.has(computeAddr(data), nil)) + bytes, _, err := tr.get(context.Background(), h, nil, &Stats{}) require.NoError(t, err) assert.NotEqual(string(data), string(bytes)) } @@ -461,7 +461,7 @@ func doTestNGetMany(t *testing.T, count int) { eg, ctx := errgroup.WithContext(context.Background()) got := make([]*chunks.Chunk, 0) - _, err = tr.getMany(ctx, eg, getBatch, func(ctx context.Context, c *chunks.Chunk) { got = append(got, c) }, &Stats{}) + _, _, err = tr.getMany(ctx, eg, getBatch, func(ctx context.Context, c *chunks.Chunk) { got = append(got, c) }, nil, &Stats{}) require.NoError(t, err) require.NoError(t, eg.Wait()) From 2f47ac2a4b7964f71c1b4f77b3f21e409ed32765 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Thu, 16 Jan 2025 09:49:46 -0800 Subject: [PATCH 16/34] go/store/nbs: Add GC keeper calls on reads through NomsBlockStore. --- go/store/nbs/archive_chunk_source.go | 5 +- go/store/nbs/generational_chunk_store.go | 45 ++- go/store/nbs/store.go | 405 ++++++++++++++--------- go/store/nbs/store_test.go | 2 +- go/store/nbs/table.go | 2 +- 5 files changed, 286 insertions(+), 173 deletions(-) diff --git a/go/store/nbs/archive_chunk_source.go b/go/store/nbs/archive_chunk_source.go index f2d7a3a69e2..1b1f55d8265 100644 --- a/go/store/nbs/archive_chunk_source.go +++ b/go/store/nbs/archive_chunk_source.go @@ -106,7 +106,10 @@ func (acs archiveChunkSource) getMany(ctx context.Context, eg *errgroup.Group, r for i, req := range reqs { h := *req.a data, err := acs.aRdr.get(h) - if err != nil || data == nil { + if err != nil { + return true, gcBehavior_Continue, err + } + if data == nil { foundAll = false } else { if keeper != nil && keeper(h) { diff --git a/go/store/nbs/generational_chunk_store.go b/go/store/nbs/generational_chunk_store.go index 64846797ad3..5d2f801368a 100644 --- a/go/store/nbs/generational_chunk_store.go +++ b/go/store/nbs/generational_chunk_store.go @@ -118,7 +118,9 @@ func (gcs *GenerationalNBS) GetMany(ctx context.Context, hashes hash.HashSet, fo return nil } - err = gcs.newGen.GetMany(ctx, notFound, func(ctx context.Context, chunk *chunks.Chunk) { + hashes = notFound + notFound = hashes.Copy() + err = gcs.newGen.GetMany(ctx, hashes, func(ctx context.Context, chunk *chunks.Chunk) { func() { mu.Lock() defer mu.Unlock() @@ -202,14 +204,30 @@ func (gcs *GenerationalNBS) Has(ctx context.Context, h hash.Hash) (bool, error) } // HasMany returns a new HashSet containing any members of |hashes| that are absent from the store. -func (gcs *GenerationalNBS) HasMany(ctx context.Context, hashes hash.HashSet) (absent hash.HashSet, err error) { - gcs.newGen.mu.RLock() - defer gcs.newGen.mu.RUnlock() - return gcs.hasMany(toHasRecords(hashes)) +func (gcs *GenerationalNBS) HasMany(ctx context.Context, hashes hash.HashSet) (hash.HashSet, error) { + absent, err := gcs.newGen.HasMany(ctx, hashes) + if err != nil { + return nil, err + } + if len(absent) == 0 { + return nil, err + } + + absent, err = gcs.oldGen.HasMany(ctx, absent) + if err != nil { + return nil, err + } + if len(absent) == 0 || gcs.ghostGen == nil { + return nil, err + } + + return gcs.ghostGen.HasMany(ctx, absent) } -func (gcs *GenerationalNBS) hasMany(recs []hasRecord) (absent hash.HashSet, err error) { - absent, err = gcs.newGen.hasMany(recs) +// |refCheck| is called from write processes in newGen, so it is called with +// newGen.mu held. oldGen.mu is not held however. +func (gcs *GenerationalNBS) refCheck(recs []hasRecord) (hash.HashSet, error) { + absent, err := gcs.newGen.refCheck(recs) if err != nil { return nil, err } else if len(absent) == 0 { @@ -219,12 +237,11 @@ func (gcs *GenerationalNBS) hasMany(recs []hasRecord) (absent hash.HashSet, err absent, err = func() (hash.HashSet, error) { gcs.oldGen.mu.RLock() defer gcs.oldGen.mu.RUnlock() - return gcs.oldGen.hasMany(recs) + return gcs.oldGen.refCheck(recs) }() if err != nil { return nil, err } - if len(absent) == 0 || gcs.ghostGen == nil { return absent, nil } @@ -237,7 +254,7 @@ func (gcs *GenerationalNBS) hasMany(recs []hasRecord) (absent hash.HashSet, err // to Flush(). Put may be called concurrently with other calls to Put(), // Get(), GetMany(), Has() and HasMany(). func (gcs *GenerationalNBS) Put(ctx context.Context, c chunks.Chunk, getAddrs chunks.GetAddrsCurry) error { - return gcs.newGen.putChunk(ctx, c, getAddrs, gcs.hasMany) + return gcs.newGen.putChunk(ctx, c, getAddrs, gcs.refCheck) } // Returns the NomsBinFormat with which this ChunkSource is compatible. @@ -277,7 +294,7 @@ func (gcs *GenerationalNBS) Root(ctx context.Context) (hash.Hash, error) { // persisted root hash from last to current (or keeps it the same). // If last doesn't match the root in persistent storage, returns false. func (gcs *GenerationalNBS) Commit(ctx context.Context, current, last hash.Hash) (bool, error) { - return gcs.newGen.commit(ctx, current, last, gcs.hasMany) + return gcs.newGen.commit(ctx, current, last, gcs.refCheck) } // Stats may return some kind of struct that reports statistics about the @@ -400,18 +417,18 @@ func (gcs *GenerationalNBS) AddTableFilesToManifest(ctx context.Context, fileIdT // PruneTableFiles deletes old table files that are no longer referenced in the manifest of the new or old gen chunkstores func (gcs *GenerationalNBS) PruneTableFiles(ctx context.Context) error { - err := gcs.oldGen.pruneTableFiles(ctx, gcs.hasMany) + err := gcs.oldGen.pruneTableFiles(ctx) if err != nil { return err } - return gcs.newGen.pruneTableFiles(ctx, gcs.hasMany) + return gcs.newGen.pruneTableFiles(ctx) } // SetRootChunk changes the root chunk hash from the previous value to the new root for the newgen cs func (gcs *GenerationalNBS) SetRootChunk(ctx context.Context, root, previous hash.Hash) error { - return gcs.newGen.setRootChunk(ctx, root, previous, gcs.hasMany) + return gcs.newGen.setRootChunk(ctx, root, previous, gcs.refCheck) } // SupportedOperations returns a description of the support TableFile operations. Some stores only support reading table files, not writing. diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index f8f6bda09bb..b423dfc7159 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -23,6 +23,7 @@ package nbs import ( "context" + "errors" "fmt" "io" "os" @@ -39,7 +40,6 @@ import ( lru "github.com/hashicorp/golang-lru/v2" "github.com/oracle/oci-go-sdk/v65/common" "github.com/oracle/oci-go-sdk/v65/objectstorage" - "github.com/pkg/errors" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -158,15 +158,14 @@ func (nbs *NomsBlockStore) GetChunkLocationsWithPaths(ctx context.Context, hashe } func (nbs *NomsBlockStore) GetChunkLocations(ctx context.Context, hashes hash.HashSet) (map[hash.Hash]map[hash.Hash]Range, error) { - gr := toGetRecords(hashes) - ranges := make(map[hash.Hash]map[hash.Hash]Range) - - fn := func(css chunkSourceSet) error { + fn := func(css chunkSourceSet, gr []getRecord, ranges map[hash.Hash]map[hash.Hash]Range, keeper keeperF) (gcBehavior, error) { for _, cs := range css { - // TODO: keeperF - rng, _, err := cs.getRecordRanges(ctx, gr, nil) + rng, gcb, err := cs.getRecordRanges(ctx, gr, keeper) if err != nil { - return err + return gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return gcb, nil } h := hash.Hash(cs.hash()) @@ -178,22 +177,60 @@ func (nbs *NomsBlockStore) GetChunkLocations(ctx context.Context, hashes hash.Ha ranges[h] = rng } } - return nil + return gcBehavior_Continue, nil } - tables := func() tableSet { - nbs.mu.RLock() - defer nbs.mu.RUnlock() - return nbs.tables - }() + for { + nbs.mu.Lock() + tables, keeper, endRead := nbs.tables, nbs.keeperFunc, nbs.beginRead() + nbs.mu.Unlock() - if err := fn(tables.upstream); err != nil { - return nil, err + gr := toGetRecords(hashes) + ranges := make(map[hash.Hash]map[hash.Hash]Range) + + gcb, err := fn(tables.upstream, gr, ranges, keeper) + if needsContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err); err != nil { + return nil, err + } else if needsContinue { + continue + } + + gcb, err = fn(tables.novel, gr, ranges, keeper) + if needsContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err); err != nil { + return nil, err + } else if needsContinue { + continue + } + + return ranges, nil } - if err := fn(tables.novel); err != nil { - return nil, err +} + +func (nbs *NomsBlockStore) handleUnlockedRead(ctx context.Context, gcb gcBehavior, endRead func(), err error) (bool, error) { + if err != nil { + if endRead != nil { + nbs.mu.Lock() + endRead() + nbs.mu.Unlock() + } + return false, err + } + if gcb == gcBehavior_Block { + nbs.mu.Lock() + if endRead != nil { + endRead() + } + err := nbs.waitForGC(ctx) + nbs.mu.Unlock() + return true, err + } else { + if endRead != nil { + nbs.mu.Lock() + endRead() + nbs.mu.Unlock() + } + return false, nil } - return ranges, nil } func (nbs *NomsBlockStore) conjoinIfRequired(ctx context.Context) (bool, error) { @@ -730,7 +767,7 @@ func (nbs *NomsBlockStore) waitForGC(ctx context.Context) error { } func (nbs *NomsBlockStore) Put(ctx context.Context, c chunks.Chunk, getAddrs chunks.GetAddrsCurry) error { - return nbs.putChunk(ctx, c, getAddrs, nbs.hasMany) + return nbs.putChunk(ctx, c, getAddrs, nbs.refCheck) } func (nbs *NomsBlockStore) putChunk(ctx context.Context, c chunks.Chunk, getAddrs chunks.GetAddrsCurry, checker refCheck) error { @@ -854,106 +891,118 @@ func (nbs *NomsBlockStore) Get(ctx context.Context, h hash.Hash) (chunks.Chunk, nbs.stats.ChunksPerGet.Sample(1) }() - data, tables, err := func() ([]byte, chunkReader, error) { - var data []byte - nbs.mu.RLock() - defer nbs.mu.RUnlock() + for { + nbs.mu.Lock() if nbs.mt != nil { - var err error - // TODO: keeperF - data, _, err = nbs.mt.get(ctx, h, nil, nbs.stats) - + data, gcb, err := nbs.mt.get(ctx, h, nbs.keeperFunc, nbs.stats) if err != nil { - return nil, nil, err + nbs.mu.Unlock() + return chunks.EmptyChunk, err + } + if gcb == gcBehavior_Block { + err = nbs.waitForGC(ctx) + nbs.mu.Unlock() + if err != nil { + return chunks.EmptyChunk, err + } + continue + } + if data != nil { + nbs.mu.Unlock() + return chunks.NewChunkWithHash(h, data), nil } } - return data, nbs.tables, nil - }() - - if err != nil { - return chunks.EmptyChunk, err - } - - if data != nil { - return chunks.NewChunkWithHash(h, data), nil - } - - // TODO: keeperF - data, _, err = tables.get(ctx, h, nil, nbs.stats) + tables, keeper, endRead := nbs.tables, nbs.keeperFunc, nbs.beginRead() + nbs.mu.Unlock() - if err != nil { - return chunks.EmptyChunk, err - } + data, gcb, err := tables.get(ctx, h, keeper, nbs.stats) + needContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err) + if err != nil { + return chunks.EmptyChunk, err + } + if needContinue { + continue + } - if data != nil { - return chunks.NewChunkWithHash(h, data), nil + if data != nil { + return chunks.NewChunkWithHash(h, data), nil + } + return chunks.EmptyChunk, nil } - - return chunks.EmptyChunk, nil } func (nbs *NomsBlockStore) GetMany(ctx context.Context, hashes hash.HashSet, found func(context.Context, *chunks.Chunk)) error { ctx, span := tracer.Start(ctx, "nbs.GetMany", trace.WithAttributes(attribute.Int("num_hashes", len(hashes)))) - span.End() - return nbs.getManyWithFunc(ctx, hashes, func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, stats *Stats) (bool, error) { - // TODO: keeperF - res, _, err := cr.getMany(ctx, eg, reqs, found, nil, nbs.stats) - return res, err + defer span.End() + return nbs.getManyWithFunc(ctx, hashes, func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + return cr.getMany(ctx, eg, reqs, found, keeper, nbs.stats) }) } func (nbs *NomsBlockStore) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk)) error { ctx, span := tracer.Start(ctx, "nbs.GetManyCompressed", trace.WithAttributes(attribute.Int("num_hashes", len(hashes)))) defer span.End() - return nbs.getManyWithFunc(ctx, hashes, func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, stats *Stats) (bool, error) { - // TODO: keeperF - res, _, err := cr.getManyCompressed(ctx, eg, reqs, found, nil, nbs.stats) - return res, err + return nbs.getManyWithFunc(ctx, hashes, func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + return cr.getManyCompressed(ctx, eg, reqs, found, keeper, nbs.stats) }) } func (nbs *NomsBlockStore) getManyWithFunc( ctx context.Context, hashes hash.HashSet, - getManyFunc func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, stats *Stats) (bool, error), + getManyFunc func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, keeper keeperF, stats *Stats) (bool, gcBehavior, error), ) error { - t1 := time.Now() - reqs := toGetRecords(hashes) + if len(hashes) == 0 { + return nil + } + t1 := time.Now() defer func() { - if len(hashes) > 0 { - nbs.stats.GetLatency.SampleTimeSince(t1) - nbs.stats.ChunksPerGet.Sample(uint64(len(reqs))) - } + nbs.stats.GetLatency.SampleTimeSince(t1) + nbs.stats.ChunksPerGet.Sample(uint64(len(hashes))) }() - eg, ctx := errgroup.WithContext(ctx) - const ioParallelism = 16 - eg.SetLimit(ioParallelism) + for { + reqs := toGetRecords(hashes) + eg, ctx := errgroup.WithContext(ctx) + const ioParallelism = 16 + eg.SetLimit(ioParallelism) - tables, remaining, err := func() (tables chunkReader, remaining bool, err error) { - nbs.mu.RLock() - defer nbs.mu.RUnlock() - tables = nbs.tables - remaining = true + nbs.mu.Lock() if nbs.mt != nil { - remaining, err = getManyFunc(ctx, nbs.mt, eg, reqs, nbs.stats) + remaining, gcb, err := getManyFunc(ctx, nbs.mt, eg, reqs, nbs.keeperFunc, nbs.stats) + if err != nil { + nbs.mu.Unlock() + return err + } + if gcb == gcBehavior_Block { + err = nbs.waitForGC(ctx) + nbs.mu.Unlock() + if err != nil { + return err + } + continue + } + if !remaining { + nbs.mu.Unlock() + return nil + } } - return - }() - if err != nil { - return err - } + tables, keeper, endRead := nbs.tables, nbs.keeperFunc, nbs.beginRead() + nbs.mu.Unlock() - if remaining { - _, err = getManyFunc(ctx, tables, eg, reqs, nbs.stats) - } + _, gcb, err := getManyFunc(ctx, tables, eg, reqs, keeper, nbs.stats) + err = errors.Join(err, eg.Wait()) + needContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err) + if err != nil { + return err + } + if needContinue { + continue + } - if err != nil { - eg.Wait() - return err + return nil } - return eg.Wait() } func toGetRecords(hashes hash.HashSet) []getRecord { @@ -1007,38 +1056,41 @@ func (nbs *NomsBlockStore) Has(ctx context.Context, h hash.Hash) (bool, error) { nbs.stats.AddressesPerHas.Sample(1) }() - has, tables, err := func() (bool, chunkReader, error) { - nbs.mu.RLock() - defer nbs.mu.RUnlock() - + for { + nbs.mu.Lock() if nbs.mt != nil { - // TODO: keeperF - has, _, err := nbs.mt.has(h, nil) - + has, gcb, err := nbs.mt.has(h, nbs.keeperFunc) if err != nil { - return false, nil, err + nbs.mu.Unlock() + return false, err + } + if gcb == gcBehavior_Block { + err = nbs.waitForGC(ctx) + nbs.mu.Unlock() + if err != nil { + return false, err + } + continue + } + if has { + nbs.mu.Unlock() + return true, nil } - - return has, nbs.tables, nil } + tables, keeper, endRead := nbs.tables, nbs.keeperFunc, nbs.beginRead() + nbs.mu.Unlock() - return false, nbs.tables, nil - }() - - if err != nil { - return false, err - } - - if !has { - // TODO: keeperF - has, _, err = tables.has(h, nil) - + has, gcb, err := tables.has(h, keeper) + needsContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err) if err != nil { return false, err } - } + if needsContinue { + continue + } - return has, nil + return has, nil + } } func (nbs *NomsBlockStore) HasMany(ctx context.Context, hashes hash.HashSet) (hash.HashSet, error) { @@ -1047,36 +1099,84 @@ func (nbs *NomsBlockStore) HasMany(ctx context.Context, hashes hash.HashSet) (ha } t1 := time.Now() - defer nbs.stats.HasLatency.SampleTimeSince(t1) - nbs.stats.AddressesPerHas.SampleLen(hashes.Size()) + defer func() { + nbs.stats.HasLatency.SampleTimeSince(t1) + nbs.stats.AddressesPerHas.SampleLen(hashes.Size()) + }() - nbs.mu.RLock() - defer nbs.mu.RUnlock() - return nbs.hasMany(toHasRecords(hashes)) -} + for { + reqs := toHasRecords(hashes) -func (nbs *NomsBlockStore) hasManyInSources(srcs []hash.Hash, hashes hash.HashSet) (hash.HashSet, error) { - if hashes.Size() == 0 { - return nil, nil - } + nbs.mu.Lock() + if nbs.mt != nil { + remaining, gcb, err := nbs.mt.hasMany(reqs, nbs.keeperFunc) + if err != nil { + nbs.mu.Unlock() + return nil, err + } + if gcb == gcBehavior_Block { + err = nbs.waitForGC(ctx) + nbs.mu.Unlock() + if err != nil { + return nil, err + } + continue + } + if !remaining { + nbs.mu.Unlock() + return hash.HashSet{}, nil + } + } + tables, keeper, endRead := nbs.tables, nbs.keeperFunc, nbs.beginRead() + nbs.mu.Unlock() - t1 := time.Now() - defer nbs.stats.HasLatency.SampleTimeSince(t1) - nbs.stats.AddressesPerHas.SampleLen(hashes.Size()) + remaining, gcb, err := tables.hasMany(reqs, keeper) + needContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err) + if err != nil { + return nil, err + } + if needContinue { + continue + } - nbs.mu.RLock() - defer nbs.mu.RUnlock() + if !remaining { + return hash.HashSet{}, nil + } - records := toHasRecords(hashes) + absent := hash.HashSet{} + for _, r := range reqs { + if !r.has { + absent.Insert(*r.a) + } + } + return absent, nil + } +} - // TODO: keeperF - _, _, err := nbs.tables.hasManyInSources(srcs, records, nil) +// Operates a lot like |hasMany|, but without locking and without +// taking read dependencies on the checked references. Should only be +// used for the sanity checking on references for written chunks. +func (nbs *NomsBlockStore) refCheck(reqs []hasRecord) (hash.HashSet, error) { + if nbs.mt != nil { + remaining, _, err := nbs.mt.hasMany(reqs, nil) + if err != nil { + return nil, err + } + if !remaining { + return hash.HashSet{}, nil + } + } + + remaining, _, err := nbs.tables.hasMany(reqs, nil) if err != nil { return nil, err } + if !remaining { + return hash.HashSet{}, nil + } absent := hash.HashSet{} - for _, r := range records { + for _, r := range reqs { if !r.has { absent.Insert(*r.a) } @@ -1084,38 +1184,32 @@ func (nbs *NomsBlockStore) hasManyInSources(srcs []hash.Hash, hashes hash.HashSe return absent, nil } -func (nbs *NomsBlockStore) hasMany(reqs []hasRecord) (hash.HashSet, error) { - tables, remaining, err := func() (tables chunkReader, remaining bool, err error) { - tables = nbs.tables +// Only used for a generational full GC, where the table files are +// added to the store and are then used to filter which chunks need to +// make it to the new generation. In this context, we do not need to +// worry about taking read dependencies on the requested chunks. Hence +// our handling of keeperFunc and gcBehavior below. +func (nbs *NomsBlockStore) hasManyInSources(srcs []hash.Hash, hashes hash.HashSet) (hash.HashSet, error) { + if hashes.Size() == 0 { + return nil, nil + } - remaining = true - if nbs.mt != nil { - // TODO: keeperF - remaining, _, err = nbs.mt.hasMany(reqs, nil) + t1 := time.Now() + defer nbs.stats.HasLatency.SampleTimeSince(t1) + nbs.stats.AddressesPerHas.SampleLen(hashes.Size()) - if err != nil { - return nil, false, err - } - } + nbs.mu.RLock() + defer nbs.mu.RUnlock() - return tables, remaining, nil - }() + records := toHasRecords(hashes) + _, _, err := nbs.tables.hasManyInSources(srcs, records, nil) if err != nil { return nil, err } - if remaining { - // TODO: keeperF - _, _, err := tables.hasMany(reqs, nil) - - if err != nil { - return nil, err - } - } - absent := hash.HashSet{} - for _, r := range reqs { + for _, r := range records { if !r.has { absent.Insert(*r.a) } @@ -1182,7 +1276,7 @@ func (nbs *NomsBlockStore) Root(ctx context.Context) (hash.Hash, error) { } func (nbs *NomsBlockStore) Commit(ctx context.Context, current, last hash.Hash) (success bool, err error) { - return nbs.commit(ctx, current, last, nbs.hasMany) + return nbs.commit(ctx, current, last, nbs.refCheck) } func (nbs *NomsBlockStore) commit(ctx context.Context, current, last hash.Hash, checker refCheck) (success bool, err error) { @@ -1575,12 +1669,11 @@ func (nbs *NomsBlockStore) AddTableFilesToManifest(ctx context.Context, fileIdTo // PruneTableFiles deletes old table files that are no longer referenced in the manifest. func (nbs *NomsBlockStore) PruneTableFiles(ctx context.Context) (err error) { - return nbs.pruneTableFiles(ctx, nbs.hasMany) + return nbs.pruneTableFiles(ctx) } -func (nbs *NomsBlockStore) pruneTableFiles(ctx context.Context, checker refCheck) (err error) { +func (nbs *NomsBlockStore) pruneTableFiles(ctx context.Context) (err error) { mtime := time.Now() - return nbs.p.PruneTableFiles(ctx, func() []hash.Hash { nbs.mu.Lock() defer nbs.mu.Unlock() @@ -1950,7 +2043,7 @@ func (nbs *NomsBlockStore) swapTables(ctx context.Context, specs []tableSpec, mo // SetRootChunk changes the root chunk hash from the previous value to the new root. func (nbs *NomsBlockStore) SetRootChunk(ctx context.Context, root, previous hash.Hash) error { - return nbs.setRootChunk(ctx, root, previous, nbs.hasMany) + return nbs.setRootChunk(ctx, root, previous, nbs.refCheck) } func (nbs *NomsBlockStore) setRootChunk(ctx context.Context, root, previous hash.Hash, checker refCheck) error { diff --git a/go/store/nbs/store_test.go b/go/store/nbs/store_test.go index 041394e45a8..82b604df1b4 100644 --- a/go/store/nbs/store_test.go +++ b/go/store/nbs/store_test.go @@ -212,7 +212,7 @@ func TestNBSPruneTableFiles(t *testing.T) { addrs.Insert(c.Hash()) return nil } - }, st.hasMany) + }, st.refCheck) require.NoError(t, err) require.True(t, ok) ok, err = st.Commit(ctx, st.upstream.root, st.upstream.root) diff --git a/go/store/nbs/table.go b/go/store/nbs/table.go index 106c6f7f6aa..c4763989734 100644 --- a/go/store/nbs/table.go +++ b/go/store/nbs/table.go @@ -199,7 +199,7 @@ const ( // Operation was successful, go forward with the result. gcBehavior_Continue gcBehavior = false // Operation needs to block until the GC is over and then retry. - gcBehavior_Block = true + gcBehavior_Block = true ) type keeperF func(hash.Hash) bool From 7c06b6b4bfc767141910342ef4312b05c16f49c0 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Thu, 16 Jan 2025 10:36:28 -0800 Subject: [PATCH 17/34] go/store/nbs: tablePersister: Add GC dependency capturing to written chunks which were in the memtable but get filtered out because they are already in the store. --- go/store/nbs/aws_table_persister.go | 20 ++++++++----- go/store/nbs/aws_table_persister_test.go | 22 +++++++------- go/store/nbs/bs_persister.go | 36 ++++++++++++++--------- go/store/nbs/conjoiner_test.go | 4 +-- go/store/nbs/file_table_persister.go | 15 +++++++--- go/store/nbs/file_table_persister_test.go | 7 +++-- go/store/nbs/journal.go | 17 ++++++----- go/store/nbs/journal_test.go | 4 +-- go/store/nbs/mem_table.go | 19 +++++++----- go/store/nbs/mem_table_test.go | 4 +-- go/store/nbs/no_conjoin_bs_persister.go | 26 +++++++++------- go/store/nbs/root_tracker_test.go | 20 +++++++------ go/store/nbs/store_test.go | 2 +- go/store/nbs/table_persister.go | 2 +- go/store/nbs/table_set.go | 3 +- go/store/nbs/table_set_test.go | 2 +- 16 files changed, 118 insertions(+), 85 deletions(-) diff --git a/go/store/nbs/aws_table_persister.go b/go/store/nbs/aws_table_persister.go index cc58ffea894..816a9314620 100644 --- a/go/store/nbs/aws_table_persister.go +++ b/go/store/nbs/aws_table_persister.go @@ -115,25 +115,31 @@ func (s3p awsTablePersister) key(k string) string { return k } -func (s3p awsTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { - name, data, chunkCount, err := mt.write(haver, stats) - +func (s3p awsTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { + name, data, chunkCount, gcb, err := mt.write(haver, keeper, stats) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return emptyChunkSource{}, gcb, nil } if chunkCount == 0 { - return emptyChunkSource{}, nil + return emptyChunkSource{}, gcBehavior_Continue, nil } err = s3p.multipartUpload(ctx, bytes.NewReader(data), uint64(len(data)), name.String()) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err } tra := &s3TableReaderAt{&s3ObjectReader{s3: s3p.s3, bucket: s3p.bucket, readRl: s3p.rl, ns: s3p.ns}, name} - return newReaderFromIndexData(ctx, s3p.q, data, name, tra, s3BlockSize) + src, err := newReaderFromIndexData(ctx, s3p.q, data, name, tra, s3BlockSize) + if err != nil { + return emptyChunkSource{}, gcBehavior_Continue, err + } + return src, gcBehavior_Continue, nil } func (s3p awsTablePersister) multipartUpload(ctx context.Context, r io.Reader, sz uint64, key string) error { diff --git a/go/store/nbs/aws_table_persister_test.go b/go/store/nbs/aws_table_persister_test.go index 4ab92c1651b..3187f2e08b6 100644 --- a/go/store/nbs/aws_table_persister_test.go +++ b/go/store/nbs/aws_table_persister_test.go @@ -90,7 +90,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { s3svc := makeFakeS3(t) s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} - src, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + src, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) defer src.close() @@ -108,7 +108,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { s3svc := makeFakeS3(t) s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits64mb, ns: ns, q: &UnlimitedQuotaProvider{}} - src, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + src, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) defer src.close() if assert.True(mustUint32(src.count()) > 0) { @@ -133,7 +133,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { s3svc := makeFakeS3(t) s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} - src, err := s3p.Persist(context.Background(), mt, existingTable, &Stats{}) + src, _, err := s3p.Persist(context.Background(), mt, existingTable, nil, &Stats{}) require.NoError(t, err) defer src.close() assert.True(mustUint32(src.count()) == 0) @@ -148,7 +148,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { s3svc := &failingFakeS3{makeFakeS3(t), sync.Mutex{}, 1} s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} - _, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + _, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) assert.Error(err) }) } @@ -306,7 +306,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { for i := 0; i < len(chunks); i++ { mt := newMemTable(uint64(2 * targetPartSize)) mt.addChunk(computeAddr(chunks[i]), chunks[i]) - cs, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + cs, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) sources = append(sources, cs) } @@ -379,7 +379,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { } var err error - sources[i], err = s3p.Persist(context.Background(), mt, nil, &Stats{}) + sources[i], _, err = s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) } src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{}) @@ -417,9 +417,9 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { rand.Read(medChunks[i]) mt.addChunk(computeAddr(medChunks[i]), medChunks[i]) } - cs1, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + cs1, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) - cs2, err := s3p.Persist(context.Background(), mtb, nil, &Stats{}) + cs2, _, err := s3p.Persist(context.Background(), mtb, nil, nil, &Stats{}) require.NoError(t, err) sources := chunkSources{cs1, cs2} @@ -450,7 +450,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { mt := newMemTable(uint64(2 * targetPartSize)) mt.addChunk(computeAddr(smallChunks[i]), smallChunks[i]) var err error - sources[i], err = s3p.Persist(context.Background(), mt, nil, &Stats{}) + sources[i], _, err = s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) } @@ -461,7 +461,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { } var err error - cs, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) + cs, _, err := s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) sources = append(sources, cs) @@ -474,7 +474,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) { mt.addChunk(computeAddr(medChunks[i]), medChunks[i]) } - cs, err = s3p.Persist(context.Background(), mt, nil, &Stats{}) + cs, _, err = s3p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) sources = append(sources, cs) diff --git a/go/store/nbs/bs_persister.go b/go/store/nbs/bs_persister.go index 9aca6ecd73b..274bbcef2e5 100644 --- a/go/store/nbs/bs_persister.go +++ b/go/store/nbs/bs_persister.go @@ -45,12 +45,16 @@ var _ tableFilePersister = &blobstorePersister{} // Persist makes the contents of mt durable. Chunks already present in // |haver| may be dropped in the process. -func (bsp *blobstorePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { - address, data, chunkCount, err := mt.write(haver, stats) +func (bsp *blobstorePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { + address, data, chunkCount, gcb, err := mt.write(haver, keeper, stats) if err != nil { - return emptyChunkSource{}, err - } else if chunkCount == 0 { - return emptyChunkSource{}, nil + return emptyChunkSource{}, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return emptyChunkSource{}, gcb, nil + } + if chunkCount == 0 { + return emptyChunkSource{}, gcBehavior_Continue, nil } name := address.String() @@ -59,24 +63,28 @@ func (bsp *blobstorePersister) Persist(ctx context.Context, mt *memTable, haver // first write table records and tail (index+footer) as separate blobs eg, ectx := errgroup.WithContext(ctx) - eg.Go(func() (err error) { - _, err = bsp.bs.Put(ectx, name+tableRecordsExt, int64(len(records)), bytes.NewBuffer(records)) - return + eg.Go(func() error { + _, err := bsp.bs.Put(ectx, name+tableRecordsExt, int64(len(records)), bytes.NewBuffer(records)) + return err }) - eg.Go(func() (err error) { - _, err = bsp.bs.Put(ectx, name+tableTailExt, int64(len(tail)), bytes.NewBuffer(tail)) - return + eg.Go(func() error { + _, err := bsp.bs.Put(ectx, name+tableTailExt, int64(len(tail)), bytes.NewBuffer(tail)) + return err }) if err = eg.Wait(); err != nil { - return nil, err + return nil, gcBehavior_Continue, err } // then concatenate into a final blob if _, err = bsp.bs.Concatenate(ctx, name, []string{name + tableRecordsExt, name + tableTailExt}); err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err } rdr := &bsTableReaderAt{name, bsp.bs} - return newReaderFromIndexData(ctx, bsp.q, data, address, rdr, bsp.blockSize) + src, err := newReaderFromIndexData(ctx, bsp.q, data, address, rdr, bsp.blockSize) + if err != nil { + return emptyChunkSource{}, gcBehavior_Continue, err + } + return src, gcBehavior_Continue, nil } // ConjoinAll implements tablePersister. diff --git a/go/store/nbs/conjoiner_test.go b/go/store/nbs/conjoiner_test.go index 01438450e7e..a9f64aa2220 100644 --- a/go/store/nbs/conjoiner_test.go +++ b/go/store/nbs/conjoiner_test.go @@ -63,7 +63,7 @@ func makeTestSrcs(t *testing.T, tableSizes []uint32, p tablePersister) (srcs chu c := nextChunk() mt.addChunk(computeAddr(c), c) } - cs, err := p.Persist(context.Background(), mt, nil, &Stats{}) + cs, _, err := p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) c, err := cs.clone() require.NoError(t, err) @@ -180,7 +180,7 @@ func testConjoin(t *testing.T, factory func(t *testing.T) tablePersister) { mt := newMemTable(testMemTableSize) data := []byte{0xde, 0xad} mt.addChunk(computeAddr(data), data) - src, err := p.Persist(context.Background(), mt, nil, &Stats{}) + src, _, err := p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) defer src.close() return tableSpec{src.hash(), mustUint32(src.count())} diff --git a/go/store/nbs/file_table_persister.go b/go/store/nbs/file_table_persister.go index 5868b982174..83175088f29 100644 --- a/go/store/nbs/file_table_persister.go +++ b/go/store/nbs/file_table_persister.go @@ -86,16 +86,23 @@ func (ftp *fsTablePersister) Exists(ctx context.Context, name hash.Hash, chunkCo return archiveFileExists(ctx, ftp.dir, name) } -func (ftp *fsTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { +func (ftp *fsTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { t1 := time.Now() defer stats.PersistLatency.SampleTimeSince(t1) - name, data, chunkCount, err := mt.write(haver, stats) + name, data, chunkCount, gcb, err := mt.write(haver, keeper, stats) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return emptyChunkSource{}, gcb, nil } - return ftp.persistTable(ctx, name, data, chunkCount, stats) + src, err := ftp.persistTable(ctx, name, data, chunkCount, stats) + if err != nil { + return emptyChunkSource{}, gcBehavior_Continue, err + } + return src, gcBehavior_Continue, nil } func (ftp *fsTablePersister) Path() string { diff --git a/go/store/nbs/file_table_persister_test.go b/go/store/nbs/file_table_persister_test.go index 00e57d2fc95..4adde94986b 100644 --- a/go/store/nbs/file_table_persister_test.go +++ b/go/store/nbs/file_table_persister_test.go @@ -96,7 +96,8 @@ func persistTableData(p tablePersister, chunx ...[]byte) (src chunkSource, err e return nil, fmt.Errorf("memTable too full to add %s", computeAddr(c)) } } - return p.Persist(context.Background(), mt, nil, &Stats{}) + src, _, err = p.Persist(context.Background(), mt, nil, nil, &Stats{}) + return src, err } func TestFSTablePersisterPersistNoData(t *testing.T) { @@ -113,7 +114,7 @@ func TestFSTablePersisterPersistNoData(t *testing.T) { defer file.RemoveAll(dir) fts := newFSTablePersister(dir, &UnlimitedQuotaProvider{}) - src, err := fts.Persist(context.Background(), mt, existingTable, &Stats{}) + src, _, err := fts.Persist(context.Background(), mt, existingTable, nil, &Stats{}) require.NoError(t, err) assert.True(mustUint32(src.count()) == 0) @@ -177,7 +178,7 @@ func TestFSTablePersisterConjoinAllDups(t *testing.T) { } var err error - sources[0], err = fts.Persist(ctx, mt, nil, &Stats{}) + sources[0], _, err = fts.Persist(ctx, mt, nil, nil, &Stats{}) require.NoError(t, err) sources[1], err = sources[0].clone() require.NoError(t, err) diff --git a/go/store/nbs/journal.go b/go/store/nbs/journal.go index df7fbf95622..dc8deff8e19 100644 --- a/go/store/nbs/journal.go +++ b/go/store/nbs/journal.go @@ -239,18 +239,19 @@ func (j *ChunkJournal) IterateRoots(f func(root string, timestamp *time.Time) er } // Persist implements tablePersister. -func (j *ChunkJournal) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { +func (j *ChunkJournal) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { if j.backing.readOnly() { - return nil, errReadOnlyManifest + return nil, gcBehavior_Continue, errReadOnlyManifest } else if err := j.maybeInit(ctx); err != nil { - return nil, err + return nil, gcBehavior_Continue, err } if haver != nil { sort.Sort(hasRecordByPrefix(mt.order)) // hasMany() requires addresses to be sorted. - // TODO: keeperF - if _, _, err := haver.hasMany(mt.order, nil); err != nil { - return nil, err + if _, gcb, err := haver.hasMany(mt.order, keeper); err != nil { + return nil, gcBehavior_Continue, err + } else if gcb != gcBehavior_Continue { + return nil, gcb, nil } sort.Sort(hasRecordByOrder(mt.order)) // restore "insertion" order for write } @@ -262,10 +263,10 @@ func (j *ChunkJournal) Persist(ctx context.Context, mt *memTable, haver chunkRea c := chunks.NewChunkWithHash(hash.Hash(*record.a), mt.chunks[*record.a]) err := j.wr.writeCompressedChunk(ctx, ChunkToCompressedChunk(c)) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } } - return journalChunkSource{journal: j.wr}, nil + return journalChunkSource{journal: j.wr}, gcBehavior_Continue, nil } // ConjoinAll implements tablePersister. diff --git a/go/store/nbs/journal_test.go b/go/store/nbs/journal_test.go index b2f24f01702..603d1610d43 100644 --- a/go/store/nbs/journal_test.go +++ b/go/store/nbs/journal_test.go @@ -67,7 +67,7 @@ func TestChunkJournalPersist(t *testing.T) { haver := emptyChunkSource{} for i := 0; i < iters; i++ { memTbl, chunkMap := randomMemTable(16) - source, err := j.Persist(ctx, memTbl, haver, stats) + source, _, err := j.Persist(ctx, memTbl, haver, nil, stats) assert.NoError(t, err) for h, ch := range chunkMap { @@ -96,7 +96,7 @@ func TestReadRecordRanges(t *testing.T) { gets = append(gets, getRecord{a: &h, prefix: h.Prefix()}) } - jcs, err := j.Persist(ctx, mt, emptyChunkSource{}, &Stats{}) + jcs, _, err := j.Persist(ctx, mt, emptyChunkSource{}, nil, &Stats{}) require.NoError(t, err) rdr, sz, err := jcs.(journalChunkSource).journal.snapshot(context.Background()) diff --git a/go/store/nbs/mem_table.go b/go/store/nbs/mem_table.go index 135e3f8feb5..1fd8c0ffcda 100644 --- a/go/store/nbs/mem_table.go +++ b/go/store/nbs/mem_table.go @@ -61,7 +61,7 @@ func writeChunksToMT(mt *memTable, chunks []chunks.Chunk) (string, []byte, error } var stats Stats - name, data, count, err := mt.write(nil, &stats) + name, data, count, _, err := mt.write(nil, nil, &stats) if err != nil { return "", nil, err @@ -220,10 +220,11 @@ func (mt *memTable) extract(ctx context.Context, chunks chan<- extractRecord) er return nil } -func (mt *memTable) write(haver chunkReader, stats *Stats) (name hash.Hash, data []byte, count uint32, err error) { +func (mt *memTable) write(haver chunkReader, keeper keeperF, stats *Stats) (name hash.Hash, data []byte, count uint32, gcb gcBehavior, err error) { + gcb = gcBehavior_Continue numChunks := uint64(len(mt.order)) if numChunks == 0 { - return hash.Hash{}, nil, 0, fmt.Errorf("mem table cannot write with zero chunks") + return hash.Hash{}, nil, 0, gcBehavior_Continue, fmt.Errorf("mem table cannot write with zero chunks") } maxSize := maxTableSize(uint64(len(mt.order)), mt.totalData) // todo: memory quota @@ -232,10 +233,12 @@ func (mt *memTable) write(haver chunkReader, stats *Stats) (name hash.Hash, data if haver != nil { sort.Sort(hasRecordByPrefix(mt.order)) // hasMany() requires addresses to be sorted. - // TODO: keeperF - _, _, err := haver.hasMany(mt.order, nil) + _, gcb, err = haver.hasMany(mt.order, keeper) if err != nil { - return hash.Hash{}, nil, 0, err + return hash.Hash{}, nil, 0, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return hash.Hash{}, nil, 0, gcb, err } sort.Sort(hasRecordByOrder(mt.order)) // restore "insertion" order for write @@ -251,7 +254,7 @@ func (mt *memTable) write(haver chunkReader, stats *Stats) (name hash.Hash, data tableSize, name, err := tw.finish() if err != nil { - return hash.Hash{}, nil, 0, err + return hash.Hash{}, nil, 0, gcBehavior_Continue, err } if count > 0 { @@ -261,7 +264,7 @@ func (mt *memTable) write(haver chunkReader, stats *Stats) (name hash.Hash, data stats.ChunksPerPersist.Sample(uint64(count)) } - return name, buff[:tableSize], count, nil + return name, buff[:tableSize], count, gcBehavior_Continue, nil } func (mt *memTable) close() error { diff --git a/go/store/nbs/mem_table_test.go b/go/store/nbs/mem_table_test.go index 3b7b2c222e5..647a8395dce 100644 --- a/go/store/nbs/mem_table_test.go +++ b/go/store/nbs/mem_table_test.go @@ -164,7 +164,7 @@ func TestMemTableWrite(t *testing.T) { defer tr2.close() assert.True(tr2.has(computeAddr(chunks[2]), nil)) - _, data, count, err := mt.write(chunkReaderGroup{tr1, tr2}, &Stats{}) + _, data, count, _, err := mt.write(chunkReaderGroup{tr1, tr2}, nil, &Stats{}) require.NoError(t, err) assert.Equal(uint32(1), count) @@ -218,7 +218,7 @@ func TestMemTableSnappyWriteOutOfLine(t *testing.T) { } mt.snapper = &outOfLineSnappy{[]bool{false, true, false}} // chunks[1] should trigger a panic - assert.Panics(func() { mt.write(nil, &Stats{}) }) + assert.Panics(func() { mt.write(nil, nil, &Stats{}) }) } type outOfLineSnappy struct { diff --git a/go/store/nbs/no_conjoin_bs_persister.go b/go/store/nbs/no_conjoin_bs_persister.go index 053c9be710e..98ed3a06a5c 100644 --- a/go/store/nbs/no_conjoin_bs_persister.go +++ b/go/store/nbs/no_conjoin_bs_persister.go @@ -21,7 +21,6 @@ import ( "io" "time" - "github.com/fatih/color" "golang.org/x/sync/errgroup" "github.com/dolthub/dolt/go/store/blobstore" @@ -40,27 +39,32 @@ var _ tableFilePersister = &noConjoinBlobstorePersister{} // Persist makes the contents of mt durable. Chunks already present in // |haver| may be dropped in the process. -func (bsp *noConjoinBlobstorePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { - address, data, chunkCount, err := mt.write(haver, stats) +func (bsp *noConjoinBlobstorePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { + address, data, chunkCount, gcb, err := mt.write(haver, keeper, stats) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err + } else if gcb != gcBehavior_Continue { + return emptyChunkSource{}, gcb, nil } else if chunkCount == 0 { - return emptyChunkSource{}, nil + return emptyChunkSource{}, gcBehavior_Continue, nil } name := address.String() eg, ectx := errgroup.WithContext(ctx) - eg.Go(func() (err error) { - fmt.Fprintf(color.Output, "Persist: bs.Put: name: %s\n", name) - _, err = bsp.bs.Put(ectx, name, int64(len(data)), bytes.NewBuffer(data)) - return + eg.Go(func() error { + _, err := bsp.bs.Put(ectx, name, int64(len(data)), bytes.NewBuffer(data)) + return err }) if err = eg.Wait(); err != nil { - return nil, err + return nil, gcBehavior_Continue, err } rdr := &bsTableReaderAt{name, bsp.bs} - return newReaderFromIndexData(ctx, bsp.q, data, address, rdr, bsp.blockSize) + src, err := newReaderFromIndexData(ctx, bsp.q, data, address, rdr, bsp.blockSize) + if err != nil { + return nil, gcBehavior_Continue, err + } + return src, gcBehavior_Continue, nil } // ConjoinAll implements tablePersister. diff --git a/go/store/nbs/root_tracker_test.go b/go/store/nbs/root_tracker_test.go index b5e82e8e182..37d82483743 100644 --- a/go/store/nbs/root_tracker_test.go +++ b/go/store/nbs/root_tracker_test.go @@ -399,7 +399,7 @@ func interloperWrite(fm *fakeManifest, p tablePersister, rootChunk []byte, chunk persisted = append(chunks, rootChunk) var src chunkSource - src, err = p.Persist(context.Background(), createMemTable(persisted), nil, &Stats{}) + src, _, err = p.Persist(context.Background(), createMemTable(persisted), nil, nil, &Stats{}) if err != nil { return hash.Hash{}, nil, err } @@ -505,16 +505,18 @@ type fakeTablePersister struct { var _ tablePersister = fakeTablePersister{} -func (ftp fakeTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) { +func (ftp fakeTablePersister) Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) { if mustUint32(mt.count()) == 0 { - return emptyChunkSource{}, nil + return emptyChunkSource{}, gcBehavior_Continue, nil } - name, data, chunkCount, err := mt.write(haver, stats) + name, data, chunkCount, gcb, err := mt.write(haver, keeper, stats) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err + } else if gcb != gcBehavior_Continue { + return emptyChunkSource{}, gcb, nil } else if chunkCount == 0 { - return emptyChunkSource{}, nil + return emptyChunkSource{}, gcBehavior_Continue, nil } ftp.mu.Lock() @@ -523,14 +525,14 @@ func (ftp fakeTablePersister) Persist(ctx context.Context, mt *memTable, haver c ti, err := parseTableIndexByCopy(ctx, data, ftp.q) if err != nil { - return nil, err + return nil, gcBehavior_Continue, err } cs, err := newTableReader(ti, tableReaderAtFromBytes(data), fileBlockSize) if err != nil { - return emptyChunkSource{}, err + return emptyChunkSource{}, gcBehavior_Continue, err } - return chunkSourceAdapter{cs, name}, nil + return chunkSourceAdapter{cs, name}, gcBehavior_Continue, nil } func (ftp fakeTablePersister) ConjoinAll(ctx context.Context, sources chunkSources, stats *Stats) (chunkSource, cleanupFunc, error) { diff --git a/go/store/nbs/store_test.go b/go/store/nbs/store_test.go index 82b604df1b4..82dc8d89005 100644 --- a/go/store/nbs/store_test.go +++ b/go/store/nbs/store_test.go @@ -381,7 +381,7 @@ func persistTableFileSources(t *testing.T, p tablePersister, numTableFiles int) require.True(t, ok) tableFileMap[fileIDHash] = uint32(i + 1) mapIds[i] = fileIDHash - cs, err := p.Persist(context.Background(), createMemTable(chunkData), nil, &Stats{}) + cs, _, err := p.Persist(context.Background(), createMemTable(chunkData), nil, nil, &Stats{}) require.NoError(t, err) require.NoError(t, cs.close()) diff --git a/go/store/nbs/table_persister.go b/go/store/nbs/table_persister.go index 5c230daa091..6d283fb3ec5 100644 --- a/go/store/nbs/table_persister.go +++ b/go/store/nbs/table_persister.go @@ -47,7 +47,7 @@ type cleanupFunc func() type tablePersister interface { // Persist makes the contents of mt durable. Chunks already present in // |haver| may be dropped in the process. - Persist(ctx context.Context, mt *memTable, haver chunkReader, stats *Stats) (chunkSource, error) + Persist(ctx context.Context, mt *memTable, haver chunkReader, keeper keeperF, stats *Stats) (chunkSource, gcBehavior, error) // ConjoinAll conjoins all chunks in |sources| into a single, new // chunkSource. It returns a |cleanupFunc| which can be called to diff --git a/go/store/nbs/table_set.go b/go/store/nbs/table_set.go index b647aad46a3..c020478ed5e 100644 --- a/go/store/nbs/table_set.go +++ b/go/store/nbs/table_set.go @@ -385,7 +385,8 @@ func (ts tableSet) append(ctx context.Context, mt *memTable, checker refCheck, h return tableSet{}, fmt.Errorf("%w: found dangling references to %s", ErrDanglingRef, absent.String()) } - cs, err := ts.p.Persist(ctx, mt, ts, stats) + // TODO: keeperF + cs, _, err := ts.p.Persist(ctx, mt, ts, nil, stats) if err != nil { return tableSet{}, err } diff --git a/go/store/nbs/table_set_test.go b/go/store/nbs/table_set_test.go index b7d54cbd092..9b4ccdd7ecc 100644 --- a/go/store/nbs/table_set_test.go +++ b/go/store/nbs/table_set_test.go @@ -146,7 +146,7 @@ func persist(t *testing.T, p tablePersister, chunks ...[]byte) { for _, c := range chunks { mt := newMemTable(testMemTableSize) mt.addChunk(computeAddr(c), c) - cs, err := p.Persist(context.Background(), mt, nil, &Stats{}) + cs, _, err := p.Persist(context.Background(), mt, nil, nil, &Stats{}) require.NoError(t, err) require.NoError(t, cs.close()) } From 8a5689ec11efcf745120a5babf78597a6c85e475 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Thu, 16 Jan 2025 11:42:25 -0800 Subject: [PATCH 18/34] go/store/nbs: table_set: append: Thread GC dependency tracking through for chunks that are written but which are already present in the store. --- go/store/nbs/store.go | 41 +++++++++++++++++++++++----------- go/store/nbs/table_set.go | 16 +++++++------ go/store/nbs/table_set_test.go | 24 ++++++++++---------- 3 files changed, 49 insertions(+), 32 deletions(-) diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index b423dfc7159..9ccf1eb637f 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -833,11 +833,18 @@ func (nbs *NomsBlockStore) addChunk(ctx context.Context, ch chunks.Chunk, getAdd addChunkRes = nbs.mt.addChunk(ch.Hash(), ch.Data()) if addChunkRes == chunkNotAdded { - ts, err := nbs.tables.append(ctx, nbs.mt, checker, nbs.hasCache, nbs.stats) + ts, gcb, err := nbs.tables.append(ctx, nbs.mt, checker, nbs.keeperFunc, nbs.hasCache, nbs.stats) if err != nil { nbs.handlePossibleDanglingRefError(err) return false, err } + if gcb == gcBehavior_Block { + retry = true + if err := nbs.waitForGC(ctx); err != nil { + return false, err + } + continue + } nbs.addPendingRefsToHasCache() nbs.tables = ts nbs.mt = newMemTable(nbs.mtSize) @@ -1365,22 +1372,30 @@ func (nbs *NomsBlockStore) updateManifest(ctx context.Context, current, last has return handleOptimisticLockFailure(cached) } - if nbs.mt != nil { - cnt, err := nbs.mt.count() - - if err != nil { - return err - } - - if cnt > 0 { - ts, err := nbs.tables.append(ctx, nbs.mt, checker, nbs.hasCache, nbs.stats) + for { + if nbs.mt != nil { + cnt, err := nbs.mt.count() if err != nil { - nbs.handlePossibleDanglingRefError(err) return err } - nbs.addPendingRefsToHasCache() - nbs.tables, nbs.mt = ts, nil + if cnt > 0 { + ts, gcb, err := nbs.tables.append(ctx, nbs.mt, checker, nbs.keeperFunc, nbs.hasCache, nbs.stats) + if err != nil { + nbs.handlePossibleDanglingRefError(err) + return err + } + if gcb == gcBehavior_Block { + err = nbs.waitForGC(ctx) + if err != nil { + return err + } + continue + } + nbs.addPendingRefsToHasCache() + nbs.tables, nbs.mt = ts, nil + } } + break } didConjoin, err := nbs.conjoinIfRequired(ctx) diff --git a/go/store/nbs/table_set.go b/go/store/nbs/table_set.go index c020478ed5e..88fd92de587 100644 --- a/go/store/nbs/table_set.go +++ b/go/store/nbs/table_set.go @@ -364,7 +364,7 @@ func (ts tableSet) Size() int { // append adds a memTable to an existing tableSet, compacting |mt| and // returning a new tableSet with newly compacted table added. -func (ts tableSet) append(ctx context.Context, mt *memTable, checker refCheck, hasCache *lru.TwoQueueCache[hash.Hash, struct{}], stats *Stats) (tableSet, error) { +func (ts tableSet) append(ctx context.Context, mt *memTable, checker refCheck, keeper keeperF, hasCache *lru.TwoQueueCache[hash.Hash, struct{}], stats *Stats) (tableSet, gcBehavior, error) { addrs := hash.NewHashSet() for _, getAddrs := range mt.getChildAddrs { getAddrs(ctx, addrs, func(h hash.Hash) bool { return hasCache.Contains(h) }) @@ -380,15 +380,17 @@ func (ts tableSet) append(ctx context.Context, mt *memTable, checker refCheck, h sort.Sort(hasRecordByPrefix(mt.pendingRefs)) absent, err := checker(mt.pendingRefs) if err != nil { - return tableSet{}, err + return tableSet{}, gcBehavior_Continue, err } else if absent.Size() > 0 { - return tableSet{}, fmt.Errorf("%w: found dangling references to %s", ErrDanglingRef, absent.String()) + return tableSet{}, gcBehavior_Continue, fmt.Errorf("%w: found dangling references to %s", ErrDanglingRef, absent.String()) } - // TODO: keeperF - cs, _, err := ts.p.Persist(ctx, mt, ts, nil, stats) + cs, gcb, err := ts.p.Persist(ctx, mt, ts, keeper, stats) if err != nil { - return tableSet{}, err + return tableSet{}, gcBehavior_Continue, err + } + if gcb != gcBehavior_Continue { + return tableSet{}, gcb, nil } newTs := tableSet{ @@ -399,7 +401,7 @@ func (ts tableSet) append(ctx context.Context, mt *memTable, checker refCheck, h rl: ts.rl, } newTs.novel[cs.hash()] = cs - return newTs, nil + return newTs, gcBehavior_Continue, nil } // flatten returns a new tableSet with |upstream| set to the union of ts.novel diff --git a/go/store/nbs/table_set_test.go b/go/store/nbs/table_set_test.go index 9b4ccdd7ecc..e1cfcef3dad 100644 --- a/go/store/nbs/table_set_test.go +++ b/go/store/nbs/table_set_test.go @@ -41,7 +41,7 @@ var hasManyHasAll = func([]hasRecord) (hash.HashSet, error) { func TestTableSetPrependEmpty(t *testing.T) { hasCache, err := lru.New2Q[hash.Hash, struct{}](1024) require.NoError(t, err) - ts, err := newFakeTableSet(&UnlimitedQuotaProvider{}).append(context.Background(), newMemTable(testMemTableSize), hasManyHasAll, hasCache, &Stats{}) + ts, _, err := newFakeTableSet(&UnlimitedQuotaProvider{}).append(context.Background(), newMemTable(testMemTableSize), hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) specs, err := ts.toSpecs() require.NoError(t, err) @@ -61,7 +61,7 @@ func TestTableSetPrepend(t *testing.T) { mt.addChunk(computeAddr(testChunks[0]), testChunks[0]) hasCache, err := lru.New2Q[hash.Hash, struct{}](1024) require.NoError(t, err) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) firstSpecs, err := ts.toSpecs() @@ -71,7 +71,7 @@ func TestTableSetPrepend(t *testing.T) { mt = newMemTable(testMemTableSize) mt.addChunk(computeAddr(testChunks[1]), testChunks[1]) mt.addChunk(computeAddr(testChunks[2]), testChunks[2]) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) secondSpecs, err := ts.toSpecs() @@ -93,17 +93,17 @@ func TestTableSetToSpecsExcludesEmptyTable(t *testing.T) { mt.addChunk(computeAddr(testChunks[0]), testChunks[0]) hasCache, err := lru.New2Q[hash.Hash, struct{}](1024) require.NoError(t, err) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) mt = newMemTable(testMemTableSize) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) mt = newMemTable(testMemTableSize) mt.addChunk(computeAddr(testChunks[1]), testChunks[1]) mt.addChunk(computeAddr(testChunks[2]), testChunks[2]) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) specs, err = ts.toSpecs() @@ -124,17 +124,17 @@ func TestTableSetFlattenExcludesEmptyTable(t *testing.T) { mt.addChunk(computeAddr(testChunks[0]), testChunks[0]) hasCache, err := lru.New2Q[hash.Hash, struct{}](1024) require.NoError(t, err) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) mt = newMemTable(testMemTableSize) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) mt = newMemTable(testMemTableSize) mt.addChunk(computeAddr(testChunks[1]), testChunks[1]) mt.addChunk(computeAddr(testChunks[2]), testChunks[2]) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) ts, err = ts.flatten(context.Background()) @@ -164,7 +164,7 @@ func TestTableSetRebase(t *testing.T) { for _, c := range chunks { mt := newMemTable(testMemTableSize) mt.addChunk(computeAddr(c), c) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) } return ts @@ -213,13 +213,13 @@ func TestTableSetPhysicalLen(t *testing.T) { mt.addChunk(computeAddr(testChunks[0]), testChunks[0]) hasCache, err := lru.New2Q[hash.Hash, struct{}](1024) require.NoError(t, err) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) mt = newMemTable(testMemTableSize) mt.addChunk(computeAddr(testChunks[1]), testChunks[1]) mt.addChunk(computeAddr(testChunks[2]), testChunks[2]) - ts, err = ts.append(context.Background(), mt, hasManyHasAll, hasCache, &Stats{}) + ts, _, err = ts.append(context.Background(), mt, hasManyHasAll, nil, hasCache, &Stats{}) require.NoError(t, err) assert.True(mustUint64(ts.physicalLen()) > indexSize(mustUint32(ts.count()))) From f69f4a438299a83573e9bfca0cf91423b0d35b3d Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Thu, 16 Jan 2025 11:43:29 -0800 Subject: [PATCH 19/34] go/store/nbs: store.go: MarkAndSweepChunks: After adding read tracking, make it so that reading chunks as part of the GC process does not take further dependencies on them and never blocks on waitForGC. --- go/store/nbs/generational_chunk_store.go | 14 ++++--- go/store/nbs/ghost_store.go | 4 ++ go/store/nbs/store.go | 47 ++++++++++++++++++------ 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/go/store/nbs/generational_chunk_store.go b/go/store/nbs/generational_chunk_store.go index 5d2f801368a..e8790ec7e8c 100644 --- a/go/store/nbs/generational_chunk_store.go +++ b/go/store/nbs/generational_chunk_store.go @@ -145,14 +145,18 @@ func (gcs *GenerationalNBS) GetMany(ctx context.Context, hashes hash.HashSet, fo } func (gcs *GenerationalNBS) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk)) error { + return gcs.getManyCompressed(ctx, hashes, found, gcDependencyMode_TakeDependency) +} + +func (gcs *GenerationalNBS) getManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk), gcDepMode gcDependencyMode) error { var mu sync.Mutex notInOldGen := hashes.Copy() - err := gcs.oldGen.GetManyCompressed(ctx, hashes, func(ctx context.Context, chunk CompressedChunk) { + err := gcs.oldGen.getManyCompressed(ctx, hashes, func(ctx context.Context, chunk CompressedChunk) { mu.Lock() delete(notInOldGen, chunk.Hash()) mu.Unlock() found(ctx, chunk) - }) + }, gcDepMode) if err != nil { return err } @@ -161,12 +165,12 @@ func (gcs *GenerationalNBS) GetManyCompressed(ctx context.Context, hashes hash.H } notFound := notInOldGen.Copy() - err = gcs.newGen.GetManyCompressed(ctx, notInOldGen, func(ctx context.Context, chunk CompressedChunk) { + err = gcs.newGen.getManyCompressed(ctx, notInOldGen, func(ctx context.Context, chunk CompressedChunk) { mu.Lock() delete(notFound, chunk.Hash()) mu.Unlock() found(ctx, chunk) - }) + }, gcDepMode) if err != nil { return err } @@ -176,7 +180,7 @@ func (gcs *GenerationalNBS) GetManyCompressed(ctx context.Context, hashes hash.H // The missing chunks may be ghost chunks. if gcs.ghostGen != nil { - return gcs.ghostGen.GetManyCompressed(ctx, notFound, found) + return gcs.ghostGen.getManyCompressed(ctx, notFound, found, gcDepMode) } return nil } diff --git a/go/store/nbs/ghost_store.go b/go/store/nbs/ghost_store.go index 11d23de6a68..9edd0fb40fa 100644 --- a/go/store/nbs/ghost_store.go +++ b/go/store/nbs/ghost_store.go @@ -91,6 +91,10 @@ func (g GhostBlockStore) GetMany(ctx context.Context, hashes hash.HashSet, found } func (g GhostBlockStore) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk)) error { + return g.getManyCompressed(ctx, hashes, found, gcDependencyMode_TakeDependency) +} + +func (g GhostBlockStore) getManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk), gcDepMode gcDependencyMode) error { for h := range hashes { if g.skippedRefs.Has(h) { found(ctx, NewGhostCompressedChunk(h)) diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index 9ccf1eb637f..1970fd258c6 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -89,6 +89,16 @@ type NBSCompressedChunkStore interface { GetManyCompressed(context.Context, hash.HashSet, func(context.Context, CompressedChunk)) error } +type gcDependencyMode int +const ( + gcDependencyMode_TakeDependency gcDependencyMode = iota + gcDependencyMode_NoDependency +) + +type CompressedChunkStoreForGC interface { + getManyCompressed(context.Context, hash.HashSet, func(context.Context, CompressedChunk), gcDependencyMode) error +} + type NomsBlockStore struct { mm manifestManager p tablePersister @@ -941,22 +951,31 @@ func (nbs *NomsBlockStore) Get(ctx context.Context, h hash.Hash) (chunks.Chunk, func (nbs *NomsBlockStore) GetMany(ctx context.Context, hashes hash.HashSet, found func(context.Context, *chunks.Chunk)) error { ctx, span := tracer.Start(ctx, "nbs.GetMany", trace.WithAttributes(attribute.Int("num_hashes", len(hashes)))) defer span.End() - return nbs.getManyWithFunc(ctx, hashes, func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, keeper keeperF, stats *Stats) (bool, gcBehavior, error) { - return cr.getMany(ctx, eg, reqs, found, keeper, nbs.stats) - }) + return nbs.getManyWithFunc(ctx, hashes, gcDependencyMode_TakeDependency, + func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + return cr.getMany(ctx, eg, reqs, found, keeper, nbs.stats) + }, + ) } func (nbs *NomsBlockStore) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk)) error { + return nbs.getManyCompressed(ctx, hashes, found, gcDependencyMode_TakeDependency) +} + +func (nbs *NomsBlockStore) getManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, CompressedChunk), gcDepMode gcDependencyMode) error { ctx, span := tracer.Start(ctx, "nbs.GetManyCompressed", trace.WithAttributes(attribute.Int("num_hashes", len(hashes)))) defer span.End() - return nbs.getManyWithFunc(ctx, hashes, func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, keeper keeperF, stats *Stats) (bool, gcBehavior, error) { - return cr.getManyCompressed(ctx, eg, reqs, found, keeper, nbs.stats) - }) + return nbs.getManyWithFunc(ctx, hashes, gcDepMode, + func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, keeper keeperF, stats *Stats) (bool, gcBehavior, error) { + return cr.getManyCompressed(ctx, eg, reqs, found, keeper, nbs.stats) + }, + ) } func (nbs *NomsBlockStore) getManyWithFunc( ctx context.Context, hashes hash.HashSet, + gcDepMode gcDependencyMode, getManyFunc func(ctx context.Context, cr chunkReader, eg *errgroup.Group, reqs []getRecord, keeper keeperF, stats *Stats) (bool, gcBehavior, error), ) error { if len(hashes) == 0 { @@ -976,8 +995,12 @@ func (nbs *NomsBlockStore) getManyWithFunc( eg.SetLimit(ioParallelism) nbs.mu.Lock() + keeper := nbs.keeperFunc + if gcDepMode == gcDependencyMode_NoDependency { + keeper = nil + } if nbs.mt != nil { - remaining, gcb, err := getManyFunc(ctx, nbs.mt, eg, reqs, nbs.keeperFunc, nbs.stats) + remaining, gcb, err := getManyFunc(ctx, nbs.mt, eg, reqs, keeper, nbs.stats) if err != nil { nbs.mu.Unlock() return err @@ -995,7 +1018,7 @@ func (nbs *NomsBlockStore) getManyWithFunc( return nil } } - tables, keeper, endRead := nbs.tables, nbs.keeperFunc, nbs.beginRead() + tables, endRead := nbs.tables, nbs.beginRead() nbs.mu.Unlock() _, gcb, err := getManyFunc(ctx, tables, eg, reqs, keeper, nbs.stats) @@ -1755,7 +1778,7 @@ func (nbs *NomsBlockStore) MarkAndSweepChunks(ctx context.Context, getAddrs chun return markAndSweepChunks(ctx, nbs, nbs, dest, getAddrs, filter, mode) } -func markAndSweepChunks(ctx context.Context, nbs *NomsBlockStore, src NBSCompressedChunkStore, dest chunks.ChunkStore, getAddrs chunks.GetAddrsCurry, filter chunks.HasManyFunc, mode chunks.GCMode) (chunks.MarkAndSweeper, error) { +func markAndSweepChunks(ctx context.Context, nbs *NomsBlockStore, src CompressedChunkStoreForGC, dest chunks.ChunkStore, getAddrs chunks.GetAddrsCurry, filter chunks.HasManyFunc, mode chunks.GCMode) (chunks.MarkAndSweeper, error) { ops := nbs.SupportedOperations() if !ops.CanGC || !ops.CanPrune { return nil, chunks.ErrUnsupportedOperation @@ -1823,7 +1846,7 @@ func markAndSweepChunks(ctx context.Context, nbs *NomsBlockStore, src NBSCompres } type markAndSweeper struct { - src NBSCompressedChunkStore + src CompressedChunkStoreForGC dest *NomsBlockStore getAddrs chunks.GetAddrsCurry filter chunks.HasManyFunc @@ -1869,7 +1892,7 @@ func (i *markAndSweeper) SaveHashes(ctx context.Context, hashes []hash.Hash) err found := 0 var addErr error - err = i.src.GetManyCompressed(ctx, toVisit, func(ctx context.Context, cc CompressedChunk) { + err = i.src.getManyCompressed(ctx, toVisit, func(ctx context.Context, cc CompressedChunk) { mu.Lock() defer mu.Unlock() if addErr != nil { @@ -1893,7 +1916,7 @@ func (i *markAndSweeper) SaveHashes(ctx context.Context, hashes []hash.Hash) err return } addErr = i.getAddrs(c)(ctx, nextToVisit, func(h hash.Hash) bool { return false }) - }) + }, gcDependencyMode_NoDependency) if err != nil { return err } From de5c1788304a7d356a6faf937bef1e7c2fe41d4e Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Thu, 16 Jan 2025 13:52:05 -0800 Subject: [PATCH 20/34] repofmt.sh. --- go/store/nbs/frag/main.go | 4 ++-- go/store/nbs/store.go | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/go/store/nbs/frag/main.go b/go/store/nbs/frag/main.go index 424259916e7..7eb9c2d8db8 100644 --- a/go/store/nbs/frag/main.go +++ b/go/store/nbs/frag/main.go @@ -153,14 +153,14 @@ func main() { if i+1 == numGroups { // last group go func(i int) { defer wg.Done() - reads[i], _, err = nbs.CalcReads(store, orderedChildren[i*branchFactor:].HashSet(), 0) + reads[i], _, _, err = nbs.CalcReads(store, orderedChildren[i*branchFactor:].HashSet(), 0, nil) d.PanicIfError(err) }(i) continue } go func(i int) { defer wg.Done() - reads[i], _, err = nbs.CalcReads(store, orderedChildren[i*branchFactor:(i+1)*branchFactor].HashSet(), 0) + reads[i], _, _, err = nbs.CalcReads(store, orderedChildren[i*branchFactor:(i+1)*branchFactor].HashSet(), 0, nil) d.PanicIfError(err) }(i) } diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index 1970fd258c6..e5a3c1e0ea9 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -90,6 +90,7 @@ type NBSCompressedChunkStore interface { } type gcDependencyMode int + const ( gcDependencyMode_TakeDependency gcDependencyMode = iota gcDependencyMode_NoDependency From bba09b1b4b48a8a58b73e33f16816b8a46c5bc39 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Thu, 23 Jan 2025 13:38:48 -0800 Subject: [PATCH 21/34] go/store/nbs: store.go: Fix errgroup context usage-after-Wait bug in getManyWithFunc. --- go/store/nbs/store.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index e5a3c1e0ea9..bc39ceee3be 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -989,11 +989,9 @@ func (nbs *NomsBlockStore) getManyWithFunc( nbs.stats.ChunksPerGet.Sample(uint64(len(hashes))) }() + const ioParallelism = 16 for { reqs := toGetRecords(hashes) - eg, ctx := errgroup.WithContext(ctx) - const ioParallelism = 16 - eg.SetLimit(ioParallelism) nbs.mu.Lock() keeper := nbs.keeperFunc @@ -1001,7 +999,8 @@ func (nbs *NomsBlockStore) getManyWithFunc( keeper = nil } if nbs.mt != nil { - remaining, gcb, err := getManyFunc(ctx, nbs.mt, eg, reqs, keeper, nbs.stats) + // nbs.mt does not use the errgroup parameter, which we pass at |nil| here. + remaining, gcb, err := getManyFunc(ctx, nbs.mt, nil, reqs, keeper, nbs.stats) if err != nil { nbs.mu.Unlock() return err @@ -1022,8 +1021,12 @@ func (nbs *NomsBlockStore) getManyWithFunc( tables, endRead := nbs.tables, nbs.beginRead() nbs.mu.Unlock() - _, gcb, err := getManyFunc(ctx, tables, eg, reqs, keeper, nbs.stats) - err = errors.Join(err, eg.Wait()) + gcb, err := func() (gcBehavior, error) { + eg, ctx := errgroup.WithContext(ctx) + eg.SetLimit(ioParallelism) + _, gcb, err := getManyFunc(ctx, tables, eg, reqs, keeper, nbs.stats) + return gcb, errors.Join(err, eg.Wait()) + }() needContinue, err := nbs.handleUnlockedRead(ctx, gcb, endRead, err) if err != nil { return err From c2ec3deda50696bf5ea6b1d5ff164dbd44a19cf5 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Mon, 27 Jan 2025 17:46:22 -0800 Subject: [PATCH 22/34] go: binlogreplication: Add Session{{Begin,End}Command,End} lifecycle callbacks to the replica controller execution context session. This makes small clean ups to the lifecycle around replica applier and the ownership and lifecycle of the mysql.Conn read connection. This PR also includes some changes to slightly improve the performance and reliability of the tests when running them locally. In particular, some of the changes include: 1) Since `go run ./cmd/dolt` takes about four seconds to validate the existing cached build on my laptop, we just go ahead and use a cached build everywhere. 2) We use t.Log{f,} instead of fmt.Prin.. to improve the ergonomics of test running and getting output from a failure in particular. 3) We try to minimize global process state changes like unnecessary `os.Chdir` calls, since it would be nice to parallelize these tests eventually. 4) We get rid of the unused and seemingly unnecessary --socket= argument to Dolt, where we had to use a directory not corresponding to $TMPDIR, for example, because max pathlength on a sun_path on MacOS is 104 characters or whatever. --- .../binlog_metadata_persistence.go | 4 +- .../binlogreplication/binlog_primary_test.go | 10 +- .../binlog_replica_applier.go | 50 +++- .../binlog_replica_controller.go | 20 +- .../binlog_replica_event_producer.go | 35 ++- .../binlog_replication_restart_test.go | 2 +- .../binlog_replication_test.go | 248 ++++++++---------- 7 files changed, 199 insertions(+), 170 deletions(-) diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_metadata_persistence.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_metadata_persistence.go index 01c6b38d0fa..3737498ff1f 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_metadata_persistence.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_metadata_persistence.go @@ -104,7 +104,9 @@ func persistReplicaRunningState(ctx *sql.Context, state replicaRunningState) err // loadReplicationConfiguration loads the replication configuration for default channel ("") from // the "mysql" database, |mysqlDb|. -func loadReplicationConfiguration(_ *sql.Context, mysqlDb *mysql_db.MySQLDb) (*mysql_db.ReplicaSourceInfo, error) { +func loadReplicationConfiguration(ctx *sql.Context, mysqlDb *mysql_db.MySQLDb) (*mysql_db.ReplicaSourceInfo, error) { + sql.SessionCommandBegin(ctx.Session) + defer sql.SessionCommandEnd(ctx.Session) rd := mysqlDb.Reader() defer rd.Close() diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_primary_test.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_primary_test.go index 02cb6a0a77e..9b8770befa1 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_primary_test.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_primary_test.go @@ -430,7 +430,7 @@ func TestBinlogPrimary_ReplicaRestart(t *testing.T) { // Restart the MySQL replica and reconnect to the Dolt primary prevPrimaryDatabase := primaryDatabase var err error - mySqlPort, mySqlProcess, err = startMySqlServer(testDir) + mySqlPort, mySqlProcess, err = startMySqlServer(t, testDir) require.NoError(t, err) replicaDatabase = primaryDatabase primaryDatabase = prevPrimaryDatabase @@ -1042,7 +1042,7 @@ func outputReplicaApplierStatus(t *testing.T) { newRows, err := replicaDatabase.Queryx("select * from performance_schema.replication_applier_status_by_worker;") require.NoError(t, err) allNewRows := readAllRowsIntoMaps(t, newRows) - fmt.Printf("\n\nreplication_applier_status_by_worker: %v\n", allNewRows) + t.Logf("\n\nreplication_applier_status_by_worker: %v\n", allNewRows) } // outputShowReplicaStatus prints out replica status information. This is useful for debugging @@ -1052,7 +1052,7 @@ func outputShowReplicaStatus(t *testing.T) { newRows, err := replicaDatabase.Queryx("show replica status;") require.NoError(t, err) allNewRows := readAllRowsIntoMaps(t, newRows) - fmt.Printf("\n\nSHOW REPLICA STATUS: %v\n", allNewRows) + t.Logf("\n\nSHOW REPLICA STATUS: %v\n", allNewRows) } // copyMap returns a copy of the specified map |m|. @@ -1098,7 +1098,7 @@ func waitForReplicaToReconnect(t *testing.T) { func mustRestartDoltPrimaryServer(t *testing.T) { var err error prevReplicaDatabase := replicaDatabase - doltPort, doltProcess, err = startDoltSqlServer(testDir, nil) + doltPort, doltProcess, err = startDoltSqlServer(t, testDir, nil) require.NoError(t, err) primaryDatabase = replicaDatabase replicaDatabase = prevReplicaDatabase @@ -1109,7 +1109,7 @@ func mustRestartDoltPrimaryServer(t *testing.T) { func mustRestartMySqlReplicaServer(t *testing.T) { var err error prevPrimaryDatabase := primaryDatabase - mySqlPort, mySqlProcess, err = startMySqlServer(testDir) + mySqlPort, mySqlProcess, err = startMySqlServer(t, testDir) require.NoError(t, err) replicaDatabase = primaryDatabase primaryDatabase = prevPrimaryDatabase diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_applier.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_applier.go index 3833aecc6aa..36bce2c5767 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_applier.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_applier.go @@ -19,6 +19,7 @@ import ( "io" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -63,6 +64,7 @@ type binlogReplicaApplier struct { currentPosition *mysql.Position // successfully executed GTIDs filters *filterConfiguration running atomic.Bool + handlerWg sync.WaitGroup engine *gms.Engine dbsWithUncommittedChanges map[string]struct{} } @@ -88,10 +90,14 @@ const rowFlag_rowsAreComplete = 0x0008 // Go spawns a new goroutine to run the applier's binlog event handler. func (a *binlogReplicaApplier) Go(ctx *sql.Context) { + if !a.running.CompareAndSwap(false, true) { + panic("attempt to start binlogReplicaApplier while it is already running") + } + a.handlerWg.Add(1) go func() { - a.running.Store(true) + defer a.handlerWg.Done() + defer a.running.Store(false) err := a.replicaBinlogEventHandler(ctx) - a.running.Store(false) if err != nil { ctx.GetLogger().Errorf("unexpected error of type %T: '%v'", err, err.Error()) DoltBinlogReplicaController.setSqlError(mysql.ERUnknownError, err.Error()) @@ -104,6 +110,27 @@ func (a *binlogReplicaApplier) IsRunning() bool { return a.running.Load() } +// Stop will shutdown the replication thread if it is running. This is not safe to call concurrently |Go|. +// This is used by the controller when implementing STOP REPLICA, but it is also used on shutdown when the +// replication thread should be shutdown cleanly in the event that it is still running. +func (a *binlogReplicaApplier) Stop() { + if a.IsRunning() { + // We jump through some hoops here. It is not the case that the replication thread + // is guaranteed to read from |stopReplicationChan|. Instead, it can exit on its + // own with an error, for example, after exceeding connection retry attempts. + done := make(chan struct{}) + go func() { + defer close(done) + a.handlerWg.Wait() + }() + select { + case a.stopReplicationChan <- struct{}{}: + case <-done: + } + a.handlerWg.Wait() + } +} + // connectAndStartReplicationEventStream connects to the configured MySQL replication source, including pausing // and retrying if errors are encountered. func (a *binlogReplicaApplier) connectAndStartReplicationEventStream(ctx *sql.Context) (*mysql.Conn, error) { @@ -263,25 +290,21 @@ func (a *binlogReplicaApplier) startReplicationEventStream(ctx *sql.Context, con func (a *binlogReplicaApplier) replicaBinlogEventHandler(ctx *sql.Context) error { engine := a.engine - var conn *mysql.Conn var eventProducer *binlogEventProducer // Process binlog events for { - if conn == nil { + if eventProducer == nil { ctx.GetLogger().Debug("no binlog connection to source, attempting to establish one") - if eventProducer != nil { - eventProducer.Stop() - } - var err error - if conn, err = a.connectAndStartReplicationEventStream(ctx); err == ErrReplicationStopped { + if conn, err := a.connectAndStartReplicationEventStream(ctx); err == ErrReplicationStopped { return nil } else if err != nil { return err + } else { + eventProducer = newBinlogEventProducer(conn) + eventProducer.Go(ctx) } - eventProducer = newBinlogEventProducer(conn) - eventProducer.Go(ctx) } select { @@ -305,8 +328,6 @@ func (a *binlogReplicaApplier) replicaBinlogEventHandler(ctx *sql.Context) error }) eventProducer.Stop() eventProducer = nil - conn.Close() - conn = nil } } else { // otherwise, log the error if it's something we don't expect and continue @@ -317,6 +338,7 @@ func (a *binlogReplicaApplier) replicaBinlogEventHandler(ctx *sql.Context) error case <-a.stopReplicationChan: ctx.GetLogger().Trace("received stop replication signal") eventProducer.Stop() + eventProducer = nil return nil } } @@ -325,6 +347,8 @@ func (a *binlogReplicaApplier) replicaBinlogEventHandler(ctx *sql.Context) error // processBinlogEvent processes a single binlog event message and returns an error if there were any problems // processing it. func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms.Engine, event mysql.BinlogEvent) error { + sql.SessionCommandBegin(ctx.Session) + defer sql.SessionCommandEnd(ctx.Session) var err error createCommit := false diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go index 8e22a09cd92..5ea8edd9681 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go @@ -157,7 +157,9 @@ func (d *doltBinlogReplicaController) StartReplica(ctx *sql.Context) error { // changes and execute DDL statements on the running server. If the account doesn't exist, it will be // created and locked to disable log ins, and if it does exist, but is missing super privs or is not // locked, it will be given superuser privs and locked. -func (d *doltBinlogReplicaController) configureReplicationUser(_ *sql.Context) { +func (d *doltBinlogReplicaController) configureReplicationUser(ctx *sql.Context) { + sql.SessionCommandBegin(ctx.Session) + defer sql.SessionCommandEnd(ctx.Session) mySQLDb := d.engine.Analyzer.Catalog.MySQLDb ed := mySQLDb.Editor() defer ed.Close() @@ -180,12 +182,15 @@ func (d *doltBinlogReplicaController) SetEngine(engine *sqle.Engine) { // StopReplica implements the BinlogReplicaController interface. func (d *doltBinlogReplicaController) StopReplica(ctx *sql.Context) error { + d.operationMutex.Lock() + defer d.operationMutex.Unlock() + if d.applier.IsRunning() == false { ctx.Warn(3084, "Replication thread(s) for channel '' are already stopped.") return nil } - d.applier.stopReplicationChan <- struct{}{} + d.applier.Stop() d.updateStatus(func(status *binlogreplication.ReplicaStatus) { status.ReplicaIoRunning = binlogreplication.ReplicaIoNotRunning @@ -428,6 +433,17 @@ func (d *doltBinlogReplicaController) AutoStart(_ context.Context) error { return d.StartReplica(d.ctx) } +// Release all resources, such as replication threads, associated with the replication. +// This can only be done once in the lifecycle of the instance. Because DoltBinlogReplicaController +// is currently a global singleton, this should only be done once in the lifecycle of the +// application. +func (d *doltBinlogReplicaController) Close() { + d.applier.Stop() + if d.ctx != nil { + sql.SessionEnd(d.ctx.Session) + } +} + // // Helper functions // diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_event_producer.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_event_producer.go index 34f45eb3abf..872a7ccba31 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_event_producer.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_event_producer.go @@ -15,6 +15,7 @@ package binlogreplication import ( + "sync" "sync/atomic" "github.com/dolthub/go-mysql-server/sql" @@ -30,19 +31,24 @@ type binlogEventProducer struct { conn *mysql.Conn errorChan chan error eventChan chan mysql.BinlogEvent + closeChan chan struct{} + wg sync.WaitGroup running atomic.Bool } // newBinlogEventProducer creates a new binlog event producer that reads from the specified, established MySQL // connection |conn|. The returned binlogEventProducer owns the communication channels // and is responsible for closing them when the binlogEventProducer is stopped. +// +// The BinlogEventProducer will take ownership of the supplied |*Conn| instance and +// will |Close| it when the producer itself exits. func newBinlogEventProducer(conn *mysql.Conn) *binlogEventProducer { producer := &binlogEventProducer{ conn: conn, eventChan: make(chan mysql.BinlogEvent), errorChan: make(chan error), + closeChan: make(chan struct{}), } - producer.running.Store(true) return producer } @@ -61,7 +67,14 @@ func (p *binlogEventProducer) ErrorChan() <-chan error { // Go starts this binlogEventProducer in a new goroutine. Right before this routine exits, it will close the // two communication channels it owns. func (p *binlogEventProducer) Go(_ *sql.Context) { + if !p.running.CompareAndSwap(false, true) { + panic("attempt to start binlogEventProducer more than once.") + } + p.wg.Add(1) go func() { + defer p.wg.Done() + defer close(p.errorChan) + defer close(p.eventChan) for p.IsRunning() { // ReadBinlogEvent blocks until a binlog event can be read and returned, so this has to be done on a // separate thread, otherwise the applier would be blocked and wouldn't be able to handle the STOP @@ -75,13 +88,19 @@ func (p *binlogEventProducer) Go(_ *sql.Context) { } if err != nil { - p.errorChan <- err + select { + case p.errorChan <- err: + case <-p.closeChan: + return + } } else { - p.eventChan <- event + select { + case p.eventChan <- event: + case <-p.closeChan: + return + } } } - close(p.errorChan) - close(p.eventChan) }() } @@ -92,5 +111,9 @@ func (p *binlogEventProducer) IsRunning() bool { // Stop requests for this binlogEventProducer to stop processing events as soon as possible. func (p *binlogEventProducer) Stop() { - p.running.Store(false) + if p.running.CompareAndSwap(true, false) { + p.conn.Close() + close(p.closeChan) + } + p.wg.Wait() } diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_restart_test.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_restart_test.go index 70810019f7c..1ebb0a046ca 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_restart_test.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_restart_test.go @@ -49,7 +49,7 @@ func TestBinlogReplicationServerRestart(t *testing.T) { time.Sleep(1000 * time.Millisecond) var err error - doltPort, doltProcess, err = startDoltSqlServer(testDir, nil) + doltPort, doltProcess, err = startDoltSqlServer(t, testDir, nil) require.NoError(t, err) // Check replication status on the replica and assert configuration persisted diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_test.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_test.go index 9231185dcdc..55eae2086ed 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_test.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replication_test.go @@ -29,6 +29,7 @@ import ( "slices" "strconv" "strings" + "sync" "syscall" "testing" "time" @@ -47,7 +48,6 @@ var mySqlProcess, doltProcess *os.Process var doltLogFilePath, oldDoltLogFilePath, mysqlLogFilePath string var doltLogFile, mysqlLogFile *os.File var testDir string -var originalWorkingDir string // doltReplicaSystemVars are the common system variables that need // to be set on a Dolt replica before replication is turned on. @@ -55,6 +55,48 @@ var doltReplicaSystemVars = map[string]string{ "server_id": "42", } +func TestMain(m *testing.M) { + res := func() int { + defer func() { + cachedDoltDevBuildPathOnce.Do(func() {}) + if cachedDoltDevBuildPath != "" { + os.RemoveAll(filepath.Dir(cachedDoltDevBuildPath)) + } + }() + return m.Run() + }() + os.Exit(res) +} + +var cachedDoltDevBuildPath string +var cachedDoltDevBuildPathOnce sync.Once + +func DoltDevBuildPath() string { + cachedDoltDevBuildPathOnce.Do(func() { + tmp, err := os.MkdirTemp("", "binlog-replication-doltbin-") + if err != nil { + panic(err) + } + fullpath := filepath.Join(tmp, "dolt") + + originalWorkingDir, err := os.Getwd() + if err != nil { + panic(err) + } + + goDirPath := filepath.Join(originalWorkingDir, "..", "..", "..", "..") + + cmd := exec.Command("go", "build", "-o", fullpath, "./cmd/dolt") + cmd.Dir = goDirPath + output, err := cmd.CombinedOutput() + if err != nil { + panic("unable to build dolt for binlog integration tests: " + err.Error() + "\nFull output: " + string(output) + "\n") + } + cachedDoltDevBuildPath = fullpath + }) + return cachedDoltDevBuildPath +} + func teardown(t *testing.T) { if mySqlProcess != nil { stopMySqlServer(t) @@ -72,17 +114,17 @@ func teardown(t *testing.T) { // Output server logs on failure for easier debugging if t.Failed() { if oldDoltLogFilePath != "" { - fmt.Printf("\nDolt server log from %s:\n", oldDoltLogFilePath) - printFile(oldDoltLogFilePath) + t.Logf("\nDolt server log from %s:\n", oldDoltLogFilePath) + printFile(t, oldDoltLogFilePath) } - fmt.Printf("\nDolt server log from %s:\n", doltLogFilePath) - printFile(doltLogFilePath) - fmt.Printf("\nMySQL server log from %s:\n", mysqlLogFilePath) - printFile(mysqlLogFilePath) + t.Logf("\nDolt server log from %s:\n", doltLogFilePath) + printFile(t, doltLogFilePath) + t.Logf("\nMySQL server log from %s:\n", mysqlLogFilePath) + printFile(t, mysqlLogFilePath) mysqlErrorLogFilePath := filepath.Join(filepath.Dir(mysqlLogFilePath), "error_log.err") - fmt.Printf("\nMySQL server error log from %s:\n", mysqlErrorLogFilePath) - printFile(mysqlErrorLogFilePath) + t.Logf("\nMySQL server error log from %s:\n", mysqlErrorLogFilePath) + printFile(t, mysqlErrorLogFilePath) } else { // clean up temp files on clean test runs defer os.RemoveAll(testDir) @@ -194,7 +236,7 @@ func TestAutoRestartReplica(t *testing.T) { // Restart the Dolt replica stopDoltSqlServer(t) var err error - doltPort, doltProcess, err = startDoltSqlServer(testDir, nil) + doltPort, doltProcess, err = startDoltSqlServer(t, testDir, nil) require.NoError(t, err) // Assert that some test data replicates correctly @@ -218,7 +260,7 @@ func TestAutoRestartReplica(t *testing.T) { // Restart the Dolt replica stopDoltSqlServer(t) - doltPort, doltProcess, err = startDoltSqlServer(testDir, nil) + doltPort, doltProcess, err = startDoltSqlServer(t, testDir, nil) require.NoError(t, err) // SHOW REPLICA STATUS should show that replication is NOT running, with no errors @@ -590,11 +632,13 @@ func TestCharsetsAndCollations(t *testing.T) { // Test Helper Functions // -// waitForReplicaToCatchUp waits (up to 30s) for the replica to catch up with the primary database. The -// lag is measured by checking that gtid_executed is the same on the primary and replica. +// waitForReplicaToCatchUp waits for the replica to catch up with the primary database. The +// lag is measured by checking that gtid_executed is the same on the primary and replica. If +// no progress is made in 30 seconds, this function will fail the test. func waitForReplicaToCatchUp(t *testing.T) { timeLimit := 30 * time.Second + lastReplicaGtid := "" endTime := time.Now().Add(timeLimit) for time.Now().Before(endTime) { replicaGtid := queryGtid(t, replicaDatabase) @@ -602,8 +646,11 @@ func waitForReplicaToCatchUp(t *testing.T) { if primaryGtid == replicaGtid { return + } else if lastReplicaGtid != replicaGtid { + lastReplicaGtid = replicaGtid + endTime = time.Now().Add(timeLimit) } else { - fmt.Printf("primary and replica not in sync yet... (primary: %s, replica: %s)\n", primaryGtid, replicaGtid) + t.Logf("primary and replica not in sync yet... (primary: %s, replica: %s)\n", primaryGtid, replicaGtid) time.Sleep(250 * time.Millisecond) } } @@ -639,7 +686,7 @@ func waitForReplicaToReachGtid(t *testing.T, target int) { } } - fmt.Printf("replica has not reached transaction %d yet; currently at: %s \n", target, replicaGtid) + t.Logf("replica has not reached transaction %d yet; currently at: %s \n", target, replicaGtid) } t.Fatal("replica did not reach target GTID within " + timeLimit.String()) @@ -725,20 +772,13 @@ func startSqlServersWithDoltSystemVars(t *testing.T, doltPersistentSystemVars ma testDir = filepath.Join(os.TempDir(), fmt.Sprintf("%s-%v", t.Name(), time.Now().Unix())) err := os.MkdirAll(testDir, 0777) - - cmd := exec.Command("chmod", "777", testDir) - _, err = cmd.Output() - if err != nil { - panic(err) - } - require.NoError(t, err) - fmt.Printf("temp dir: %v \n", testDir) + t.Logf("temp dir: %v \n", testDir) // Start up primary and replica databases - mySqlPort, mySqlProcess, err = startMySqlServer(testDir) + mySqlPort, mySqlProcess, err = startMySqlServer(t, testDir) require.NoError(t, err) - doltPort, doltProcess, err = startDoltSqlServer(testDir, doltPersistentSystemVars) + doltPort, doltProcess, err = startDoltSqlServer(t, testDir, doltPersistentSystemVars) require.NoError(t, err) } @@ -856,25 +896,9 @@ func findFreePort() int { // startMySqlServer configures a starts a fresh MySQL server instance and returns the port it is running on, // and the os.Process handle. If unable to start up the MySQL server, an error is returned. -func startMySqlServer(dir string) (int, *os.Process, error) { - originalCwd, err := os.Getwd() - if err != nil { - panic(err) - } - - dir = dir + string(os.PathSeparator) + "mysql" + string(os.PathSeparator) - dataDir := dir + "mysql_data" - err = os.MkdirAll(dir, 0777) - if err != nil { - return -1, nil, err - } - cmd := exec.Command("chmod", "777", dir) - output, err := cmd.Output() - if err != nil { - panic(err) - } - - err = os.Chdir(dir) +func startMySqlServer(t *testing.T, dir string) (int, *os.Process, error) { + dir = filepath.Join(dir, "mysql") + err := os.MkdirAll(dir, 0777) if err != nil { return -1, nil, err } @@ -889,28 +913,31 @@ func startMySqlServer(dir string) (int, *os.Process, error) { } username := user.Username if username == "root" { - fmt.Printf("overriding current user (root) to run mysql as 'mysql' user instead\n") + t.Logf("overriding current user (root) to run mysql as 'mysql' user instead\n") username = "mysql" } + dataDir := filepath.Join(dir, "mysql_data") + // Check to see if the MySQL data directory has the "mysql" directory in it, which // tells us whether this MySQL instance has been initialized yet or not. initialized := directoryExists(filepath.Join(dataDir, "mysql")) if !initialized { // Create a fresh MySQL server for the primary - chmodCmd := exec.Command("mysqld", + initCmd := exec.Command("mysqld", "--no-defaults", "--user="+username, "--initialize-insecure", "--datadir="+dataDir, "--default-authentication-plugin=mysql_native_password") - output, err = chmodCmd.CombinedOutput() + initCmd.Dir = dir + output, err := initCmd.CombinedOutput() if err != nil { - return -1, nil, fmt.Errorf("unable to execute command %v: %v – %v", cmd.String(), err.Error(), string(output)) + return -1, nil, fmt.Errorf("unable to execute command %v: %v – %v", initCmd.String(), err.Error(), string(output)) } } - cmd = exec.Command("mysqld", + cmd := exec.Command("mysqld", "--no-defaults", "--user="+username, "--datadir="+dataDir, @@ -920,17 +947,18 @@ func startMySqlServer(dir string) (int, *os.Process, error) { fmt.Sprintf("--port=%v", mySqlPort), "--server-id=11223344", fmt.Sprintf("--socket=mysql-%v.sock", mySqlPort), - "--general_log_file="+dir+"general_log", - "--slow_query_log_file="+dir+"slow_query_log", + "--general_log_file="+filepath.Join(dir, "general_log"), + "--slow_query_log_file="+filepath.Join(dir, "slow_query_log"), "--log-error="+dir+"error_log", - fmt.Sprintf("--pid-file="+dir+"pid-%v.pid", mySqlPort)) + fmt.Sprintf("--pid-file="+filepath.Join(dir, "pid-%v.pid"), mySqlPort)) + cmd.Dir = dir mysqlLogFilePath = filepath.Join(dir, fmt.Sprintf("mysql-%d.out.log", time.Now().Unix())) mysqlLogFile, err = os.Create(mysqlLogFilePath) if err != nil { return -1, nil, err } - fmt.Printf("MySQL server logs at: %s \n", mysqlLogFilePath) + t.Logf("MySQL server logs at: %s \n", mysqlLogFilePath) cmd.Stdout = mysqlLogFile cmd.Stderr = mysqlLogFile err = cmd.Start() @@ -941,7 +969,7 @@ func startMySqlServer(dir string) (int, *os.Process, error) { dsn := fmt.Sprintf("root@tcp(127.0.0.1:%v)/", mySqlPort) primaryDatabase = sqlx.MustOpen("mysql", dsn) - err = waitForSqlServerToStart(primaryDatabase) + err = waitForSqlServerToStart(t, primaryDatabase) if err != nil { return -1, nil, err } @@ -955,8 +983,7 @@ func startMySqlServer(dir string) (int, *os.Process, error) { dsn = fmt.Sprintf("root@tcp(127.0.0.1:%v)/", mySqlPort) primaryDatabase = sqlx.MustOpen("mysql", dsn) - os.Chdir(originalCwd) - fmt.Printf("MySQL server started on port %v \n", mySqlPort) + t.Logf("MySQL server started on port %v \n", mySqlPort) return mySqlPort, cmd.Process, nil } @@ -971,43 +998,10 @@ func directoryExists(path string) bool { return info.IsDir() } -var cachedDoltDevBuildPath = "" - -func initializeDevDoltBuild(dir string, goDirPath string) string { - if cachedDoltDevBuildPath != "" { - return cachedDoltDevBuildPath - } - - // If we're not in a CI environment, don't worry about building a dev build - if os.Getenv("CI") != "true" { - return "" - } - - basedir := filepath.Dir(filepath.Dir(dir)) - fullpath := filepath.Join(basedir, fmt.Sprintf("devDolt-%d", os.Getpid())) - - _, err := os.Stat(fullpath) - if err == nil { - return fullpath - } - - fmt.Printf("building dolt dev build at: %s \n", fullpath) - cmd := exec.Command("go", "build", "-o", fullpath, "./cmd/dolt") - cmd.Dir = goDirPath - - output, err := cmd.CombinedOutput() - if err != nil { - panic("unable to build dolt for binlog integration tests: " + err.Error() + "\nFull output: " + string(output) + "\n") - } - cachedDoltDevBuildPath = fullpath - - return cachedDoltDevBuildPath -} - // startDoltSqlServer starts a Dolt sql-server on a free port from the specified directory |dir|. If // |doltPeristentSystemVars| is populated, then those system variables will be set, persistently, for // the Dolt database, before the Dolt sql-server is started. -func startDoltSqlServer(dir string, doltPersistentSystemVars map[string]string) (int, *os.Process, error) { +func startDoltSqlServer(t *testing.T, dir string, doltPersistentSystemVars map[string]string) (int, *os.Process, error) { dir = filepath.Join(dir, "dolt") err := os.MkdirAll(dir, 0777) if err != nil { @@ -1019,57 +1013,34 @@ func startDoltSqlServer(dir string, doltPersistentSystemVars map[string]string) if doltPort < 1 { doltPort = findFreePort() } - fmt.Printf("Starting Dolt sql-server on port: %d, with data dir %s\n", doltPort, dir) - - // take the CWD and move up four directories to find the go directory - if originalWorkingDir == "" { - var err error - originalWorkingDir, err = os.Getwd() - if err != nil { - panic(err) - } - } - goDirPath := filepath.Join(originalWorkingDir, "..", "..", "..", "..") - err = os.Chdir(goDirPath) - if err != nil { - panic(err) - } - - socketPath := filepath.Join("/tmp", fmt.Sprintf("dolt.%v.sock", doltPort)) + t.Logf("Starting Dolt sql-server on port: %d, with data dir %s\n", doltPort, dir) // use an admin user NOT named "root" to test that we don't require the "root" account adminUser := "admin" if doltPersistentSystemVars != nil && len(doltPersistentSystemVars) > 0 { // Initialize the dolt directory first - err = runDoltCommand(dir, goDirPath, "init", "--name=binlog-test", "--email=binlog@test") + err = runDoltCommand(t, dir, "init", "--name=binlog-test", "--email=binlog@test") if err != nil { return -1, nil, err } for systemVar, value := range doltPersistentSystemVars { query := fmt.Sprintf("SET @@PERSIST.%s=%s;", systemVar, value) - err = runDoltCommand(dir, goDirPath, "sql", fmt.Sprintf("-q=%s", query)) + err = runDoltCommand(t, dir, "sql", fmt.Sprintf("-q=%s", query)) if err != nil { return -1, nil, err } } } - args := []string{"go", "run", "./cmd/dolt", + args := []string{DoltDevBuildPath(), "sql-server", fmt.Sprintf("-u%s", adminUser), "--loglevel=TRACE", fmt.Sprintf("--data-dir=%s", dir), - fmt.Sprintf("--port=%v", doltPort), - fmt.Sprintf("--socket=%s", socketPath)} - - // If we're running in CI, use a precompiled dolt binary instead of go run - devDoltPath := initializeDevDoltBuild(dir, goDirPath) - if devDoltPath != "" { - args[2] = devDoltPath - args = args[2:] - } + fmt.Sprintf("--port=%v", doltPort)} + cmd := exec.Command(args[0], args[1:]...) // Set a unique process group ID so that we can cleanly kill this process, as well as @@ -1094,7 +1065,7 @@ func startDoltSqlServer(dir string, doltPersistentSystemVars map[string]string) if err != nil { return -1, nil, err } - fmt.Printf("dolt sql-server logs at: %s \n", doltLogFilePath) + t.Logf("dolt sql-server logs at: %s \n", doltLogFilePath) cmd.Stdout = doltLogFile cmd.Stderr = doltLogFile err = cmd.Start() @@ -1102,18 +1073,18 @@ func startDoltSqlServer(dir string, doltPersistentSystemVars map[string]string) return -1, nil, fmt.Errorf("unable to execute command %v: %v", cmd.String(), err.Error()) } - fmt.Printf("Dolt CMD: %s\n", cmd.String()) + t.Logf("Dolt CMD: %s\n", cmd.String()) dsn := fmt.Sprintf("%s@tcp(127.0.0.1:%v)/", adminUser, doltPort) replicaDatabase = sqlx.MustOpen("mysql", dsn) - err = waitForSqlServerToStart(replicaDatabase) + err = waitForSqlServerToStart(t, replicaDatabase) if err != nil { return -1, nil, err } mustCreateReplicatorUser(replicaDatabase) - fmt.Printf("Dolt server started on port %v \n", doltPort) + t.Logf("Dolt server started on port %v \n", doltPort) return doltPort, cmd.Process, nil } @@ -1125,24 +1096,17 @@ func mustCreateReplicatorUser(db *sqlx.DB) { } // runDoltCommand runs a short-lived dolt CLI command with the specified arguments from |doltArgs|. The Dolt data -// directory is specified from |doltDataDir| and |goDirPath| is the path to the go directory within the Dolt repo. +// directory is specified from |doltDataDir|. // This function will only return when the Dolt CLI command has completed, so it is not suitable for running // long-lived commands such as "sql-server". If the command fails, an error is returned with the combined output. -func runDoltCommand(doltDataDir string, goDirPath string, doltArgs ...string) error { - // If we're running in CI, use a precompiled dolt binary instead of go run - devDoltPath := initializeDevDoltBuild(doltDataDir, goDirPath) - - args := append([]string{"go", "run", "./cmd/dolt", +func runDoltCommand(t *testing.T, doltDataDir string, doltArgs ...string) error { + args := append([]string{DoltDevBuildPath(), fmt.Sprintf("--data-dir=%s", doltDataDir)}, doltArgs...) - if devDoltPath != "" { - args[2] = devDoltPath - args = args[2:] - } cmd := exec.Command(args[0], args[1:]...) - fmt.Printf("Running Dolt CMD: %s\n", cmd.String()) + t.Logf("Running Dolt CMD: %s\n", cmd.String()) output, err := cmd.CombinedOutput() - fmt.Printf("Dolt CMD output: %s\n", string(output)) + t.Logf("Dolt CMD output: %s\n", string(output)) if err != nil { return fmt.Errorf("%w: %s", err, string(output)) } @@ -1152,13 +1116,13 @@ func runDoltCommand(doltDataDir string, goDirPath string, doltArgs ...string) er // waitForSqlServerToStart polls the specified database to wait for it to become available, pausing // between retry attempts, and returning an error if it is not able to verify that the database is // available. -func waitForSqlServerToStart(database *sqlx.DB) error { - fmt.Printf("Waiting for server to start...\n") +func waitForSqlServerToStart(t *testing.T, database *sqlx.DB) error { + t.Logf("Waiting for server to start...\n") for counter := 0; counter < 30; counter++ { if database.Ping() == nil { return nil } - fmt.Printf("not up yet; waiting...\n") + t.Logf("not up yet; waiting...\n") time.Sleep(500 * time.Millisecond) } @@ -1166,10 +1130,10 @@ func waitForSqlServerToStart(database *sqlx.DB) error { } // printFile opens the specified filepath |path| and outputs the contents of that file to stdout. -func printFile(path string) { +func printFile(t *testing.T, path string) { file, err := os.Open(path) if err != nil { - fmt.Printf("Unable to open file: %s \n", err) + t.Logf("Unable to open file: %s \n", err) return } defer file.Close() @@ -1184,9 +1148,9 @@ func printFile(path string) { panic(err) } } - fmt.Print(s) + t.Log(s) } - fmt.Println() + t.Log() } // assertRepoStateFileExists asserts that the repo_state.json file is present for the specified From 09816e30e98196434e4551a3fd023722081a7633 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Mon, 27 Jan 2025 17:57:35 -0800 Subject: [PATCH 23/34] go: cmd/dolt: sqlengine: Actually call DoltBinlogReplicaController.Close when closing the SqlEngine. --- go/cmd/dolt/commands/engine/sqlengine.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/go/cmd/dolt/commands/engine/sqlengine.go b/go/cmd/dolt/commands/engine/sqlengine.go index d37f25e212b..af9adb0f50f 100644 --- a/go/cmd/dolt/commands/engine/sqlengine.go +++ b/go/cmd/dolt/commands/engine/sqlengine.go @@ -316,6 +316,9 @@ func (se *SqlEngine) GetUnderlyingEngine() *gms.Engine { func (se *SqlEngine) Close() error { if se.engine != nil { + if se.engine.Analyzer.Catalog.BinlogReplicaController != nil { + dblr.DoltBinlogReplicaController.Close() + } return se.engine.Close() } return nil From 6f6207baafb3c0c109859b9a8390f0ebeafe9a9e Mon Sep 17 00:00:00 2001 From: zachmu Date: Tue, 28 Jan 2025 18:30:59 +0000 Subject: [PATCH 24/34] [ga-bump-dep] Bump dependency in Dolt by zachmu --- go/go.mod | 2 +- go/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/go.mod b/go/go.mod index 50ac73be0e7..9122e3a0630 100644 --- a/go/go.mod +++ b/go/go.mod @@ -56,7 +56,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 github.com/creasty/defaults v1.6.0 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.19.1-0.20250124213954-8a1af52235d7 + github.com/dolthub/go-mysql-server v0.19.1-0.20250128182847-3f5bb8c52cd8 github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 github.com/dolthub/swiss v0.1.0 github.com/esote/minmaxheap v1.0.0 diff --git a/go/go.sum b/go/go.sum index 223d7a54136..748abe7421f 100644 --- a/go/go.sum +++ b/go/go.sum @@ -179,8 +179,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90 h1:Sni8jrP0sy/w9ZYXoff4g/ixe+7bFCZlfCqXKJSU+zM= github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.19.1-0.20250124213954-8a1af52235d7 h1:DjirOAU+gMlWqr3Ut9PsVT5iqdirAcLr84Cbbi60Kis= -github.com/dolthub/go-mysql-server v0.19.1-0.20250124213954-8a1af52235d7/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc= +github.com/dolthub/go-mysql-server v0.19.1-0.20250128182847-3f5bb8c52cd8 h1:eEGYHOC5Ft+56yPaH26gsdbonrZ2EiTwQLy8Oj3TAFE= +github.com/dolthub/go-mysql-server v0.19.1-0.20250128182847-3f5bb8c52cd8/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= From c8e47a80fb51ca709f03d95bfe4b34b3bd7204ba Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Tue, 28 Jan 2025 12:01:35 -0800 Subject: [PATCH 25/34] go/store/nbs: generational_chunk_store.go: In GCMode_Full, also take dependencies on chunks read from OldGen. --- go/store/chunks/chunk_store.go | 10 +++- go/store/chunks/memory_store.go | 4 +- go/store/chunks/test_utils.go | 8 +-- go/store/nbs/generational_chunk_store.go | 33 +++++++++-- go/store/nbs/nbs_metrics_wrapper.go | 8 +-- go/store/nbs/store.go | 32 +++++------ go/store/nbs/store_test.go | 4 +- go/store/types/value_store.go | 12 ++-- .../concurrent_gc_test.go | 56 +++++++++++++------ 9 files changed, 107 insertions(+), 60 deletions(-) diff --git a/go/store/chunks/chunk_store.go b/go/store/chunks/chunk_store.go index 19bf4d07f96..bb5fe17a162 100644 --- a/go/store/chunks/chunk_store.go +++ b/go/store/chunks/chunk_store.go @@ -225,11 +225,11 @@ type ChunkStoreGarbageCollector interface { // // This function should not block indefinitely and should return an // error if a GC is already in progress. - BeginGC(addChunk func(hash.Hash) bool) error + BeginGC(addChunk func(hash.Hash) bool, mode GCMode) error // EndGC indicates that the GC is over. The previously provided // addChunk function must not be called after this function. - EndGC() + EndGC(mode GCMode) // MarkAndSweepChunks returns a handle that can be used to supply // hashes which should be saved into |dest|. The hashes are @@ -257,6 +257,12 @@ type GenerationalCS interface { NewGen() ChunkStoreGarbageCollector OldGen() ChunkStoreGarbageCollector GhostGen() ChunkStore + + // Has the same return values as OldGen().HasMany, but should be used by a + // generational GC process as the filter function instead of + // OldGen().HasMany. This function never takes read dependencies on the + // chunks that it queries. + OldGenGCFilter() HasManyFunc } var ErrUnsupportedOperation = errors.New("operation not supported") diff --git a/go/store/chunks/memory_store.go b/go/store/chunks/memory_store.go index a7fd1ae5725..ec664b03748 100644 --- a/go/store/chunks/memory_store.go +++ b/go/store/chunks/memory_store.go @@ -335,11 +335,11 @@ func (ms *MemoryStoreView) Commit(ctx context.Context, current, last hash.Hash) return success, nil } -func (ms *MemoryStoreView) BeginGC(keeper func(hash.Hash) bool) error { +func (ms *MemoryStoreView) BeginGC(keeper func(hash.Hash) bool, _ GCMode) error { return ms.transitionToGC(keeper) } -func (ms *MemoryStoreView) EndGC() { +func (ms *MemoryStoreView) EndGC(_ GCMode) { ms.transitionToNoGC() } diff --git a/go/store/chunks/test_utils.go b/go/store/chunks/test_utils.go index 36e7467bbb6..084166caaa1 100644 --- a/go/store/chunks/test_utils.go +++ b/go/store/chunks/test_utils.go @@ -75,20 +75,20 @@ func (s *TestStoreView) Put(ctx context.Context, c Chunk, getAddrs GetAddrsCurry return s.ChunkStore.Put(ctx, c, getAddrs) } -func (s *TestStoreView) BeginGC(keeper func(hash.Hash) bool) error { +func (s *TestStoreView) BeginGC(keeper func(hash.Hash) bool, mode GCMode) error { collector, ok := s.ChunkStore.(ChunkStoreGarbageCollector) if !ok { return ErrUnsupportedOperation } - return collector.BeginGC(keeper) + return collector.BeginGC(keeper, mode) } -func (s *TestStoreView) EndGC() { +func (s *TestStoreView) EndGC(mode GCMode) { collector, ok := s.ChunkStore.(ChunkStoreGarbageCollector) if !ok { panic(ErrUnsupportedOperation) } - collector.EndGC() + collector.EndGC(mode) } func (s *TestStoreView) MarkAndSweepChunks(ctx context.Context, getAddrs GetAddrsCurry, filter HasManyFunc, dest ChunkStore, mode GCMode) (MarkAndSweeper, error) { diff --git a/go/store/nbs/generational_chunk_store.go b/go/store/nbs/generational_chunk_store.go index e8790ec7e8c..cfbbf33e4ba 100644 --- a/go/store/nbs/generational_chunk_store.go +++ b/go/store/nbs/generational_chunk_store.go @@ -494,12 +494,37 @@ func (gcs *GenerationalNBS) UpdateManifest(ctx context.Context, updates map[hash return gcs.newGen.UpdateManifest(ctx, updates) } -func (gcs *GenerationalNBS) BeginGC(keeper func(hash.Hash) bool) error { - return gcs.newGen.BeginGC(keeper) +func (gcs *GenerationalNBS) OldGenGCFilter() chunks.HasManyFunc { + return func(ctx context.Context, hashes hash.HashSet) (hash.HashSet, error) { + return gcs.oldGen.hasManyDep(ctx, hashes, gcDependencyMode_NoDependency) + } } -func (gcs *GenerationalNBS) EndGC() { - gcs.newGen.EndGC() +func (gcs *GenerationalNBS) BeginGC(keeper func(hash.Hash) bool, mode chunks.GCMode) error { + err := gcs.newGen.BeginGC(keeper, mode) + if err != nil { + return err + } + // In GCMode_Full, the OldGen is also being collected. In normal + // operation, the OldGen is not being collected because it is + // still growing monotonically and nothing in it is at risk of + // going away. In Full mode, we want to take read dependencies + // from the OldGen as well. + if mode == chunks.GCMode_Full { + err = gcs.oldGen.BeginGC(keeper, mode) + if err != nil { + gcs.newGen.EndGC(mode) + return err + } + } + return nil +} + +func (gcs *GenerationalNBS) EndGC(mode chunks.GCMode) { + if mode == chunks.GCMode_Full { + gcs.oldGen.EndGC(mode) + } + gcs.newGen.EndGC(mode) } func (gcs *GenerationalNBS) MarkAndSweepChunks(ctx context.Context, getAddrs chunks.GetAddrsCurry, filter chunks.HasManyFunc, dest chunks.ChunkStore, mode chunks.GCMode) (chunks.MarkAndSweeper, error) { diff --git a/go/store/nbs/nbs_metrics_wrapper.go b/go/store/nbs/nbs_metrics_wrapper.go index 1ca852da04d..36b262075b7 100644 --- a/go/store/nbs/nbs_metrics_wrapper.go +++ b/go/store/nbs/nbs_metrics_wrapper.go @@ -71,12 +71,12 @@ func (nbsMW *NBSMetricWrapper) SupportedOperations() chunks.TableFileStoreOps { return nbsMW.nbs.SupportedOperations() } -func (nbsMW *NBSMetricWrapper) BeginGC(keeper func(hash.Hash) bool) error { - return nbsMW.nbs.BeginGC(keeper) +func (nbsMW *NBSMetricWrapper) BeginGC(keeper func(hash.Hash) bool, mode chunks.GCMode) error { + return nbsMW.nbs.BeginGC(keeper, mode) } -func (nbsMW *NBSMetricWrapper) EndGC() { - nbsMW.nbs.EndGC() +func (nbsMW *NBSMetricWrapper) EndGC(mode chunks.GCMode) { + nbsMW.nbs.EndGC(mode) } func (nbsMW *NBSMetricWrapper) MarkAndSweepChunks(ctx context.Context, getAddrs chunks.GetAddrsCurry, filter chunks.HasManyFunc, dest chunks.ChunkStore, mode chunks.GCMode) (chunks.MarkAndSweeper, error) { diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index bc39ceee3be..a9c4911e360 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -273,11 +273,6 @@ func (nbs *NomsBlockStore) conjoinIfRequired(ctx context.Context) (bool, error) func (nbs *NomsBlockStore) UpdateManifest(ctx context.Context, updates map[hash.Hash]uint32) (mi ManifestInfo, err error) { nbs.mu.Lock() defer nbs.mu.Unlock() - err = nbs.waitForGC(ctx) - if err != nil { - return - } - err = nbs.checkAllManifestUpdatesExist(ctx, updates) if err != nil { return @@ -361,11 +356,6 @@ func (nbs *NomsBlockStore) UpdateManifest(ctx context.Context, updates map[hash. func (nbs *NomsBlockStore) UpdateManifestWithAppendix(ctx context.Context, updates map[hash.Hash]uint32, option ManifestAppendixOption) (mi ManifestInfo, err error) { nbs.mu.Lock() defer nbs.mu.Unlock() - err = nbs.waitForGC(ctx) - if err != nil { - return - } - err = nbs.checkAllManifestUpdatesExist(ctx, updates) if err != nil { return @@ -517,11 +507,6 @@ func fromManifestAppendixOptionNewContents(upstream manifestContents, appendixSp func OverwriteStoreManifest(ctx context.Context, store *NomsBlockStore, root hash.Hash, tableFiles map[hash.Hash]uint32, appendixTableFiles map[hash.Hash]uint32) (err error) { store.mu.Lock() defer store.mu.Unlock() - err = store.waitForGC(ctx) - if err != nil { - return - } - contents := manifestContents{ root: root, nbfVers: store.upstream.nbfVers, @@ -1128,6 +1113,10 @@ func (nbs *NomsBlockStore) Has(ctx context.Context, h hash.Hash) (bool, error) { } func (nbs *NomsBlockStore) HasMany(ctx context.Context, hashes hash.HashSet) (hash.HashSet, error) { + return nbs.hasManyDep(ctx, hashes, gcDependencyMode_TakeDependency) +} + +func (nbs *NomsBlockStore) hasManyDep(ctx context.Context, hashes hash.HashSet, gcDepMode gcDependencyMode) (hash.HashSet, error) { if hashes.Size() == 0 { return nil, nil } @@ -1143,7 +1132,11 @@ func (nbs *NomsBlockStore) HasMany(ctx context.Context, hashes hash.HashSet) (ha nbs.mu.Lock() if nbs.mt != nil { - remaining, gcb, err := nbs.mt.hasMany(reqs, nbs.keeperFunc) + keeper := nbs.keeperFunc + if gcDepMode == gcDependencyMode_NoDependency { + keeper = nil + } + remaining, gcb, err := nbs.mt.hasMany(reqs, keeper) if err != nil { nbs.mu.Unlock() return nil, err @@ -1162,6 +1155,9 @@ func (nbs *NomsBlockStore) HasMany(ctx context.Context, hashes hash.HashSet) (ha } } tables, keeper, endRead := nbs.tables, nbs.keeperFunc, nbs.beginRead() + if gcDepMode == gcDependencyMode_NoDependency { + keeper = nil + } nbs.mu.Unlock() remaining, gcb, err := tables.hasMany(reqs, keeper) @@ -1730,7 +1726,7 @@ func (nbs *NomsBlockStore) pruneTableFiles(ctx context.Context) (err error) { }, mtime) } -func (nbs *NomsBlockStore) BeginGC(keeper func(hash.Hash) bool) error { +func (nbs *NomsBlockStore) BeginGC(keeper func(hash.Hash) bool, _ chunks.GCMode) error { nbs.cond.L.Lock() defer nbs.cond.L.Unlock() if nbs.gcInProgress { @@ -1742,7 +1738,7 @@ func (nbs *NomsBlockStore) BeginGC(keeper func(hash.Hash) bool) error { return nil } -func (nbs *NomsBlockStore) EndGC() { +func (nbs *NomsBlockStore) EndGC(_ chunks.GCMode) { nbs.cond.L.Lock() defer nbs.cond.L.Unlock() if !nbs.gcInProgress { diff --git a/go/store/nbs/store_test.go b/go/store/nbs/store_test.go index 82dc8d89005..ffcb5459546 100644 --- a/go/store/nbs/store_test.go +++ b/go/store/nbs/store_test.go @@ -337,7 +337,7 @@ func TestNBSCopyGC(t *testing.T) { require.NoError(t, err) require.True(t, ok) - require.NoError(t, st.BeginGC(nil)) + require.NoError(t, st.BeginGC(nil, chunks.GCMode_Full)) noopFilter := func(ctx context.Context, hashes hash.HashSet) (hash.HashSet, error) { return hashes, nil } @@ -352,7 +352,7 @@ func TestNBSCopyGC(t *testing.T) { require.NoError(t, err) require.NoError(t, sweeper.Close(ctx)) require.NoError(t, finalizer.SwapChunksInStore(ctx)) - st.EndGC() + st.EndGC(chunks.GCMode_Full) for h, c := range keepers { out, err := st.Get(ctx, h) diff --git a/go/store/types/value_store.go b/go/store/types/value_store.go index 026a97bcd39..fb348da918f 100644 --- a/go/store/types/value_store.go +++ b/go/store/types/value_store.go @@ -591,7 +591,7 @@ func (lvs *ValueStore) GC(ctx context.Context, mode GCMode, oldGenRefs, newGenRe var oldGenHasMany chunks.HasManyFunc switch mode { case GCModeDefault: - oldGenHasMany = oldGen.HasMany + oldGenHasMany = gcs.OldGenGCFilter() chksMode = chunks.GCMode_Default case GCModeFull: oldGenHasMany = unfilteredHashFunc @@ -601,11 +601,11 @@ func (lvs *ValueStore) GC(ctx context.Context, mode GCMode, oldGenRefs, newGenRe } err := func() error { - err := collector.BeginGC(lvs.gcAddChunk) + err := collector.BeginGC(lvs.gcAddChunk, chksMode) if err != nil { return err } - defer collector.EndGC() + defer collector.EndGC(chksMode) var callCancelSafepoint bool if safepoint != nil { @@ -650,7 +650,7 @@ func (lvs *ValueStore) GC(ctx context.Context, mode GCMode, oldGenRefs, newGenRe } if mode == GCModeDefault { - oldGenHasMany = oldGen.HasMany + oldGenHasMany = gcs.OldGenGCFilter() } else { oldGenHasMany = newFileHasMany } @@ -685,11 +685,11 @@ func (lvs *ValueStore) GC(ctx context.Context, mode GCMode, oldGenRefs, newGenRe newGenRefs.InsertAll(oldGenRefs) err := func() error { - err := collector.BeginGC(lvs.gcAddChunk) + err := collector.BeginGC(lvs.gcAddChunk, chunks.GCMode_Full) if err != nil { return err } - defer collector.EndGC() + defer collector.EndGC(chunks.GCMode_Full) var callCancelSafepoint bool if safepoint != nil { diff --git a/integration-tests/go-sql-server-driver/concurrent_gc_test.go b/integration-tests/go-sql-server-driver/concurrent_gc_test.go index 002876c14f3..838dfb444e3 100644 --- a/integration-tests/go-sql-server-driver/concurrent_gc_test.go +++ b/integration-tests/go-sql-server-driver/concurrent_gc_test.go @@ -31,15 +31,41 @@ import ( ) func TestConcurrentGC(t *testing.T) { - var gct gcTest - gct.numThreads = 8 - gct.duration = 10 * time.Second t.Run("NoCommits", func(t *testing.T) { - gct.run(t) + t.Run("Normal", func(t *testing.T) { + var gct = gcTest{ + numThreads: 8, + duration: 10 * time.Second, + } + gct.run(t) + }) + t.Run("Full", func(t *testing.T) { + var gct = gcTest{ + numThreads: 8, + duration: 10 * time.Second, + full: true, + } + gct.run(t) + }) }) - gct.commit = true t.Run("WithCommits", func(t *testing.T) { - gct.run(t) + t.Run("Normal", func(t *testing.T) { + var gct = gcTest{ + numThreads: 8, + duration: 10 * time.Second, + commit: true, + } + gct.run(t) + }) + t.Run("Full", func(t *testing.T) { + var gct = gcTest{ + numThreads: 8, + duration: 10 * time.Second, + commit: true, + full: true, + } + gct.run(t) + }) }) } @@ -47,6 +73,7 @@ type gcTest struct { numThreads int duration time.Duration commit bool + full bool } func (gct gcTest) createDB(t *testing.T, ctx context.Context, db *sql.DB) { @@ -118,19 +145,12 @@ func (gct gcTest) doGC(t *testing.T, ctx context.Context, db *sql.DB) error { }) }() b := time.Now() - _, err = conn.ExecContext(ctx, "call dolt_gc()") - if err != nil { - if !assert.NotContains(t, err.Error(), "dangling ref") { - return err - } - if !assert.NotContains(t, err.Error(), "is unexpected noms value") { - return err - } - if !assert.NotContains(t, err.Error(), "interface conversion: types.Value is nil") { - return err - } - t.Logf("err in Exec dolt_gc: %v", err) + if !gct.full { + _, err = conn.ExecContext(ctx, "call dolt_gc()") } else { + _, err = conn.ExecContext(ctx, `call dolt_gc("--full")`) + } + if assert.NoError(t, err) { t.Logf("successful dolt_gc took %v", time.Since(b)) } return nil From 872f8d9b8f3cfec94aa00aadcaaabf89257d5ce4 Mon Sep 17 00:00:00 2001 From: coffeegoddd Date: Tue, 21 Jan 2025 15:32:35 -0800 Subject: [PATCH 26/34] /go/libraries/doltcore/env/actions: make iter resolved tags paginated, sort in lexicographical order --- go/libraries/doltcore/env/actions/tag.go | 107 +++++++++---- go/libraries/doltcore/env/actions/tag_test.go | 150 ++++++++++++++++++ 2 files changed, 223 insertions(+), 34 deletions(-) create mode 100644 go/libraries/doltcore/env/actions/tag_test.go diff --git a/go/libraries/doltcore/env/actions/tag.go b/go/libraries/doltcore/env/actions/tag.go index 1580ad43808..d092348dd83 100644 --- a/go/libraries/doltcore/env/actions/tag.go +++ b/go/libraries/doltcore/env/actions/tag.go @@ -17,7 +17,6 @@ package actions import ( "context" "fmt" - "sort" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" @@ -25,6 +24,8 @@ import ( "github.com/dolthub/dolt/go/store/datas" ) +const DefaultPageSize = 100 + type TagProps struct { TaggerName string TaggerEmail string @@ -97,67 +98,105 @@ func DeleteTagsOnDB(ctx context.Context, ddb *doltdb.DoltDB, tagNames ...string) return nil } -// IterResolvedTags iterates over tags in dEnv.DoltDB from newest to oldest, resolving the tag to a commit and calling cb(). -func IterResolvedTags(ctx context.Context, ddb *doltdb.DoltDB, cb func(tag *doltdb.Tag) (stop bool, err error)) error { +// IterUnresolvedTags iterates over tags in dEnv.DoltDB, and calls cb() for each with an unresolved Tag. +func IterUnresolvedTags(ctx context.Context, ddb *doltdb.DoltDB, cb func(tag *doltdb.TagResolver) (stop bool, err error)) error { tagRefs, err := ddb.GetTags(ctx) - if err != nil { return err } - var resolved []*doltdb.Tag - for _, r := range tagRefs { - tr, ok := r.(ref.TagRef) - if !ok { - return fmt.Errorf("DoltDB.GetTags() returned non-tag DoltRef") - } + tagResolvers, err := ddb.GetTagResolvers(ctx, tagRefs) + if err != nil { + return err + } - tag, err := ddb.ResolveTag(ctx, tr) + for _, tagResolver := range tagResolvers { + stop, err := cb(&tagResolver) if err != nil { return err } + if stop { + break + } + } + return nil +} + +// IterResolvedTagsPaginated iterates over tags in dEnv.DoltDB in their default lexicographical order, resolving the tag to a commit and calling cb(). +// Returns the next tag name if there are more results available. +func IterResolvedTagsPaginated(ctx context.Context, ddb *doltdb.DoltDB, startTag string, cb func(tag *doltdb.Tag) (stop bool, err error)) (string, error) { + // tags returned here are sorted lexicographically + tagRefs, err := ddb.GetTags(ctx) + if err != nil { + return "", err + } - resolved = append(resolved, tag) + // find starting index based on start tag + startIdx := 0 + if startTag != "" { + for i, tr := range tagRefs { + if tr.GetPath() == startTag { + startIdx = i + 1 // start after the given tag + break + } + } } - // iterate newest to oldest - sort.Slice(resolved, func(i, j int) bool { - return resolved[i].Meta.Timestamp > resolved[j].Meta.Timestamp - }) + // get page of results + endIdx := startIdx + DefaultPageSize + if endIdx > len(tagRefs) { + endIdx = len(tagRefs) + } - for _, tag := range resolved { - stop, err := cb(tag) + pageTagRefs := tagRefs[startIdx:endIdx] + // resolve tags for this page + for _, tr := range pageTagRefs { + tag, err := ddb.ResolveTag(ctx, tr.(ref.TagRef)) if err != nil { - return err + return "", err } + + stop, err := cb(tag) + if err != nil { + return "", err + } + if stop { break } } - return nil -} -// IterUnresolvedTags iterates over tags in dEnv.DoltDB, and calls cb() for each with an unresovled Tag. -func IterUnresolvedTags(ctx context.Context, ddb *doltdb.DoltDB, cb func(tag *doltdb.TagResolver) (stop bool, err error)) error { - tagRefs, err := ddb.GetTags(ctx) - if err != nil { - return err + // return next tag name if there are more results + if endIdx < len(tagRefs) { + lastTag := pageTagRefs[len(pageTagRefs)-1] + return lastTag.GetPath(), nil } - tagResolvers, err := ddb.GetTagResolvers(ctx, tagRefs) + return "", nil +} + +// VisitResolvedTag iterates over tags in ddb until the given tag name is found, then calls cb() with the resolved tag. +func VisitResolvedTag(ctx context.Context, ddb *doltdb.DoltDB, tagName string, cb func(tag *doltdb.Tag) error) error { + tagRefs, err := ddb.GetTags(ctx) if err != nil { return err } - for _, tagResolver := range tagResolvers { - stop, err := cb(&tagResolver) - if err != nil { - return err + for _, r := range tagRefs { + tr, ok := r.(ref.TagRef) + if !ok { + return fmt.Errorf("DoltDB.GetTags() returned non-tag DoltRef") } - if stop { - break + + if tr.GetPath() == tagName { + tag, err := ddb.ResolveTag(ctx, tr) + if err != nil { + return err + } + return cb(tag) } } - return nil + + return doltdb.ErrTagNotFound } diff --git a/go/libraries/doltcore/env/actions/tag_test.go b/go/libraries/doltcore/env/actions/tag_test.go new file mode 100644 index 00000000000..6ad86189abb --- /dev/null +++ b/go/libraries/doltcore/env/actions/tag_test.go @@ -0,0 +1,150 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package actions + +import ( + "context" + "fmt" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/utils/filesys" + "github.com/dolthub/dolt/go/store/types" +) + +const ( + testHomeDir = "/user/bheni" + workingDir = "/user/bheni/datasets/addresses" + credsDir = "creds" + + configFile = "config.json" + GlobalConfigFile = "config_global.json" +) + +func testHomeDirFunc() (string, error) { + return testHomeDir, nil +} + +func createTestEnv() (*env.DoltEnv, *filesys.InMemFS) { + initialDirs := []string{testHomeDir, workingDir} + initialFiles := map[string][]byte{} + + fs := filesys.NewInMemFS(initialDirs, initialFiles, workingDir) + dEnv := env.Load(context.Background(), testHomeDirFunc, fs, doltdb.InMemDoltDB, "test") + + return dEnv, fs +} + +func TestVisitResolvedTag(t *testing.T) { + dEnv, _ := createTestEnv() + ctx := context.Background() + + // Initialize repo + err := dEnv.InitRepo(ctx, types.Format_Default, "test user", "test@test.com", "main") + require.NoError(t, err) + + // Create a tag + tagName := "test-tag" + tagMsg := "test tag message" + err = CreateTag(ctx, dEnv, tagName, "main", TagProps{TaggerName: "test user", TaggerEmail: "test@test.com", Description: tagMsg}) + require.NoError(t, err) + + // Visit the tag and verify its properties + var foundTag *doltdb.Tag + err = VisitResolvedTag(ctx, dEnv.DoltDB, tagName, func(tag *doltdb.Tag) error { + foundTag = tag + return nil + }) + require.NoError(t, err) + require.NotNil(t, foundTag) + require.Equal(t, tagName, foundTag.Name) + require.Equal(t, tagMsg, foundTag.Meta.Description) + + // Test visiting non-existent tag + err = VisitResolvedTag(ctx, dEnv.DoltDB, "non-existent-tag", func(tag *doltdb.Tag) error { + return nil + }) + require.Equal(t, doltdb.ErrTagNotFound, err) +} + +func TestIterResolvedTagsPaginated(t *testing.T) { + dEnv, _ := createTestEnv() + ctx := context.Background() + + // Initialize repo + err := dEnv.InitRepo(ctx, types.Format_Default, "test user", "test@test.com", "main") + require.NoError(t, err) + + expectedTagNames := make([]string, DefaultPageSize*2) + // Create multiple tags with different timestamps + tagNames := make([]string, DefaultPageSize*2) + for i := range tagNames { + tagName := fmt.Sprintf("tag-%d", i) + err = CreateTag(ctx, dEnv, tagName, "main", TagProps{ + TaggerName: "test user", + TaggerEmail: "test@test.com", + Description: fmt.Sprintf("test tag %s", tagName), + }) + time.Sleep(2 * time.Millisecond) + require.NoError(t, err) + tagNames[i] = tagName + expectedTagNames[i] = tagName + } + + // Sort expected tag names to ensure they are in the correct order + sort.Strings(expectedTagNames) + + // Test first page + var foundTags []string + pageToken, err := IterResolvedTagsPaginated(ctx, dEnv.DoltDB, "", func(tag *doltdb.Tag) (bool, error) { + foundTags = append(foundTags, tag.Name) + return false, nil + }) + require.NoError(t, err) + require.NotEmpty(t, pageToken) // Should have next page + require.Equal(t, DefaultPageSize, len(foundTags)) // Default page size tags returned + require.Equal(t, expectedTagNames[:DefaultPageSize], foundTags) + + // Test second page + var secondPageTags []string + nextPageToken, err := IterResolvedTagsPaginated(ctx, dEnv.DoltDB, pageToken, func(tag *doltdb.Tag) (bool, error) { + secondPageTags = append(secondPageTags, tag.Name) + return false, nil + }) + + require.NoError(t, err) + require.Empty(t, nextPageToken) // Should be no more pages + require.Equal(t, DefaultPageSize, len(secondPageTags)) // Remaining tags + require.Equal(t, expectedTagNames[DefaultPageSize:], secondPageTags) + + // Verify all tags were found + allFoundTags := append(foundTags, secondPageTags...) + require.Equal(t, len(tagNames), len(allFoundTags)) + require.Equal(t, expectedTagNames, allFoundTags) + + // Test early termination + var earlyTermTags []string + _, err = IterResolvedTagsPaginated(ctx, dEnv.DoltDB, "", func(tag *doltdb.Tag) (bool, error) { + earlyTermTags = append(earlyTermTags, tag.Name) + return true, nil // Stop after first tag + }) + require.NoError(t, err) + require.Equal(t, 1, len(earlyTermTags)) +} From a80e71b3535a84aac7ed71e9a5b9eb185bc631e3 Mon Sep 17 00:00:00 2001 From: coffeegoddd Date: Thu, 30 Jan 2025 09:29:28 -0800 Subject: [PATCH 27/34] /go/libraries/doltcore/env/actions/tag.go: reinstage iterresolved tags for better deprecation process --- go/libraries/doltcore/env/actions/tag.go | 42 ++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/go/libraries/doltcore/env/actions/tag.go b/go/libraries/doltcore/env/actions/tag.go index d092348dd83..a6bb57e8c5e 100644 --- a/go/libraries/doltcore/env/actions/tag.go +++ b/go/libraries/doltcore/env/actions/tag.go @@ -17,6 +17,7 @@ package actions import ( "context" "fmt" + "sort" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" @@ -122,6 +123,47 @@ func IterUnresolvedTags(ctx context.Context, ddb *doltdb.DoltDB, cb func(tag *do return nil } +// IterResolvedTags iterates over tags in dEnv.DoltDB from newest to oldest, resolving the tag to a commit and calling cb(). +func IterResolvedTags(ctx context.Context, ddb *doltdb.DoltDB, cb func(tag *doltdb.Tag) (stop bool, err error)) error { + tagRefs, err := ddb.GetTags(ctx) + + if err != nil { + return err + } + + var resolved []*doltdb.Tag + for _, r := range tagRefs { + tr, ok := r.(ref.TagRef) + if !ok { + return fmt.Errorf("DoltDB.GetTags() returned non-tag DoltRef") + } + + tag, err := ddb.ResolveTag(ctx, tr) + if err != nil { + return err + } + + resolved = append(resolved, tag) + } + + // iterate newest to oldest + sort.Slice(resolved, func(i, j int) bool { + return resolved[i].Meta.Timestamp > resolved[j].Meta.Timestamp + }) + + for _, tag := range resolved { + stop, err := cb(tag) + + if err != nil { + return err + } + if stop { + break + } + } + return nil +} + // IterResolvedTagsPaginated iterates over tags in dEnv.DoltDB in their default lexicographical order, resolving the tag to a commit and calling cb(). // Returns the next tag name if there are more results available. func IterResolvedTagsPaginated(ctx context.Context, ddb *doltdb.DoltDB, startTag string, cb func(tag *doltdb.Tag) (stop bool, err error)) (string, error) { From d4dd7aa8f81e5bbc8799d2b7603937d0db6525f8 Mon Sep 17 00:00:00 2001 From: Hydrocharged Date: Fri, 31 Jan 2025 11:07:12 +0000 Subject: [PATCH 28/34] [ga-bump-dep] Bump dependency in Dolt by Hydrocharged --- go/go.mod | 2 +- go/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/go.mod b/go/go.mod index 9122e3a0630..2fc96289a25 100644 --- a/go/go.mod +++ b/go/go.mod @@ -56,7 +56,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 github.com/creasty/defaults v1.6.0 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.19.1-0.20250128182847-3f5bb8c52cd8 + github.com/dolthub/go-mysql-server v0.19.1-0.20250131110511-67aa2a430366 github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 github.com/dolthub/swiss v0.1.0 github.com/esote/minmaxheap v1.0.0 diff --git a/go/go.sum b/go/go.sum index 748abe7421f..0af441a1c7c 100644 --- a/go/go.sum +++ b/go/go.sum @@ -179,8 +179,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90 h1:Sni8jrP0sy/w9ZYXoff4g/ixe+7bFCZlfCqXKJSU+zM= github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.19.1-0.20250128182847-3f5bb8c52cd8 h1:eEGYHOC5Ft+56yPaH26gsdbonrZ2EiTwQLy8Oj3TAFE= -github.com/dolthub/go-mysql-server v0.19.1-0.20250128182847-3f5bb8c52cd8/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc= +github.com/dolthub/go-mysql-server v0.19.1-0.20250131110511-67aa2a430366 h1:pJ+upgX6hrhyqgpkmk9Ye9lIPSualMHZcUMs8kWknV4= +github.com/dolthub/go-mysql-server v0.19.1-0.20250131110511-67aa2a430366/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= From 72ec11fe03e2e89a09a747d75707dcf7336f06eb Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Fri, 31 Jan 2025 11:14:00 -0800 Subject: [PATCH 29/34] Include deleteErr message in returned error. --- go/libraries/doltcore/sqle/database_provider.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index 2b6b84b96e8..ec6f2caa13f 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -684,7 +684,8 @@ func (p *DoltDatabaseProvider) CloneDatabaseFromRemote( if exists { deleteErr := p.fs.Delete(dbName, true) if deleteErr != nil { - err = fmt.Errorf("%s: unable to clean up failed clone in directory '%s'", err.Error(), dbName) + err = fmt.Errorf("%s: unable to clean up failed clone in directory '%s': %s", + err.Error(), dbName, deleteErr.Error()) } } return err From eed637b80761f390c02b4cb527746fe0d5d617b1 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Thu, 30 Jan 2025 10:15:06 -0800 Subject: [PATCH 30/34] go: store/types: Fix dolt_gc on databases that use vector indexes. --- go/store/types/serial_message.go | 4 ++-- integration-tests/bats/vector-index.bats | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/go/store/types/serial_message.go b/go/store/types/serial_message.go index 50ee98c3261..e5c4060f147 100644 --- a/go/store/types/serial_message.go +++ b/go/store/types/serial_message.go @@ -769,10 +769,10 @@ func (sm SerialMessage) WalkAddrs(nbf *NomsBinFormat, cb func(addr hash.Hash) er return err } } - case serial.TableSchemaFileID, serial.ForeignKeyCollectionFileID: + case serial.TableSchemaFileID, serial.ForeignKeyCollectionFileID, serial.TupleFileID: // no further references from these file types return nil - case serial.ProllyTreeNodeFileID, serial.AddressMapFileID, serial.MergeArtifactsFileID, serial.BlobFileID, serial.CommitClosureFileID: + case serial.ProllyTreeNodeFileID, serial.AddressMapFileID, serial.MergeArtifactsFileID, serial.BlobFileID, serial.CommitClosureFileID, serial.VectorIndexNodeFileID: return message.WalkAddresses(context.TODO(), serial.Message(sm), func(ctx context.Context, addr hash.Hash) error { return cb(addr) }) diff --git a/integration-tests/bats/vector-index.bats b/integration-tests/bats/vector-index.bats index 5693b7d8628..8d233f60887 100644 --- a/integration-tests/bats/vector-index.bats +++ b/integration-tests/bats/vector-index.bats @@ -430,3 +430,14 @@ SQL [[ "$output" =~ "pk1" ]] || false [[ "${#lines[@]}" = "1" ]] || false } + +@test "vector-index: can GC" { + dolt sql < Date: Tue, 28 Jan 2025 13:30:47 -0800 Subject: [PATCH 31/34] go: sqlserver: Clean up how remotesrv gets a handle to the mrEnv.FileSystem, remove the confusing indirection through ctxFactory. --- go/cmd/dolt/commands/engine/sqlengine.go | 7 +++++++ go/cmd/dolt/commands/sqlserver/server.go | 4 +++- go/libraries/doltcore/sqle/cluster/controller.go | 2 +- go/libraries/doltcore/sqle/remotesrv.go | 16 +++++----------- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/go/cmd/dolt/commands/engine/sqlengine.go b/go/cmd/dolt/commands/engine/sqlengine.go index af9adb0f50f..19aee1af7ea 100644 --- a/go/cmd/dolt/commands/engine/sqlengine.go +++ b/go/cmd/dolt/commands/engine/sqlengine.go @@ -47,6 +47,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/sqle/statspro" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer" "github.com/dolthub/dolt/go/libraries/utils/config" + "github.com/dolthub/dolt/go/libraries/utils/filesys" ) // SqlEngine packages up the context necessary to run sql queries against dsqle. @@ -55,6 +56,7 @@ type SqlEngine struct { contextFactory contextFactory dsessFactory sessionFactory engine *gms.Engine + fs filesys.Filesys } type sessionFactory func(mysqlSess *sql.BaseSession, pro sql.DatabaseProvider) (*dsess.DoltSession, error) @@ -194,6 +196,7 @@ func NewSqlEngine( sqlEngine.contextFactory = sqlContextFactory() sqlEngine.dsessFactory = sessFactory sqlEngine.engine = engine + sqlEngine.fs = pro.FileSystem() // configuring stats depends on sessionBuilder // sessionBuilder needs ref to statsProv @@ -314,6 +317,10 @@ func (se *SqlEngine) GetUnderlyingEngine() *gms.Engine { return se.engine } +func (se *SqlEngine) FileSystem() filesys.Filesys { + return se.fs +} + func (se *SqlEngine) Close() error { if se.engine != nil { if se.engine.Analyzer.Catalog.BinlogReplicaController != nil { diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 9374608c37e..2e3d7f038d7 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -567,7 +567,8 @@ func ConfigureServices( ConcurrencyControl: remotesapi.PushConcurrencyControl_PUSH_CONCURRENCY_CONTROL_ASSERT_WORKING_SET, } var err error - args.FS, args.DBCache, err = sqle.RemoteSrvFSAndDBCache(sqlEngine.NewDefaultContext, sqle.DoNotCreateUnknownDatabases) + args.FS = sqlEngine.FileSystem() + args.DBCache, err = sqle.RemoteSrvDBCache(sqlEngine.NewDefaultContext, sqle.DoNotCreateUnknownDatabases) if err != nil { lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) return err @@ -621,6 +622,7 @@ func ConfigureServices( lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) return err } + args.FS = sqlEngine.FileSystem() clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig()) if err != nil { diff --git a/go/libraries/doltcore/sqle/cluster/controller.go b/go/libraries/doltcore/sqle/cluster/controller.go index 3845a1e9fa8..b09b32b9551 100644 --- a/go/libraries/doltcore/sqle/cluster/controller.go +++ b/go/libraries/doltcore/sqle/cluster/controller.go @@ -690,7 +690,7 @@ func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql. args.GrpcListenAddr = listenaddr args.Options = c.ServerOptions() var err error - args.FS, args.DBCache, err = sqle.RemoteSrvFSAndDBCache(ctxFactory, sqle.CreateUnknownDatabases) + args.DBCache, err = sqle.RemoteSrvDBCache(ctxFactory, sqle.CreateUnknownDatabases) if err != nil { return remotesrv.ServerArgs{}, err } diff --git a/go/libraries/doltcore/sqle/remotesrv.go b/go/libraries/doltcore/sqle/remotesrv.go index 87e6164764e..4564723f90f 100644 --- a/go/libraries/doltcore/sqle/remotesrv.go +++ b/go/libraries/doltcore/sqle/remotesrv.go @@ -22,7 +22,6 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/remotesrv" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" - "github.com/dolthub/dolt/go/libraries/utils/filesys" "github.com/dolthub/dolt/go/store/datas" ) @@ -81,17 +80,12 @@ type CreateUnknownDatabasesSetting bool const CreateUnknownDatabases CreateUnknownDatabasesSetting = true const DoNotCreateUnknownDatabases CreateUnknownDatabasesSetting = false -// Considers |args| and returns a new |remotesrv.ServerArgs| instance which -// will serve databases accessible through |ctxFactory|. -func RemoteSrvFSAndDBCache(ctxFactory func(context.Context) (*sql.Context, error), createSetting CreateUnknownDatabasesSetting) (filesys.Filesys, remotesrv.DBCache, error) { - sqlCtx, err := ctxFactory(context.Background()) - if err != nil { - return nil, nil, err - } - sess := dsess.DSessFromSess(sqlCtx.Session) - fs := sess.Provider().FileSystem() +// Returns a remotesrv.DBCache instance which will use the *sql.Context +// returned from |ctxFactory| to access a database in the session +// DatabaseProvider. +func RemoteSrvDBCache(ctxFactory func(context.Context) (*sql.Context, error), createSetting CreateUnknownDatabasesSetting) (remotesrv.DBCache, error) { dbcache := remotesrvStore{ctxFactory, bool(createSetting)} - return fs, dbcache, nil + return dbcache, nil } func WithUserPasswordAuth(args remotesrv.ServerArgs, authnz remotesrv.AccessControl) remotesrv.ServerArgs { From 86acc8c57062afcf219970b4e4c5989a959b2407 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Tue, 28 Jan 2025 14:28:42 -0800 Subject: [PATCH 32/34] go: sqle,remotesrv: Implement sql.Session lifecycle callbacks for sql.Contexts used in remotesrv RPCs. This PR changes each RPC invocation against the gRPC and HTTP servers implementing remotesapi and cluster replication to create a sql.Context which lives the duration of the call. The Session for that call gets SessionCommand{Begin,End} and SessionEnd lifecycle callbacks so that it can participate in GC safepoint rendezvous appropriately. Previously the remotesrv.DBCache and the user/password remotesapi authentication implementation would simply create new sql.Contexts whenever they needed them. There could be multiple sql.Contexts for the single server call. --- go/cmd/dolt/commands/sqlserver/server.go | 11 ++- .../doltcore/sqle/cluster/controller.go | 11 ++- go/libraries/doltcore/sqle/cluster/jwks.go | 9 +- go/libraries/doltcore/sqle/remotesrv.go | 88 +++++++++++++++++++ 4 files changed, 111 insertions(+), 8 deletions(-) diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 2e3d7f038d7..20e538c8e9a 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -559,22 +559,27 @@ func ConfigureServices( } listenaddr := fmt.Sprintf(":%d", port) + sqlContextInterceptor := sqle.SqlContextServerInterceptor{ + Factory: sqlEngine.NewDefaultContext, + } args := remotesrv.ServerArgs{ Logger: logrus.NewEntry(lgr), ReadOnly: apiReadOnly || serverConfig.ReadOnly(), HttpListenAddr: listenaddr, GrpcListenAddr: listenaddr, ConcurrencyControl: remotesapi.PushConcurrencyControl_PUSH_CONCURRENCY_CONTROL_ASSERT_WORKING_SET, + Options: sqlContextInterceptor.Options(), + HttpInterceptor: sqlContextInterceptor.HTTP(nil), } var err error args.FS = sqlEngine.FileSystem() - args.DBCache, err = sqle.RemoteSrvDBCache(sqlEngine.NewDefaultContext, sqle.DoNotCreateUnknownDatabases) + args.DBCache, err = sqle.RemoteSrvDBCache(sqle.GetInterceptorSqlContext, sqle.DoNotCreateUnknownDatabases) if err != nil { lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) return err } - authenticator := newAccessController(sqlEngine.NewDefaultContext, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb) + authenticator := newAccessController(sqle.GetInterceptorSqlContext, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb) args = sqle.WithUserPasswordAuth(args, authenticator) args.TLSConfig = serverConf.TLSConfig @@ -636,7 +641,7 @@ func ConfigureServices( lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err) return err } - clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.srv.GrpcServer()) + clusterController.RegisterGrpcServices(sqle.GetInterceptorSqlContext, clusterRemoteSrv.srv.GrpcServer()) clusterRemoteSrv.lis, err = clusterRemoteSrv.srv.Listeners() if err != nil { diff --git a/go/libraries/doltcore/sqle/cluster/controller.go b/go/libraries/doltcore/sqle/cluster/controller.go index b09b32b9551..4be3f36b341 100644 --- a/go/libraries/doltcore/sqle/cluster/controller.go +++ b/go/libraries/doltcore/sqle/cluster/controller.go @@ -688,9 +688,14 @@ func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql. listenaddr := c.RemoteSrvListenAddr() args.HttpListenAddr = listenaddr args.GrpcListenAddr = listenaddr - args.Options = c.ServerOptions() + ctxInterceptor := sqle.SqlContextServerInterceptor{ + Factory: ctxFactory, + } + args.Options = append(args.Options, ctxInterceptor.Options()...) + args.Options = append(args.Options, c.ServerOptions()...) + args.HttpInterceptor = ctxInterceptor.HTTP(args.HttpInterceptor) var err error - args.DBCache, err = sqle.RemoteSrvDBCache(ctxFactory, sqle.CreateUnknownDatabases) + args.DBCache, err = sqle.RemoteSrvDBCache(sqle.GetInterceptorSqlContext, sqle.CreateUnknownDatabases) if err != nil { return remotesrv.ServerArgs{}, err } @@ -699,7 +704,7 @@ func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql. keyID := creds.PubKeyToKID(c.pub) keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID) - args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, c.pub) + args.HttpInterceptor = JWKSHandlerInterceptor(args.HttpInterceptor, keyIDStr, c.pub) return args, nil } diff --git a/go/libraries/doltcore/sqle/cluster/jwks.go b/go/libraries/doltcore/sqle/cluster/jwks.go index 1e3c357c1c3..36511ae62be 100644 --- a/go/libraries/doltcore/sqle/cluster/jwks.go +++ b/go/libraries/doltcore/sqle/cluster/jwks.go @@ -46,16 +46,21 @@ func (h JWKSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Write(b) } -func JWKSHandlerInterceptor(keyID string, pub ed25519.PublicKey) func(http.Handler) http.Handler { +func JWKSHandlerInterceptor(existing func(http.Handler) http.Handler, keyID string, pub ed25519.PublicKey) func(http.Handler) http.Handler { jh := JWKSHandler{KeyID: keyID, PublicKey: pub} return func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + this := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.EscapedPath() == "/.well-known/jwks.json" { jh.ServeHTTP(w, r) return } h.ServeHTTP(w, r) }) + if existing != nil { + return existing(this) + } else { + return this + } } } diff --git a/go/libraries/doltcore/sqle/remotesrv.go b/go/libraries/doltcore/sqle/remotesrv.go index 4564723f90f..1f899c2a7ac 100644 --- a/go/libraries/doltcore/sqle/remotesrv.go +++ b/go/libraries/doltcore/sqle/remotesrv.go @@ -16,8 +16,11 @@ package sqle import ( "context" + "errors" + "net/http" "github.com/dolthub/go-mysql-server/sql" + "google.golang.org/grpc" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/remotesrv" @@ -96,3 +99,88 @@ func WithUserPasswordAuth(args remotesrv.ServerArgs, authnz remotesrv.AccessCont args.Options = append(args.Options, si.Options()...) return args } + +type SqlContextServerInterceptor struct { + Factory func(context.Context) (*sql.Context, error) +} + +type serverStreamWrapper struct { + grpc.ServerStream + ctx context.Context +} + +func (s serverStreamWrapper) Context() context.Context { + return s.ctx +} + +type sqlContextInterceptorKey struct{} + +func GetInterceptorSqlContext(ctx context.Context) (*sql.Context, error) { + if v := ctx.Value(sqlContextInterceptorKey{}); v != nil { + return v.(*sql.Context), nil + } + return nil, errors.New("misconfiguration; a sql.Context should always be available from the intercetpor chain.") +} + +func (si SqlContextServerInterceptor) Stream() grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + sqlCtx, err := si.Factory(ss.Context()) + if err != nil { + return err + } + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + defer sql.SessionEnd(sqlCtx.Session) + newCtx := context.WithValue(ss.Context(), sqlContextInterceptorKey{}, sqlCtx) + newSs := serverStreamWrapper{ + ServerStream: ss, + ctx: newCtx, + } + return handler(srv, newSs) + } +} + +func (si SqlContextServerInterceptor) Unary() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + sqlCtx, err := si.Factory(ctx) + if err != nil { + return nil, err + } + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + defer sql.SessionEnd(sqlCtx.Session) + newCtx := context.WithValue(ctx, sqlContextInterceptorKey{}, sqlCtx) + return handler(newCtx, req) + } +} + +func (si SqlContextServerInterceptor) HTTP(existing func(http.Handler) http.Handler) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + this := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + sqlCtx, err := si.Factory(ctx) + if err != nil { + http.Error(w, "could not initialize sql.Context", http.StatusInternalServerError) + return + } + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + defer sql.SessionEnd(sqlCtx.Session) + newCtx := context.WithValue(ctx, sqlContextInterceptorKey{}, sqlCtx) + newReq := r.WithContext(newCtx) + h.ServeHTTP(w, newReq) + }) + if existing != nil { + return existing(this) + } else { + return this + } + } +} + +func (si SqlContextServerInterceptor) Options() []grpc.ServerOption { + return []grpc.ServerOption{ + grpc.ChainUnaryInterceptor(si.Unary()), + grpc.ChainStreamInterceptor(si.Stream()), + } +} From ff7aebbd4eab8cbac55aba48e5890078ce49b44e Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Thu, 30 Jan 2025 10:20:00 -0800 Subject: [PATCH 33/34] Fix typo in error message. Co-authored-by: Maximilian Hoffman --- go/libraries/doltcore/sqle/remotesrv.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/libraries/doltcore/sqle/remotesrv.go b/go/libraries/doltcore/sqle/remotesrv.go index 1f899c2a7ac..3d63cb68e74 100644 --- a/go/libraries/doltcore/sqle/remotesrv.go +++ b/go/libraries/doltcore/sqle/remotesrv.go @@ -119,7 +119,7 @@ func GetInterceptorSqlContext(ctx context.Context) (*sql.Context, error) { if v := ctx.Value(sqlContextInterceptorKey{}); v != nil { return v.(*sql.Context), nil } - return nil, errors.New("misconfiguration; a sql.Context should always be available from the intercetpor chain.") + return nil, errors.New("misconfiguration; a sql.Context should always be available from the interceptor chain.") } func (si SqlContextServerInterceptor) Stream() grpc.StreamServerInterceptor { From 29cb66ef19c59a6cc454af151fba71fedb0375ad Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Fri, 31 Jan 2025 13:59:03 -0800 Subject: [PATCH 34/34] go/cmd/dolt/commands: Declare bankruptcy on signed commit test for now. --- go/cmd/dolt/commands/signed_commits_test.go | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/go/cmd/dolt/commands/signed_commits_test.go b/go/cmd/dolt/commands/signed_commits_test.go index 2e73d280026..e83d1132bf3 100644 --- a/go/cmd/dolt/commands/signed_commits_test.go +++ b/go/cmd/dolt/commands/signed_commits_test.go @@ -24,7 +24,6 @@ import ( "testing" "github.com/dolthub/dolt/go/cmd/dolt/cli" - "github.com/dolthub/dolt/go/libraries/doltcore/dbfactory" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/utils/argparser" "github.com/dolthub/dolt/go/libraries/utils/config" @@ -46,9 +45,10 @@ func importKey(t *testing.T, ctx context.Context) { } func setupTestDB(t *testing.T, ctx context.Context, fs filesys.Filesys) string { - dir := t.TempDir() + dir, err := os.MkdirTemp(os.TempDir(), "signed_commits") + require.NoError(t, err) dbDir := filepath.Join(dir, "db") - err := filesys.CopyDir("testdata/signed_commits/db/", dbDir, fs) + err = filesys.CopyDir("testdata/signed_commits/db/", dbDir, fs) require.NoError(t, err) log.Println(dbDir) @@ -79,9 +79,6 @@ func TestSignAndVerifyCommit(t *testing.T) { ctx := context.Background() importKey(t, ctx) dbDir := setupTestDB(t, ctx, filesys.LocalFS) - t.Cleanup(func() { - dbfactory.CloseAllLocalDatabases() - }) global := map[string]string{ "user.name": "First Last", @@ -93,7 +90,7 @@ func TestSignAndVerifyCommit(t *testing.T) { apr, err := cli.CreateCommitArgParser().Parse(test.commitArgs) require.NoError(t, err) - _, err = execCommand(ctx, t, dbDir, CommitCmd{}, test.commitArgs, apr, map[string]string{}, global) + _, err = execCommand(ctx, dbDir, CommitCmd{}, test.commitArgs, apr, map[string]string{}, global) if test.expectErr { require.Error(t, err) @@ -106,14 +103,14 @@ func TestSignAndVerifyCommit(t *testing.T) { apr, err = cli.CreateLogArgParser(false).Parse(args) require.NoError(t, err) - logOutput, err := execCommand(ctx, t, dbDir, LogCmd{}, args, apr, map[string]string{}, global) + logOutput, err := execCommand(ctx, dbDir, LogCmd{}, args, apr, map[string]string{}, global) require.NoError(t, err) require.Contains(t, logOutput, "Good signature from \"Test User \"") }) } } -func execCommand(ctx context.Context, t *testing.T, wd string, cmd cli.Command, args []string, apr *argparser.ArgParseResults, local, global map[string]string) (output string, err error) { +func execCommand(ctx context.Context, wd string, cmd cli.Command, args []string, apr *argparser.ArgParseResults, local, global map[string]string) (output string, err error) { err = os.Chdir(wd) if err != nil { err = fmt.Errorf("error changing directory to %s: %w", wd, err) @@ -160,7 +157,7 @@ func execCommand(ctx context.Context, t *testing.T, wd string, cmd cli.Command, initialOut := os.Stdout initialErr := os.Stderr - f, err := os.CreateTemp(t.TempDir(), "signed-commit-test-*") + f, err := os.CreateTemp(os.TempDir(), "signed-commit-test-*") if err != nil { err = fmt.Errorf("error creating temp file: %w", err) return