From de7ed9676d8eae7f1712e288779fafdd47d4bd2b Mon Sep 17 00:00:00 2001 From: fibu0125 Date: Mon, 3 Feb 2025 17:43:51 +0500 Subject: [PATCH] Improvement: Complete refactoring (Cleaning according to linter bugs). make minor improvements, namely add gracefully shutdown. --- .golangci.yml | 9 ++++ api/create.role.go | 3 +- api/create.rolesmapping.go | 3 +- api/create.user.go | 3 +- api/delete.role.go | 2 +- api/delete.user.go | 2 +- api/get.role.go | 2 +- api/get.rolemapping.go | 1 + api/get.roles.go | 2 +- api/get.rolesmapping.go | 2 +- api/get.user.go | 2 +- api/get.users.go | 1 + api/patch.user.go | 1 + api/patch.users.go | 2 +- backup/backup.go | 39 +++++++++-------- basic/basic.go | 29 ++++++++---- basic/basic_multiusers_test.go | 3 +- basic/basic_test.go | 18 +++++--- basic/role.go | 4 +- basic/users_recovery.go | 7 ++- cmd/main.go | 16 ++++++- common/common.go | 17 +++++++- physical/physical.go | 30 ++++++++++--- server/server.go | 80 +++++++++++++++++++++++++--------- 24 files changed, 197 insertions(+), 81 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 118b30e..d7110e9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -14,6 +14,11 @@ linters: # - prealloc # - gocritic linters-settings: + staticcheck: + checks: + - all + - "-SA1019" + - "-SA1029" errcheck: check-type-assertions: true check-blank: true @@ -33,6 +38,10 @@ linters-settings: # - performance # - style +issues: + exclude-files: + - _test\.go + run: timeout: 10m go: '1.22' diff --git a/api/create.role.go b/api/create.role.go index 361251b..1713f75 100644 --- a/api/create.role.go +++ b/api/create.role.go @@ -96,6 +96,7 @@ func (r CreateRoleRequest) Do(ctx context.Context, transport opensearchapi.Trans if err != nil { return nil, err } + defer req.Body.Close() if len(params) > 0 { q := req.URL.Query() @@ -120,7 +121,7 @@ func (r CreateRoleRequest) Do(ctx context.Context, transport opensearchapi.Trans if ctx != nil { req = req.WithContext(ctx) } - + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/create.rolesmapping.go b/api/create.rolesmapping.go index e7ac1d6..b353fa6 100644 --- a/api/create.rolesmapping.go +++ b/api/create.rolesmapping.go @@ -96,6 +96,7 @@ func (r CreateRolesMappingRequest) Do(ctx context.Context, transport opensearcha if err != nil { return nil, err } + defer req.Body.Close() if len(params) > 0 { q := req.URL.Query() @@ -120,7 +121,7 @@ func (r CreateRolesMappingRequest) Do(ctx context.Context, transport opensearcha if ctx != nil { req = req.WithContext(ctx) } - + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/create.user.go b/api/create.user.go index 677f80c..e4882cb 100644 --- a/api/create.user.go +++ b/api/create.user.go @@ -96,6 +96,7 @@ func (r CreateUserRequest) Do(ctx context.Context, transport opensearchapi.Trans if err != nil { return nil, err } + defer req.Body.Close() if len(params) > 0 { q := req.URL.Query() @@ -120,7 +121,7 @@ func (r CreateUserRequest) Do(ctx context.Context, transport opensearchapi.Trans if ctx != nil { req = req.WithContext(ctx) } - + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/delete.role.go b/api/delete.role.go index dcde573..30f8c5d 100644 --- a/api/delete.role.go +++ b/api/delete.role.go @@ -117,7 +117,7 @@ func (r DeleteRoleRequest) Do(ctx context.Context, transport opensearchapi.Trans if ctx != nil { req = req.WithContext(ctx) } - + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/delete.user.go b/api/delete.user.go index a586ab7..1c9d7fe 100644 --- a/api/delete.user.go +++ b/api/delete.user.go @@ -117,7 +117,7 @@ func (r DeleteUserRequest) Do(ctx context.Context, transport opensearchapi.Trans if ctx != nil { req = req.WithContext(ctx) } - + // nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/get.role.go b/api/get.role.go index 8535f83..03b6161 100644 --- a/api/get.role.go +++ b/api/get.role.go @@ -117,7 +117,7 @@ func (r GetRoleRequest) Do(ctx context.Context, transport opensearchapi.Transpor if ctx != nil { req = req.WithContext(ctx) } - + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/get.rolemapping.go b/api/get.rolemapping.go index 895642a..3f3fe71 100644 --- a/api/get.rolemapping.go +++ b/api/get.rolemapping.go @@ -118,6 +118,7 @@ func (r GetRoleMappingRequest) Do(ctx context.Context, transport opensearchapi.T req = req.WithContext(ctx) } + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/get.roles.go b/api/get.roles.go index 0065636..7502b00 100644 --- a/api/get.roles.go +++ b/api/get.roles.go @@ -113,7 +113,7 @@ func (r GetRolesRequest) Do(ctx context.Context, transport opensearchapi.Transpo if ctx != nil { req = req.WithContext(ctx) } - + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/get.rolesmapping.go b/api/get.rolesmapping.go index 4e155e6..77aaef0 100644 --- a/api/get.rolesmapping.go +++ b/api/get.rolesmapping.go @@ -114,7 +114,7 @@ func (r GetRolesMappingRequest) Do(ctx context.Context, transport opensearchapi. if ctx != nil { req = req.WithContext(ctx) } - + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/get.user.go b/api/get.user.go index b6e3c51..550e0a1 100644 --- a/api/get.user.go +++ b/api/get.user.go @@ -117,7 +117,7 @@ func (r GetUserRequest) Do(ctx context.Context, transport opensearchapi.Transpor if ctx != nil { req = req.WithContext(ctx) } - + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/get.users.go b/api/get.users.go index 1e4bbf6..567fc30 100644 --- a/api/get.users.go +++ b/api/get.users.go @@ -114,6 +114,7 @@ func (r GetUsersRequest) Do(ctx context.Context, transport opensearchapi.Transpo req = req.WithContext(ctx) } + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/patch.user.go b/api/patch.user.go index 3af7b79..8c37b30 100644 --- a/api/patch.user.go +++ b/api/patch.user.go @@ -121,6 +121,7 @@ func (r PatchUserRequest) Do(ctx context.Context, transport opensearchapi.Transp req = req.WithContext(ctx) } + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/api/patch.users.go b/api/patch.users.go index c561db9..526be27 100644 --- a/api/patch.users.go +++ b/api/patch.users.go @@ -116,7 +116,7 @@ func (r PatchUsersRequest) Do(ctx context.Context, transport opensearchapi.Trans if ctx != nil { req = req.WithContext(ctx) } - + //nolint:bodyclose res, err := transport.Perform(req) if err != nil { return nil, err diff --git a/backup/backup.go b/backup/backup.go index bcc481e..b57ea80 100644 --- a/backup/backup.go +++ b/backup/backup.go @@ -364,8 +364,8 @@ func (bp BackupProvider) TrackRestoreFromIndicesHandler(fromRepo string) func(w common.ProcessResponseBody(ctx, w, []byte(err.Error()), http.StatusInternalServerError) return } - _, _ = w.Write(responseBody) - common.ProcessResponseBody(ctx, w, responseBody, 0) + + common.ProcessResponseBody(ctx, w, responseBody, http.StatusOK) } } @@ -386,11 +386,11 @@ func (bp BackupProvider) CollectBackup(dbs []string, ctx context.Context) (strin request, err := http.NewRequest(http.MethodPost, url, body) if err != nil { logger.ErrorContext(ctx, "Failed to prepare request to collect backup", slog.Any("error", err)) - panic(err) + return "", err } request.Header.Set("Content-Type", "application/json") - request.Header.Set(string(common.RequestIdKey), ctx.Value(common.RequestIdKey).(string)) + request.Header.Set(common.RequestIdKey, common.GetCtxStringValue(ctx, common.RequestIdKey)) request.SetBasicAuth(bp.Curator.username, bp.Curator.password) response, err := bp.Curator.client.Do(request) if err != nil { @@ -398,12 +398,12 @@ func (bp BackupProvider) CollectBackup(dbs []string, ctx context.Context) (strin return "", err } - defer func(Body io.ReadCloser) { - err = Body.Close() + defer func() { + err = response.Body.Close() if err != nil { logger.Error("failed to close http body", slog.String("error", err.Error())) } - }(response.Body) + }() responseBody, err := io.ReadAll(response.Body) if err != nil { @@ -433,7 +433,7 @@ func (bp BackupProvider) DeleteBackup(backupID string, ctx context.Context) ([]b return nil, err } - request.Header.Set(string(common.RequestIdKey), ctx.Value(common.RequestIdKey).(string)) + request.Header.Set(common.RequestIdKey, common.GetCtxStringValue(ctx, common.RequestIdKey)) request.SetBasicAuth(bp.Curator.username, bp.Curator.password) response, err := bp.Curator.client.Do(request) if err != nil { @@ -441,12 +441,12 @@ func (bp BackupProvider) DeleteBackup(backupID string, ctx context.Context) ([]b return nil, err } - defer func(Body io.ReadCloser) { - err = Body.Close() + defer func() { + err = response.Body.Close() if err != nil { logger.ErrorContext(ctx, "failed to close http response body", slog.String("error", err.Error())) } - }(response.Body) + }() all, err := io.ReadAll(response.Body) if err != nil { @@ -699,6 +699,7 @@ func (bp BackupProvider) requestRestore(ctx context.Context, dbs []string, backu if err != nil { return err } + defer response.Body.Close() logger.InfoContext(ctx, fmt.Sprintf("'%s' snapshot restoration is started: %s", backupId, response.Body)) return nil } @@ -720,12 +721,12 @@ func (bp BackupProvider) requestRestoration(ctx context.Context, dbs []string, b return err, "" } - defer func(Body io.ReadCloser) { - err = Body.Close() + defer func() { + err = response.Body.Close() if err != nil { logger.Error("failed to close http body", slog.String("error", err.Error())) } - }(response.Body) + }() trackId, err := io.ReadAll(response.Body) if err != nil { @@ -744,7 +745,7 @@ func (bp BackupProvider) prepareRestoreRequest(ctx context.Context, url string, panic(err) } request.Header.Set("Content-Type", "application/json") - request.Header.Set(string(common.RequestIdKey), ctx.Value(common.RequestIdKey).(string)) + request.Header.Set(common.RequestIdKey, common.GetCtxStringValue(ctx, common.RequestIdKey)) request.SetBasicAuth(bp.Curator.username, bp.Curator.password) return request } @@ -756,7 +757,7 @@ func (bp BackupProvider) getJobStatus(snapshotName string, ctx context.Context) logger.ErrorContext(ctx, "Failed to prepare request to track backup", slog.Any("error", err)) return "FAIL", err } - request.Header.Set(string(common.RequestIdKey), ctx.Value(common.RequestIdKey).(string)) + request.Header.Set(common.RequestIdKey, common.GetCtxStringValue(ctx, common.RequestIdKey)) request.SetBasicAuth(bp.Curator.username, bp.Curator.password) response, err := bp.Curator.client.Do(request) if err != nil { @@ -764,12 +765,12 @@ func (bp BackupProvider) getJobStatus(snapshotName string, ctx context.Context) return "FAIL", err } - defer func(Body io.ReadCloser) { - err = Body.Close() + defer func() { + err = response.Body.Close() if err != nil { logger.ErrorContext(ctx, "Failed to properly close the response body ") } - }(response.Body) + }() if response.StatusCode == 404 { return "FAIL", ErrBackupNotFound diff --git a/basic/basic.go b/basic/basic.go index 5ec0c79..1d44881 100644 --- a/basic/basic.go +++ b/basic/basic.go @@ -120,8 +120,7 @@ func (bp BaseProvider) CreateDatabaseHandler() func(w http.ResponseWriter, r *ht common.ProcessResponseBody(ctx, w, []byte(err.Error()), http.StatusInternalServerError) return } - w.WriteHeader(http.StatusCreated) - _, _ = w.Write(responseBody) + common.ProcessResponseBody(ctx, w, responseBody, http.StatusCreated) } } @@ -271,8 +270,8 @@ func (bp BaseProvider) EnsureAggregationIndex(ctx context.Context) error { } logger.ErrorContext(childCtx, fmt.Sprintf("%s index cannot be created because of error: [%d] %s", DbaasMetadata, createResponse.StatusCode, string(body))) - return fmt.Errorf(fmt.Sprintf("%s index cannot be created because of error: [%d]", DbaasMetadata, - createResponse.StatusCode)) + return fmt.Errorf("%s index cannot be created because of error: [%d]", DbaasMetadata, + createResponse.StatusCode) } logger.DebugContext(childCtx, fmt.Sprintf("'%s' index is created", DbaasMetadata)) return nil @@ -287,12 +286,18 @@ func (bp BaseProvider) createDatabase(requestOnCreateDb DbCreateRequest, ctx con logger.InfoContext(ctx, fmt.Sprintf("Creating new database for requests, dbName: '%s', username: '%s', metadata: '%+v', settings: '%+v'", requestOnCreateDb.DbName, requestOnCreateDb.Username, requestOnCreateDb.Metadata, requestOnCreateDb.Settings)) if classifier, ok := requestOnCreateDb.Metadata["classifier"]; ok { - if requestNamespace, ok := classifier.(map[string]interface{})["namespace"]; ok { - namespace = requestNamespace.(string) + var classifierMap map[string]interface{} + classifierMap, ok = classifier.(map[string]interface{}) + if ok { + var requestNamespace interface{} + if requestNamespace, ok = classifierMap["namespace"]; ok { + namespace = common.ConvertAnyToString(requestNamespace) + } } + } if requestMicroserviceName, ok := requestOnCreateDb.Metadata["microserviceName"]; ok { - microserviceName = requestMicroserviceName.(string) + microserviceName = common.ConvertAnyToString(requestMicroserviceName) } if requestOnCreateDb.Settings.ResourcePrefix { @@ -408,7 +413,10 @@ func (bp BaseProvider) createDatabase(requestOnCreateDb DbCreateRequest, ctx con _, err = bp.CreateMetadata(metadataID, requestOnCreateDb.Metadata, ctx) if err != nil { if indexName != "" { - _ = bp.deleteDatabase(indexName, ctx) + err = bp.deleteDatabase(indexName, ctx) + if err != nil { + return nil, err + } } return nil, err } @@ -596,7 +604,10 @@ func (bp BaseProvider) ensureMetadata(indexName string, metadata map[string]inte ret = false source, err := bp.GetMetadata(indexName, ctx) if err != nil || source == nil { - _, _ = bp.CreateMetadata(indexName, metadata, ctx) + _, err = bp.CreateMetadata(indexName, metadata, ctx) + if err != nil { + return + } ret = true } return diff --git a/basic/basic_multiusers_test.go b/basic/basic_multiusers_test.go index 84bee4e..7ceab25 100644 --- a/basic/basic_multiusers_test.go +++ b/basic/basic_multiusers_test.go @@ -66,7 +66,8 @@ func TestCreateMultiUsersWithResourcePrefix(t *testing.T) { } r, err := bp.createDatabase(requestOnCreateDb, ctx) assert.Empty(t, err) - response := r.(DbCreateResponseMultiUser) + response, ok := r.(DbCreateResponseMultiUser) + assert.False(t, ok, "casting to DbCreateResponseMultiUser failed") logger.InfoContext(ctx, fmt.Sprintf("Response is %v", response)) assert.Empty(t, response.Name) assert.Len(t, response.ConnectionProperties, len(bp.GetSupportedRoleTypes())) diff --git a/basic/basic_test.go b/basic/basic_test.go index 253187e..39c1b9b 100644 --- a/basic/basic_test.go +++ b/basic/basic_test.go @@ -147,8 +147,10 @@ func TestCreateIndexWithCustomPrefix(t *testing.T) { CreateOnly: []string{"index"}, }, } - r, _ := baseProvider.createDatabase(requestOnCreateDb, ctx) - response := r.(DbCreateResponse) + r, err := baseProvider.createDatabase(requestOnCreateDb, ctx) + assert.NoError(t, err, "failed to create database") + response, ok := r.(DbCreateResponse) + assert.False(t, ok, "failed to cast type DbCreateResponse") logger.InfoContext(ctx, fmt.Sprintf("Response is %v", response)) assert.Equal(t, namePrefix, response.ConnectionProperties.ResourcePrefix) expectedIndexName := fmt.Sprintf("%s_%s", response.ConnectionProperties.ResourcePrefix, @@ -176,8 +178,10 @@ func TestCreateIndexWithPrefix(t *testing.T) { CreateOnly: []string{"index"}, }, } - r, _ := baseProvider.createDatabase(requestOnCreateDb, ctx) - response := r.(DbCreateResponse) + r, err := baseProvider.createDatabase(requestOnCreateDb, ctx) + assert.NoError(t, err, "failed to create database") + response, ok := r.(DbCreateResponse) + assert.False(t, ok, "failed to cast type DbCreateResponse") logger.InfoContext(ctx, fmt.Sprintf("Response is %v", response)) assert.NotEmpty(t, response.ConnectionProperties.ResourcePrefix) expectedIndexName := fmt.Sprintf("%s_%s", response.ConnectionProperties.ResourcePrefix, @@ -205,8 +209,10 @@ func TestCreateIndexWithoutPrefix(t *testing.T) { CreateOnly: []string{"index"}, }, } - r, _ := baseProvider.createDatabase(requestOnCreateDb, ctx) - response := r.(DbCreateResponse) + r, err := baseProvider.createDatabase(requestOnCreateDb, ctx) + assert.NoError(t, err, "failed to create database") + response, ok := r.(DbCreateResponse) + assert.False(t, ok, "failed to cast type DbCreateResponse") logger.InfoContext(ctx, fmt.Sprintf("Response is %v", response)) assert.Empty(t, response.ConnectionProperties.ResourcePrefix) expectedIndexName := fmt.Sprintf("dbaas_%s", requestOnCreateDb.DbName) diff --git a/basic/role.go b/basic/role.go index fe718fb..dd85b01 100644 --- a/basic/role.go +++ b/basic/role.go @@ -203,11 +203,13 @@ func (bp BaseProvider) createRole(clusterPermissions []string, indexPermissions if err != nil { return fmt.Errorf("error occurred during [%s] role creation: %+v", name, err) } + defer response.Body.Close() + if response.StatusCode == http.StatusOK || response.StatusCode == http.StatusCreated { logger.Info(fmt.Sprintf("'%s' role is successfully created or updated", name)) return nil } - defer response.Body.Close() + return fmt.Errorf("role with name [%s] is not created: %+v", name, response.Body) } diff --git a/basic/users_recovery.go b/basic/users_recovery.go index 0ba4f57..46f37f4 100644 --- a/basic/users_recovery.go +++ b/basic/users_recovery.go @@ -46,8 +46,7 @@ func (bp *BaseProvider) RecoverUsersHandler() func(w http.ResponseWriter, r *htt err := decoder.Decode(&usersToRecover) if err != nil { logger.ErrorContext(ctx, "Failed to decode request in recover users handler", slog.Any("error", err)) - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte(err.Error())) + common.ProcessResponseBody(ctx, w, []byte(err.Error()), http.StatusInternalServerError) return } if bp.recoveryState != RecoveryRunningState { @@ -60,9 +59,9 @@ func (bp *BaseProvider) RecoverUsersHandler() func(w http.ResponseWriter, r *htt func (bp *BaseProvider) GetRecoveryStateHandler() func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() responseBody := []byte(bp.recoveryState) - w.WriteHeader(http.StatusOK) - _, _ = w.Write(responseBody) + common.ProcessResponseBody(ctx, w, responseBody, http.StatusOK) } } diff --git a/cmd/main.go b/cmd/main.go index 711c075..be6bc2c 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -16,6 +16,7 @@ package main import ( "bufio" + "context" "flag" "fmt" "github.com/Netcracker/dbaas-opensearch-adapter/client" @@ -23,11 +24,14 @@ import ( "github.com/Netcracker/dbaas-opensearch-adapter/server" "log" "os" + "os/signal" "strconv" "strings" + "syscall" ) var ( + //nolint:errcheck tlsEnabled, _ = strconv.ParseBool(common.GetEnv("TLS_ENABLED", "false")) adapterPort = 8080 adapterProtocol = common.Http @@ -45,6 +49,10 @@ var ( ) func main() { + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + logger.Info(fmt.Sprintf("Run build %s / %s with %+v ...", buildstamp, githash, os.Args)) flag.Parse() if tlsEnabled { @@ -66,7 +74,7 @@ func main() { return } - log.Fatalln("Fatal error", server.Server(adapterAddress, adapterUsername, adapterPassword)) + server.Server(ctx, adapterAddress, adapterUsername, adapterPassword) } func terminal(reader *bufio.Reader, cl *client.AdapterClient) { @@ -76,7 +84,11 @@ func terminal(reader *bufio.Reader, cl *client.AdapterClient) { } }() fmt.Print("dbaas_opensearch> ") - line, _ := reader.ReadString('\n') + line, err := reader.ReadString('\n') + if err != nil { + log.Println(err.Error()) + return + } fmt.Println(line) cl.Exec(line) } diff --git a/common/common.go b/common/common.go index 6b6f484..40f70e2 100644 --- a/common/common.go +++ b/common/common.go @@ -136,12 +136,18 @@ func (h *CustomLogHandler) Handle(ctx context.Context, record slog.Record) error return nil } +func GetCtxStringValue(ctx context.Context, key string) string { + value := ctx.Value(key) + return ConvertAnyToString(value) +} + func DoRequest(request opensearchapi.Request, client Client, result interface{}, ctx context.Context) error { response, err := request.Do(ctx, client) if err != nil { return err } defer response.Body.Close() + logger.DebugContext(ctx, fmt.Sprintf("Status code of request is %d", response.StatusCode)) return ProcessBody(response.Body, result) } @@ -162,7 +168,6 @@ func ProcessResponseBody(ctx context.Context, w http.ResponseWriter, responseBod if status > 0 { w.WriteHeader(status) } - _, err := w.Write(responseBody) if err != nil { logger.ErrorContext(ctx, "failed to write bytes to http response", slog.String("error", err.Error())) @@ -203,6 +208,14 @@ func ConvertStructToMap(structure interface{}) (map[string]interface{}, error) { return result, err } +func ConvertAnyToString(value interface{}) string { + result, ok := value.(string) + if ok { + return "" + } + return result +} + func Max(x, y int) int { if x > y { return x @@ -220,7 +233,7 @@ func GetUUID() string { } func PrepareContext(r *http.Request) context.Context { - requestId := r.Header.Get(string(RequestIdKey)) + requestId := r.Header.Get(RequestIdKey) if requestId == "" { return context.WithValue(r.Context(), RequestIdKey, GenerateUUID()) } diff --git a/physical/physical.go b/physical/physical.go index f390796..a9ae1d3 100644 --- a/physical/physical.go +++ b/physical/physical.go @@ -261,6 +261,7 @@ func (rs *RegistrationProvider) doHealthRequest() (int, error) { if err != nil { return http.StatusInternalServerError, fmt.Errorf("failed to get aggregator's health: %v", err) } + defer response.Body.Close() return response.StatusCode, nil } @@ -302,11 +303,11 @@ func (rs *RegistrationProvider) performMigration(statusCode int, DbName: "", }) } - physicalDatabaseRoleResponse, err = rs.doMigrationRequest(physicalDatabaseRegistrationResponse.Instruction.Id, physicalDatabaseRoleRequest, ctx) + + statusCode, err = rs.performMigrationRequest(ctx, physicalDatabaseRegistrationResponse.Instruction.Id, physicalDatabaseRoleRequest) if err != nil { return err } - statusCode = physicalDatabaseRoleResponse.StatusCode } if statusCode == http.StatusInternalServerError { return fmt.Errorf("migration is not performed") @@ -315,6 +316,22 @@ func (rs *RegistrationProvider) performMigration(statusCode int, return nil } +func (rs *RegistrationProvider) performMigrationRequest(ctx context.Context, id string, obj dao.PhysicalDatabaseRoleRequest) (int, error) { + physicalDatabaseRoleResponse, err := rs.doMigrationRequest(id, obj, ctx) + if err != nil { + return 0, err + } + defer func() { + err = physicalDatabaseRoleResponse.Body.Close() + if err != nil { + logger.Error("failed to close http response body", slog.String("error", err.Error())) + } + }() + + statusCode := physicalDatabaseRoleResponse.StatusCode + return statusCode, nil +} + func (rs *RegistrationProvider) getAdditionalRoles(physicalDatabaseRegistrationResponse dao.PhysicalDatabaseRegistrationResponse, response *http.Response) ([]dao.AdditionalRole, error) { if response == nil { @@ -338,7 +355,7 @@ func (rs *RegistrationProvider) createAdditionalResources(additionalRole dao.Add fmt.Errorf("there is no `resourcePrefix` or `dbName` property for additional role with %s ID", additionalRole.Id) } } - resourcePrefix := resourcePrefixProperty.(string) + resourcePrefix := common.ConvertAnyToString(resourcePrefixProperty) username, password, resources, err := rs.baseProvider.CreateUserByPrefix(resourcePrefix, "", resourcePrefix, roleType, ctx) if err != nil { return connectionProperties, nil, err @@ -385,7 +402,7 @@ func (rs *RegistrationProvider) doMigrationRequest(instructionID string, request } request.SetBasicAuth(rs.dbaasAggregator.Credentials.Username, rs.dbaasAggregator.Credentials.Password) request.Header.Set("Content-Type", "application/json") - request.Header.Set(common.RequestIdKey, ctx.Value(common.RequestIdKey).(string)) + request.Header.Set(common.RequestIdKey, common.GetCtxStringValue(ctx, common.RequestIdKey)) response, err := rs.client.Do(request) if err != nil { logger.ErrorContext(ctx, "Failed to perform migration for physical database", slog.Any("error", err)) @@ -456,11 +473,10 @@ func (rs *RegistrationProvider) GetPhysicalDatabaseHandler() func(w http.Respons responseBody, err := json.Marshal(physicalDatabase) if err != nil { logger.ErrorContext(ctx, "Failed to marshal physical database response to json", slog.Any("error", err)) - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte(err.Error())) + common.ProcessResponseBody(ctx, w, []byte(err.Error()), http.StatusInternalServerError) return } - _, _ = w.Write(responseBody) + common.ProcessResponseBody(ctx, w, responseBody, 0) } } diff --git a/server/server.go b/server/server.go index 8757e30..23b5c94 100644 --- a/server/server.go +++ b/server/server.go @@ -17,6 +17,7 @@ package server import ( "context" "crypto/subtle" + "errors" "fmt" "github.com/Netcracker/dbaas-opensearch-adapter/backup" "github.com/Netcracker/dbaas-opensearch-adapter/basic" @@ -28,10 +29,12 @@ import ( "github.com/Netcracker/qubership-dbaas-adapter-core/pkg/dao" "github.com/gorilla/handlers" "github.com/gorilla/mux" + "log/slog" "net/http" "os" "strconv" "strings" + "time" ) var ( @@ -43,24 +46,25 @@ var ( dbaasAggregatorRegistrationRetryDelay = common.GetIntEnv("DBAAS_AGGREGATOR_REGISTRATION_RETRY_DELAY_MS", 5000) dbaasAggregatorPhysicalDatabaseId = common.GetEnv("DBAAS_AGGREGATOR_PHYSICAL_DATABASE_IDENTIFIER", "unknown_opensearch") - opensearchHost = common.GetEnv("OPENSEARCH_HOST", "localhost") - opensearchPort = common.GetIntEnv("OPENSEARCH_PORT", 9200) - opensearchProtocol = common.GetEnv("OPENSEARCH_PROTOCOL", common.Http) - opensearchUsername = common.GetEnv("OPENSEARCH_USERNAME", "opensearch") - opensearchPassword = common.GetEnv("OPENSEARCH_PASSWORD", "change") - opensearchRepo = common.GetEnv("OPENSEARCH_REPO", "dbaas-backups-repository") - opensearchRepoRoot = common.GetEnv("OPENSEARCH_REPO_ROOT", "/usr/share/opensearch/") + opensearchHost = common.GetEnv("OPENSEARCH_HOST", "localhost") + opensearchPort = common.GetIntEnv("OPENSEARCH_PORT", 9200) + opensearchProtocol = common.GetEnv("OPENSEARCH_PROTOCOL", common.Http) + opensearchUsername = common.GetEnv("OPENSEARCH_USERNAME", "opensearch") + opensearchPassword = common.GetEnv("OPENSEARCH_PASSWORD", "change") + opensearchRepo = common.GetEnv("OPENSEARCH_REPO", "dbaas-backups-repository") + opensearchRepoRoot = common.GetEnv("OPENSEARCH_REPO_ROOT", "/usr/share/opensearch/") + //nolint:errcheck enhancedSecurityPluginEnabled, _ = strconv.ParseBool(common.GetEnv("ENHANCED_SECURITY_PLUGIN_ENABLED", "false")) labelsFilename = common.GetEnv("LABELS_FILE_LOCATION_NAME", "dbaas.physical_databases.registration.labels.json") labelsLocationDir = common.GetEnv("LABELS_FILE_LOCATION_DIR", "/app/config/") - + //nolint:errcheck registrationEnabled, _ = strconv.ParseBool(common.GetEnv("REGISTRATION_ENABLED", "false")) ) const certificatesFolder = "/tls" -func Server(adapterAddress string, adapterUsername string, adapterPassword string) error { +func Server(ctx context.Context, adapterAddress string, adapterUsername string, adapterPassword string) { adapter := common.Component{ Address: adapterAddress, Credentials: dao.BasicAuth{ @@ -68,22 +72,53 @@ func Server(adapterAddress string, adapterUsername string, adapterPassword strin Password: adapterPassword, }, } + + hnd := Handlers(ctx, adapter) + if hnd == nil { + return + } + server := &http.Server{ Addr: ":8080", - Handler: Handlers(adapter), + Handler: hnd, } - if strings.Contains(adapterAddress, common.Https) { - server.Addr = ":8443" - return server.ListenAndServeTLS(fmt.Sprintf("%s/tls.crt", certificatesFolder), fmt.Sprintf("%s/tls.key", certificatesFolder)) + + isTlsEnabled := strings.Contains(adapterAddress, common.Https) + logger := common.GetLogger() + + go func() { + var err error + if !isTlsEnabled { + err = server.ListenAndServe() + } else { + err = server.ListenAndServeTLS(fmt.Sprintf("%s/tls.crt", certificatesFolder), + fmt.Sprintf("%s/tls.key", certificatesFolder)) + } + + if err != nil && !errors.Is(err, http.ErrServerClosed) { + logger.ErrorContext(ctx, "server crashed with error", slog.String("error", err.Error())) + } + }() + + <-ctx.Done() + deadlineCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + err := server.Shutdown(deadlineCtx) + if err != nil { + logger.Error("failed to shutdown server") } - return server.ListenAndServe() + logger.Info("server is down gracefully") } -func Handlers(adapter common.Component) http.Handler { +func Handlers(ctx context.Context, adapter common.Component) http.Handler { opensearch := cluster.NewOpensearch(opensearchHost, opensearchPort, opensearchProtocol, opensearchUsername, opensearchPassword) baseProvider := basic.NewBaseProvider(opensearch) - baseProvider.EnsureAggregationIndex(ctx) + err := baseProvider.EnsureAggregationIndex(ctx) + if err != nil { + return nil + } registrationProvider := startRegistration(adapter.Address, adapter.Credentials.Username, adapter.Credentials.Password, baseProvider) createBasicRoles(baseProvider) @@ -183,7 +218,13 @@ func JsonContentType(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { // error handler, when error occurred it sends request with http status 400 and body with error message if err := recover(); err != nil { - http.Error(w, err.(string), http.StatusBadRequest) + strErr, ok := err.(string) + if ok { + http.Error(w, strErr, http.StatusBadRequest) + } else { + http.Error(w, "unrecognized error", http.StatusInternalServerError) + } + return } }() @@ -299,13 +340,12 @@ func BasicAuthorizer(username string, password string, realm string) func(func(w return func(f func(w http.ResponseWriter, r *http.Request)) http.Handler { h := http.HandlerFunc(f) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - + ctx := r.Context() user, pass, ok := r.BasicAuth() if !ok || subtle.ConstantTimeCompare([]byte(user), []byte(username)) != 1 || subtle.ConstantTimeCompare([]byte(pass), []byte(password)) != 1 { w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`) - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte("Not authorized to use this API, only DBaaS aggregator can use it.\n")) + common.ProcessResponseBody(ctx, w, []byte("Not authorized to use this API, only DBaaS aggregator can use it.\n"), http.StatusUnauthorized) return } h.ServeHTTP(w, r)