diff --git a/docker-compose.yaml b/docker-compose.yaml index 3c5186a..449a350 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -15,8 +15,8 @@ services: max-size: 10m ports: - '3000:3000' - image: b44427a64de93c20123c068387b0adc0434434ba709fbd91dd03d33ade489c3e - container_name: sshdbg + image: ssh-sync-server-prerelease + container_name: ssh-sync-server ssh-sync-db: image: therealpaulgg/ssh-sync-db:latest container_name: ssh-sync-db-debug @@ -26,18 +26,28 @@ services: - POSTGRES_DB=sshsync restart: always ssh-sync: - image: 62eab8fb32b34e0a2cf36e8635d810c20a38baa2d7beaf5b6918139339e23c23 + image: ssh-debug container_name: ssh-sync stdin_open: true # Allows Docker container to keep STDIN open tty: true # Allocates a pseudo-TTY + volumes: + - ssh-sync-volume:/root ssh-sync-2: - image: 62eab8fb32b34e0a2cf36e8635d810c20a38baa2d7beaf5b6918139339e23c23 + image: ssh-debug container_name: ssh-sync-2 stdin_open: true # Allows Docker container to keep STDIN open tty: true # Allocates a pseudo-TTY + volumes: + - ssh-sync-2-volume:/root ssh-sync-3: - image: 62eab8fb32b34e0a2cf36e8635d810c20a38baa2d7beaf5b6918139339e23c23 + image: ssh-debug container_name: ssh-sync-3 stdin_open: true # Allows Docker container to keep STDIN open tty: true # Allocates a pseudo-TTY - #http://ssh-sync-server-debug:3000 \ No newline at end of file + volumes: + - ssh-sync-3-volume:/root + +volumes: + ssh-sync-volume: + ssh-sync-2-volume: + ssh-sync-3-volume: \ No newline at end of file diff --git a/pkg/database/query/transaction.go b/pkg/database/query/transaction.go index 64dfd38..979c202 100644 --- a/pkg/database/query/transaction.go +++ b/pkg/database/query/transaction.go @@ -2,9 +2,11 @@ package query import ( "context" + "net/http" "github.com/georgysavva/scany/v2/pgxscan" "github.com/jackc/pgx/v5" + "github.com/rs/zerolog/log" "github.com/therealpaulgg/ssh-sync-server/pkg/database" ) @@ -63,3 +65,22 @@ func (q *QueryServiceTxImpl[T]) Insert(tx pgx.Tx, query string, args ...interfac _, err := tx.Exec(context.Background(), query, args...) return err } + +func RollbackFunc(txQueryService TransactionService, tx pgx.Tx, w http.ResponseWriter, err *error) { + rb := func(tx pgx.Tx) { + err := txQueryService.Rollback(tx) + if err != nil { + log.Err(err).Msg("error rolling back transaction") + } + } + if *err != nil { + rb(tx) + } else { + internalErr := txQueryService.Commit(tx) + if internalErr != nil { + log.Err(internalErr).Msg("error committing transaction") + rb(tx) + w.WriteHeader(http.StatusInternalServerError) + } + } +} diff --git a/pkg/database/repository/machine.go b/pkg/database/repository/machine.go index b65099b..4049329 100644 --- a/pkg/database/repository/machine.go +++ b/pkg/database/repository/machine.go @@ -45,8 +45,7 @@ func (repo *MachineRepo) DeleteMachine(id uuid.UUID) error { if _, err := tx.Exec(context.TODO(), "delete from machines where id = $1", id); err != nil { return err } - err = tx.Commit(context.TODO()) - return err + return tx.Commit(context.TODO()) } func (repo *MachineRepo) GetMachine(id uuid.UUID) (*models.Machine, error) { diff --git a/pkg/database/repository/user.go b/pkg/database/repository/user.go index 2670438..a2036a0 100644 --- a/pkg/database/repository/user.go +++ b/pkg/database/repository/user.go @@ -23,10 +23,12 @@ type UserRepository interface { DeleteUser(id uuid.UUID) error GetUserConfig(id uuid.UUID) ([]models.SshConfig, error) GetUserKeys(id uuid.UUID) ([]models.SshKey, error) + GetUserKey(userId uuid.UUID, keyId uuid.UUID) (*models.SshKey, error) AddAndUpdateKeys(user *models.User) error AddAndUpdateKeysTx(user *models.User, tx pgx.Tx) error AddAndUpdateConfig(user *models.User) error AddAndUpdateConfigTx(user *models.User, tx pgx.Tx) error + DeleteUserKeyTx(user *models.User, id uuid.UUID, tx pgx.Tx) error } type UserRepo struct { @@ -195,3 +197,20 @@ func (repo *UserRepo) AddAndUpdateConfigTx(user *models.User, tx pgx.Tx) error { } return nil } + +func (repo *UserRepo) GetUserKey(userId uuid.UUID, keyId uuid.UUID) (*models.SshKey, error) { + q := do.MustInvoke[query.QueryService[models.SshKey]](repo.Injector) + key, err := q.QueryOne("select * from ssh_keys where user_id = $1 and id = $2", userId, keyId) + if err != nil { + return nil, err + } + if key == nil { + return nil, sql.ErrNoRows + } + return key, nil +} + +func (repo *UserRepo) DeleteUserKeyTx(user *models.User, id uuid.UUID, tx pgx.Tx) error { + _, err := tx.Exec(context.TODO(), "delete from ssh_keys where user_id = $1 and id = $2", user.ID, id) + return err +} diff --git a/pkg/database/repository/usermock.go b/pkg/database/repository/usermock.go index 5345a84..bc5fbd4 100644 --- a/pkg/database/repository/usermock.go +++ b/pkg/database/repository/usermock.go @@ -136,6 +136,20 @@ func (mr *MockUserRepositoryMockRecorder) DeleteUser(id interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUser", reflect.TypeOf((*MockUserRepository)(nil).DeleteUser), id) } +// DeleteUserKeyTx mocks base method. +func (m *MockUserRepository) DeleteUserKeyTx(user *models.User, id uuid.UUID, tx pgx.Tx) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUserKeyTx", user, id, tx) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUserKeyTx indicates an expected call of DeleteUserKeyTx. +func (mr *MockUserRepositoryMockRecorder) DeleteUserKeyTx(user, id, tx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserKeyTx", reflect.TypeOf((*MockUserRepository)(nil).DeleteUserKeyTx), user, id, tx) +} + // GetUser mocks base method. func (m *MockUserRepository) GetUser(id uuid.UUID) (*models.User, error) { m.ctrl.T.Helper() @@ -181,6 +195,21 @@ func (mr *MockUserRepositoryMockRecorder) GetUserConfig(id interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserConfig", reflect.TypeOf((*MockUserRepository)(nil).GetUserConfig), id) } +// GetUserKey mocks base method. +func (m *MockUserRepository) GetUserKey(userId, keyId uuid.UUID) (*models.SshKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserKey", userId, keyId) + ret0, _ := ret[0].(*models.SshKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserKey indicates an expected call of GetUserKey. +func (mr *MockUserRepositoryMockRecorder) GetUserKey(userId, keyId interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserKey", reflect.TypeOf((*MockUserRepository)(nil).GetUserKey), userId, keyId) +} + // GetUserKeys mocks base method. func (m *MockUserRepository) GetUserKeys(id uuid.UUID) ([]models.SshKey, error) { m.ctrl.T.Helper() diff --git a/pkg/web/router/routes/data.go b/pkg/web/router/routes/data.go index 1247c54..fc80f6a 100644 --- a/pkg/web/router/routes/data.go +++ b/pkg/web/router/routes/data.go @@ -8,6 +8,7 @@ import ( "net/http" "github.com/go-chi/chi" + "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" "github.com/samber/do" @@ -115,24 +116,7 @@ func addData(i *do.Injector) http.HandlerFunc { w.WriteHeader(http.StatusInternalServerError) return } - defer func() { - rb := func(tx pgx.Tx) { - err := txQueryService.Rollback(tx) - if err != nil { - log.Err(err).Msg("error rolling back transaction") - } - } - if err != nil { - rb(tx) - } else { - internalErr := txQueryService.Commit(tx) - if internalErr != nil { - log.Err(err).Msg("error committing transaction") - rb(tx) - w.WriteHeader(http.StatusInternalServerError) - } - } - }() + defer query.RollbackFunc(txQueryService, tx, w, &err) if err = userRepo.AddAndUpdateConfigTx(user, tx); err != nil { log.Err(err).Msg("could not add config") w.WriteHeader(http.StatusInternalServerError) @@ -168,10 +152,49 @@ func addData(i *do.Injector) http.HandlerFunc { } } +func deleteData(i *do.Injector) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user, ok := r.Context().Value(context_keys.UserContextKey).(*models.User) + if !ok { + log.Err(errors.New("could not get user from context")) + w.WriteHeader(http.StatusInternalServerError) + return + } + keyIdStr := chi.URLParam(r, "id") + keyId, err := uuid.Parse(keyIdStr) + if err != nil { + log.Err(err).Msg("could not parse key id") + w.WriteHeader(http.StatusBadRequest) + return + } + userRepo := do.MustInvoke[repository.UserRepository](i) + key, err := userRepo.GetUserKey(user.ID, keyId) + if err != nil { + log.Err(err).Msg("could not get key") + w.WriteHeader(http.StatusNotFound) + return + } + txQueryService := do.MustInvoke[query.TransactionService](i) + tx, err := txQueryService.StartTx(pgx.TxOptions{}) + if err != nil { + log.Err(err).Msg("error starting transaction") + w.WriteHeader(http.StatusInternalServerError) + return + } + defer query.RollbackFunc(txQueryService, tx, w, &err) + if err = userRepo.DeleteUserKeyTx(user, key.ID, tx); err != nil { + log.Err(err).Msg("could not delete key") + w.WriteHeader(http.StatusInternalServerError) + return + } + } +} + func DataRoutes(i *do.Injector) chi.Router { r := chi.NewRouter() r.Use(middleware.ConfigureAuth(i)) r.Get("/", getData(i)) r.Post("/", addData(i)) + r.Delete("/key/{id}", deleteData(i)) return r } diff --git a/pkg/web/router/routes/data_test.go b/pkg/web/router/routes/data_test.go index cdebfed..d574940 100644 --- a/pkg/web/router/routes/data_test.go +++ b/pkg/web/router/routes/data_test.go @@ -5,11 +5,13 @@ import ( "crypto/rand" "encoding/json" "errors" + "fmt" "mime/multipart" "net/http" "net/http/httptest" "testing" + "github.com/go-chi/chi" "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/samber/do" @@ -258,3 +260,84 @@ func TestAddDataError(t *testing.T) { status, http.StatusOK) } } + +func TestDeleteKey(t *testing.T) { + // Arrange + keyId := uuid.New() + req := httptest.NewRequest("DELETE", fmt.Sprintf("/%s", keyId.String()), nil) + user := testutils.GenerateUser() + req = testutils.AddUserContext(req, user) + key := &models.SshKey{ + ID: keyId, + UserID: user.ID, + } + + injector := do.New() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockUserRepo := repository.NewMockUserRepository(ctrl) + txMock := pgx.NewMockTx(ctrl) + mockUserRepo.EXPECT().GetUserKey(user.ID, keyId).Return(key, nil) + mockUserRepo.EXPECT().DeleteUserKeyTx(gomock.Any(), keyId, txMock).Return(nil) + do.Provide(injector, func(i *do.Injector) (repository.UserRepository, error) { + return mockUserRepo, nil + }) + mockTransactionService := query.NewMockTransactionService(ctrl) + mockTransactionService.EXPECT().StartTx(gomock.Any()).Return(txMock, nil) + mockTransactionService.EXPECT().Commit(txMock).Return(nil) + do.Provide(injector, func(i *do.Injector) (query.TransactionService, error) { + return mockTransactionService, nil + }) + // Act + rr := httptest.NewRecorder() + handler := chi.NewRouter() + handler.Delete("/{id}", deleteData(injector)) + handler.ServeHTTP(rr, req) + // Assert + if status := rr.Code; status != http.StatusOK { + t.Errorf("deleteData returned wrong status code: got %v want %v", + status, http.StatusOK) + } +} + +func TestDeleteKeyError(t *testing.T) { + // Arrange + keyId := uuid.New() + req := httptest.NewRequest("DELETE", fmt.Sprintf("/%s", keyId.String()), nil) + user := testutils.GenerateUser() + req = testutils.AddUserContext(req, user) + key := &models.SshKey{ + ID: keyId, + UserID: user.ID, + } + + injector := do.New() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockUserRepo := repository.NewMockUserRepository(ctrl) + txMock := pgx.NewMockTx(ctrl) + mockUserRepo.EXPECT().GetUserKey(user.ID, keyId).Return(key, nil) + mockUserRepo.EXPECT().DeleteUserKeyTx(gomock.Any(), keyId, txMock).Return(errors.New("error")) + do.Provide(injector, func(i *do.Injector) (repository.UserRepository, error) { + return mockUserRepo, nil + }) + mockTransactionService := query.NewMockTransactionService(ctrl) + mockTransactionService.EXPECT().StartTx(gomock.Any()).Return(txMock, nil) + mockTransactionService.EXPECT().Rollback(txMock).Return(nil) + do.Provide(injector, func(i *do.Injector) (query.TransactionService, error) { + return mockTransactionService, nil + }) + + // Act + rr := httptest.NewRecorder() + handler := chi.NewRouter() + handler.Delete("/{id}", deleteData(injector)) + handler.ServeHTTP(rr, req) + + // Assert + + if status := rr.Code; status != http.StatusInternalServerError { + t.Errorf("deleteData returned wrong status code: got %v want %v", + status, http.StatusInternalServerError) + } +} diff --git a/pkg/web/router/routes/setup.go b/pkg/web/router/routes/setup.go index eddf704..ec56315 100644 --- a/pkg/web/router/routes/setup.go +++ b/pkg/web/router/routes/setup.go @@ -67,24 +67,7 @@ func initialSetup(i *do.Injector) http.HandlerFunc { w.WriteHeader(http.StatusInternalServerError) return } - defer func() { - rb := func(tx pgx.Tx) { - err := txQueryService.Rollback(tx) - if err != nil { - log.Err(err).Msg("error rolling back transaction") - } - } - if err != nil { - rb(tx) - } else { - internalErr := txQueryService.Commit(tx) - if internalErr != nil { - log.Err(err).Msg("error committing transaction") - rb(tx) - w.WriteHeader(http.StatusInternalServerError) - } - } - }() + defer query.RollbackFunc(txQueryService, tx, w, &err) userRepo := do.MustInvoke[repository.UserRepository](i) user := &models.User{} user.Username = userDto.Username