diff --git a/agent/app/agent_test.go b/agent/app/agent_test.go index 230170ad8ed..6ccc88bef53 100644 --- a/agent/app/agent_test.go +++ b/agent/app/agent_test.go @@ -58,6 +58,7 @@ import ( ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + smithy "github.com/aws/smithy-go" "github.com/docker/docker/api/types" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -1059,8 +1060,9 @@ func TestReregisterContainerInstanceInstanceTypeChanged(t *testing.T) { mockDockerClient.EXPECT().ListPluginsWithFilters(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return([]string{}, nil), client.EXPECT().RegisterContainerInstance(containerInstanceARN, gomock.Any(), gomock.Any(), gomock.Any(), - gomock.Any(), gomock.Any()).Return("", "", awserr.New("", - apierrors.InstanceTypeChangedErrorMessage, errors.New(""))), + gomock.Any(), gomock.Any()).Return("", "", &smithy.GenericAPIError{ + Message: apierrors.InstanceTypeChangedErrorMessage, + }), ) cfg := getTestConfig() diff --git a/agent/go.mod b/agent/go.mod index 1c2c637dd44..e405b4abfdc 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -11,6 +11,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.28.1 github.com/aws/aws-sdk-go-v2/credentials v1.17.42 github.com/aws/aws-sdk-go-v2/service/ecs v1.47.3 + github.com/aws/smithy-go v1.22.0 github.com/awslabs/go-config-generator-for-fluentd-and-fluentbit v0.0.0-20210308162251-8959c62cb8f9 github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 github.com/container-storage-interface/spec v1.8.0 @@ -52,7 +53,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.24.3 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.3 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.32.3 // indirect - github.com/aws/smithy-go v1.22.0 // indirect github.com/cilium/ebpf v0.16.0 // indirect github.com/containerd/containerd v1.7.24 // indirect github.com/containerd/log v0.1.0 // indirect diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 827da18e539..04a74d17936 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -57,12 +57,14 @@ import ( v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" + smithy "github.com/aws/smithy-go" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" + smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/docker/docker/api/types" "github.com/golang/mock/gomock" "github.com/gorilla/mux" @@ -3287,11 +3289,17 @@ func TestGetTaskProtection(t *testing.T) { setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( nil, - awserr.NewRequestFailure( - awserr.New(apierrors.ErrCodeServerException, ecsErrMessage, nil), - http.StatusInternalServerError, - ecsRequestID, - ), + &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: http.StatusInternalServerError, + }, + }, + Err: &ecstypes.ServerException{Message: &ecsErrMessage}, + }, + RequestID: ecsRequestID, + }, ), expectedStatusCode: http.StatusInternalServerError, expectedResponseBody: tptypes.TaskProtectionResponse{ @@ -3311,11 +3319,17 @@ func TestGetTaskProtection(t *testing.T) { setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( nil, - awserr.NewRequestFailure( - awserr.New(apierrors.ErrCodeAccessDeniedException, ecsErrMessage, nil), - http.StatusBadRequest, - ecsRequestID, - ), + &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: http.StatusBadRequest, + }, + }, + Err: &ecstypes.AccessDeniedException{Message: &ecsErrMessage}, + }, + RequestID: ecsRequestID, + }, ), expectedStatusCode: http.StatusBadRequest, expectedResponseBody: tptypes.TaskProtectionResponse{ @@ -3335,7 +3349,7 @@ func TestGetTaskProtection(t *testing.T) { setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( nil, - awserr.New(apierrors.ErrCodeInvalidParameterException, ecsErrMessage, nil)), + &ecstypes.InvalidParameterException{Message: &ecsErrMessage}), expectedStatusCode: http.StatusInternalServerError, expectedResponseBody: tptypes.TaskProtectionResponse{ Error: &tptypes.ErrorResponse{ @@ -3352,7 +3366,7 @@ func TestGetTaskProtection(t *testing.T) { setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( - nil, awserr.New(request.CanceledErrorCode, "request cancelled", nil)), + nil, &smithy.CanceledError{}), expectedStatusCode: http.StatusGatewayTimeout, expectedResponseBody: tptypes.TaskProtectionResponse{ Error: &tptypes.ErrorResponse{ @@ -3561,11 +3575,17 @@ func TestUpdateTaskProtection(t *testing.T) { setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( nil, - awserr.NewRequestFailure( - awserr.New(apierrors.ErrCodeServerException, ecsErrMessage, nil), - http.StatusInternalServerError, - ecsRequestID, - ), + &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: http.StatusInternalServerError, + }, + }, + Err: &ecstypes.ServerException{Message: &ecsErrMessage}, + }, + RequestID: ecsRequestID, + }, ), expectedStatusCode: http.StatusInternalServerError, expectedResponseBody: tptypes.TaskProtectionResponse{ @@ -3583,11 +3603,17 @@ func TestUpdateTaskProtection(t *testing.T) { setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( nil, - awserr.NewRequestFailure( - awserr.New(apierrors.ErrCodeAccessDeniedException, ecsErrMessage, nil), - http.StatusBadRequest, - ecsRequestID, - ), + &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: http.StatusBadRequest, + }, + }, + Err: &ecstypes.AccessDeniedException{Message: &ecsErrMessage}, + }, + RequestID: ecsRequestID, + }, ), expectedStatusCode: http.StatusBadRequest, expectedResponseBody: tptypes.TaskProtectionResponse{ @@ -3605,7 +3631,7 @@ func TestUpdateTaskProtection(t *testing.T) { setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( nil, - awserr.New(apierrors.ErrCodeInvalidParameterException, ecsErrMessage, nil)), + &ecstypes.InvalidParameterException{Message: &ecsErrMessage}), expectedStatusCode: http.StatusInternalServerError, expectedResponseBody: tptypes.TaskProtectionResponse{ Error: &tptypes.ErrorResponse{ @@ -3620,7 +3646,7 @@ func TestUpdateTaskProtection(t *testing.T) { setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( - nil, awserr.New(request.CanceledErrorCode, "request cancelled", nil)), + nil, &smithy.CanceledError{}), expectedStatusCode: http.StatusGatewayTimeout, expectedResponseBody: tptypes.TaskProtectionResponse{ Error: &tptypes.ErrorResponse{ diff --git a/agent/handlers/v2/response.go b/agent/handlers/v2/response.go index 52882fca18f..0ca327e3ab1 100644 --- a/agent/handlers/v2/response.go +++ b/agent/handlers/v2/response.go @@ -14,7 +14,11 @@ package v2 import ( + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/smithy-go" + "github.com/cihub/seelog" + "github.com/pkg/errors" apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" @@ -25,9 +29,7 @@ import ( tmdsresponse "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" tmdsv2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" - "github.com/aws/aws-sdk-go/aws" - "github.com/cihub/seelog" - "github.com/pkg/errors" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" ) // Agent versions >= 1.2.0: Null, zero, and CPU values of 1 @@ -280,5 +282,17 @@ func newErrorResponse(err error, field, resourceARN string) *tmdsv2.ErrorRespons } } + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + errResp.ErrorCode = apiErr.ErrorCode() + errResp.ErrorMessage = apiErr.ErrorMessage() + } + + var re *awshttp.ResponseError + if errors.As(err, &re) { + errResp.StatusCode = re.HTTPStatusCode() + errResp.RequestId = re.RequestID + } + return errResp } diff --git a/agent/handlers/v2/response_test.go b/agent/handlers/v2/response_test.go index 6b5efc6e434..079cfe652ff 100644 --- a/agent/handlers/v2/response_test.go +++ b/agent/handlers/v2/response_test.go @@ -19,6 +19,7 @@ package v2 import ( "encoding/json" "fmt" + "net/http" "testing" "time" @@ -30,12 +31,12 @@ import ( apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" tmdsv2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" + "github.com/aws/aws-sdk-go-v2/aws" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" + smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/docker/docker/api/types" "github.com/golang/mock/gomock" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -679,19 +680,38 @@ func TestTaskResponseWithV4TagsError(t *testing.T) { }, } - errCode := "ThrottlingException" errMessage := "Rate exceeded" errStatusCode := 400 containerTagsRequestId := "cef9da77-aee7-431d-84d5-f92b2d342c51" taskTagsRequestId := "45dbbc67-0c60-4248-855e-14fdf4c11870" - containerTagsErr := awserr.NewRequestFailure(awserr.Error(awserr.New(errCode, errMessage, errors.New(""))), errStatusCode, containerTagsRequestId) - taskTagsError := awserr.NewRequestFailure(awserr.Error(awserr.New(errCode, errMessage, errors.New(""))), errStatusCode, taskTagsRequestId) + containerTagsErr := &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: errStatusCode, + }, + }, + Err: &ecstypes.LimitExceededException{Message: &errMessage}, + }, + RequestID: containerTagsRequestId, + } + taskTagsErr := &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: errStatusCode, + }, + }, + Err: &ecstypes.LimitExceededException{Message: &errMessage}, + }, + RequestID: taskTagsRequestId, + } gomock.InOrder( state.EXPECT().TaskByArn(taskARN).Return(task, true), state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(nil, containerTagsErr), - ecsClient.EXPECT().GetResourceTags(taskARN).Return(nil, taskTagsError), + ecsClient.EXPECT().GetResourceTags(taskARN).Return(nil, taskTagsErr), ) taskWithTagsResponse, err := NewTaskResponse(taskARN, state, ecsClient, cluster, availabilityZone, containerInstanceArn, true, true) @@ -699,13 +719,13 @@ func TestTaskResponseWithV4TagsError(t *testing.T) { _, err = json.Marshal(taskWithTagsResponse) assert.NoError(t, err) assert.Equal(t, taskWithTagsResponse.Errors[0].ErrorField, "ContainerInstanceTags") - assert.Equal(t, taskWithTagsResponse.Errors[0].ErrorCode, errCode) + assert.Equal(t, taskWithTagsResponse.Errors[0].ErrorCode, (&ecstypes.LimitExceededException{}).ErrorCode()) assert.Equal(t, taskWithTagsResponse.Errors[0].ErrorMessage, errMessage) assert.Equal(t, taskWithTagsResponse.Errors[0].StatusCode, errStatusCode) assert.Equal(t, taskWithTagsResponse.Errors[0].RequestId, containerTagsRequestId) assert.Equal(t, taskWithTagsResponse.Errors[0].ResourceARN, containerInstanceArn) assert.Equal(t, taskWithTagsResponse.Errors[1].ErrorField, "TaskTags") - assert.Equal(t, taskWithTagsResponse.Errors[1].ErrorCode, errCode) + assert.Equal(t, taskWithTagsResponse.Errors[1].ErrorCode, (&ecstypes.LimitExceededException{}).ErrorCode()) assert.Equal(t, taskWithTagsResponse.Errors[1].ErrorMessage, errMessage) assert.Equal(t, taskWithTagsResponse.Errors[1].StatusCode, errStatusCode) assert.Equal(t, taskWithTagsResponse.Errors[1].RequestId, taskTagsRequestId) diff --git a/agent/utils/utils.go b/agent/utils/utils.go index fd46b9a0f86..f3ffc86ed72 100644 --- a/agent/utils/utils.go +++ b/agent/utils/utils.go @@ -29,10 +29,12 @@ import ( "strings" commonutils "github.com/aws/amazon-ecs-agent/ecs-agent/utils" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/smithy-go" "github.com/pkg/errors" ) @@ -140,7 +142,16 @@ func Remove(slice []string, s int) []string { // the passed in error code. func IsAWSErrorCodeEqual(err error, code string) bool { awsErr, ok := err.(awserr.Error) - return ok && awsErr.Code() == code + if ok { + return awsErr.Code() == code + } + + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + return apiErr.ErrorCode() == code + } + + return false } // GetRequestFailureStatusCode returns the status code from a @@ -148,8 +159,14 @@ func IsAWSErrorCodeEqual(err error, code string) bool { func GetRequestFailureStatusCode(err error) int { var statusCode int if reqErr, ok := err.(awserr.RequestFailure); ok { - statusCode = reqErr.StatusCode() + return reqErr.StatusCode() + } + + var re *awshttp.ResponseError + if errors.As(err, &re) { + return re.HTTPStatusCode() } + return statusCode } diff --git a/agent/utils/utils_test.go b/agent/utils/utils_test.go index e5732cdd99f..7c964cf889f 100644 --- a/agent/utils/utils_test.go +++ b/agent/utils/utils_test.go @@ -19,14 +19,18 @@ package utils import ( "errors" "fmt" + "net/http" "os" "sort" "testing" "time" apierrors "github.com/aws/amazon-ecs-agent/ecs-agent/api/errors" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/smithy-go" + smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -115,6 +119,15 @@ func TestIsAWSErrorCodeEqual(t *testing.T) { err: awserr.New("errCode", "errMsg", errors.New("err")), res: false, }, + { + name: "Happy Path SDKv2", + err: &smithy.GenericAPIError{Code: apierrors.ErrCodeInvalidParameterException}, + res: true, + }, + { + name: "Wrong Error Code SDKv2", + err: &smithy.GenericAPIError{Code: "errCode"}, + }, { name: "Wrong Error Type", err: errors.New("err"), @@ -137,6 +150,19 @@ func TestGetRequestFailureStatusCode(t *testing.T) { }{ { name: "TestGetRequestFailureStatusCodeSuccess", + err: &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: http.StatusBadRequest, + }, + }, + }, + }, + res: 400, + }, + { + name: "TestGetRequestFailureStatusCodeSuccess SDKv2", err: awserr.NewRequestFailure(awserr.Error(awserr.New("BadRequest", "", errors.New(""))), 400, ""), res: 400, },