diff --git a/internal/sqlitedb/init.go b/internal/sqlitedb/init.go index 56a1b04..5088659 100644 --- a/internal/sqlitedb/init.go +++ b/internal/sqlitedb/init.go @@ -19,12 +19,38 @@ func NewDB(logger *slog.Logger, open func(driverName string, dataSourceName stri if useLocalDB == "true" { logger.Info("using local database") - return NewFileDB() + return NewFileDB(open) } return NewTursoDB(open) } +func NewFileDB(open func(driverName string, dataSourceName string) (*sql.DB, error)) (*sql.DB, error) { + workingDir, err := os.Getwd() + + if err != nil { + return nil, errors.Wrap(err, "error getting working directory") + } + + // migrationsDir := workingDir + "/sql/migrations" + localDbDir := workingDir + "/db" + localDbPath := localDbDir + "/database.db" + + if _, err := os.Stat(localDbDir); os.IsNotExist(err) { + err := os.Mkdir(localDbDir, os.ModePerm) + if err != nil { + return nil, err + } + } + + _, err = os.OpenFile(localDbPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, err + } + + return open("sqlite", localDbPath) +} + func NewTursoDB(open func(driverName string, dataSourceName string) (*sql.DB, error)) (*sql.DB, error) { dbURL := os.Getenv("DATABASE_URL") diff --git a/internal/sqlitedb/init_test.go b/internal/sqlitedb/init_test.go index 186d036..d8d8e6b 100644 --- a/internal/sqlitedb/init_test.go +++ b/internal/sqlitedb/init_test.go @@ -19,8 +19,38 @@ func (m *MockSqlOpener) Open(driverName string, dataSourceName string) (*sql.DB, return args.Get(0).(*sql.DB), args.Error(1) } -func TestNewDB(t *testing.T) { - t.Run("returns a new db connection", func(t *testing.T) { +func TestNewFileDB(t *testing.T) { + t.Run("returns a new file db connection", func(t *testing.T) { + mockOpener := setupFileDBTest(t) + + db, err := NewFileDB(mockOpener.Open) + + mockOpener.AssertExpectations(t) + assert.NotNil(t, db) + assert.NoError(t, err) + }) +} + +func setupFileDBTest(t *testing.T) *MockSqlOpener { + t.Helper() + + os.Setenv("LOCAL_DB", "true") + defer os.Setenv("LOCAL_DB", "") + workingDir, err := os.Getwd() + + if err != nil { + t.Fatal("could not get working directory") + } + + mockOpener := new(MockSqlOpener) + mockDb := &sql.DB{} + mockOpener.On("Open", "sqlite", workingDir+"/db/database.db").Return(mockDb, nil) + + return mockOpener +} + +func TestNewTursoDB(t *testing.T) { + t.Run("returns a new turso db connection", func(t *testing.T) { dataSourceName := "libsql://jobsummoner.turso.io/db" os.Setenv("DATABASE_URL", dataSourceName) defer os.Setenv("DATABASE_URL", "") diff --git a/internal/sqlitedb/local.go b/internal/sqlitedb/local.go index d609f64..dbc8b5a 100644 --- a/internal/sqlitedb/local.go +++ b/internal/sqlitedb/local.go @@ -15,32 +15,6 @@ func init() { sql.Register("sqlite3", &sqlite.Driver{}) } -func NewFileDB() (*sql.DB, error) { - workingDir, err := os.Getwd() - - if err != nil { - return nil, errors.Wrap(err, "error getting working directory") - } - - migrationsDir := workingDir + "/sql/migrations" - localDbDir := workingDir + "/db" - localDbPath := localDbDir + "/database.db" - - if _, err := os.Stat(localDbDir); os.IsNotExist(err) { - err := os.Mkdir(localDbDir, os.ModePerm) - if err != nil { - return nil, err - } - } - - _, err = os.OpenFile(localDbPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - return nil, err - } - - return migrateLocalDB(localDbPath, migrationsDir) -} - func NewInMemoryDB() (*sql.DB, error) { workingDir, err := os.Getwd()