From c1ee70b46b76417e87b4583f17270a6106d456d6 Mon Sep 17 00:00:00 2001 From: mye956 Date: Fri, 16 Aug 2024 20:56:27 +0000 Subject: [PATCH] Consuming updated task response for fault injection in fault injection handlers --- agent/api/task/task.go | 27 +++ agent/handlers/task_server_setup_test.go | 89 +++++-- agent/handlers/v4/tmdsstate.go | 19 ++ .../handlers/fault/v1/handlers/handlers.go | 88 ++++++- .../handlers/fault/v1/handlers/handlers.go | 88 ++++++- .../fault/v1/handlers/handlers_test.go | 218 ++++++++++++++---- 6 files changed, 461 insertions(+), 68 deletions(-) diff --git a/agent/api/task/task.go b/agent/api/task/task.go index 94eeae32b09..f353e750f67 100644 --- a/agent/api/task/task.go +++ b/agent/api/task/task.go @@ -296,6 +296,12 @@ type Task struct { NetworkMode string `json:"NetworkMode,omitempty"` IsInternal bool `json:"IsInternal,omitempty"` + + // TODO: Will need to initialize/set the value in a follow PR + NetworkNamespace string `json:"NetworkNamespace,omitempty"` + + // TODO: Will need to initialize/set the value in a follow PR + FaultInjectionEnabled bool `json:"FaultInjectionEnabled,omitempty"` } // TaskFromACS translates ecsacs.Task to apitask.Task by first marshaling the received @@ -3743,3 +3749,24 @@ func (task *Task) HasAContainerWithResolvedDigest() bool { } return false } + +func (task *Task) IsFaultInjectionEnabled() bool { + task.lock.RLock() + defer task.lock.RUnlock() + + return task.FaultInjectionEnabled +} + +func (task *Task) GetNetworkMode() string { + task.lock.RLock() + defer task.lock.RUnlock() + + return task.NetworkMode +} + +func (task *Task) GetNetworkNamespace() string { + task.lock.RLock() + defer task.lock.RUnlock() + + return task.NetworkNamespace +} diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 9fa1bd246a7..7966c5527e3 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -112,6 +112,7 @@ const ( subnetGatewayIpv4Address = "172.31.32.1/20" taskCredentialsID = "taskCredentialsId" endpointId = "endpointId" + networkNamespace = "/path" port = 1234 protocol = "tcp" @@ -416,6 +417,21 @@ var ( SubnetGatewayIPV4Address: "", }}, }) + + agentStateExpectations = func(state *mock_dockerstate.MockTaskEngineState, faultInjectionEnabled bool, networkMode string) { + task := standardTask() + task.FaultInjectionEnabled = faultInjectionEnabled + task.NetworkMode = networkMode + task.NetworkNamespace = networkNamespace + gomock.InOrder( + state.EXPECT().TaskARNByV3EndpointID(endpointId).Return(taskARN, true), + state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2), + state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), + state.EXPECT().TaskByArn(taskARN).Return(task, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(), + state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), + ) + } ) func standardTask() *apitask.Task { @@ -3576,7 +3592,9 @@ type blackholePortFaultTestCase struct { expectedStatusCode int requestBody interface{} expectedFaultResponse faulttype.NetworkFaultInjectionResponse - setStateExpectations func(state *mock_dockerstate.MockTaskEngineState) + setStateExpectations func(state *mock_dockerstate.MockTaskEngineState, faultInjectionEnabled bool, networkMode string) + faultInjectionEnabled bool + networkMode string } func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyResponseBody string) []blackholePortFaultTestCase { @@ -3585,24 +3603,25 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe "Protocol": protocol, "TrafficType": trafficType, } - happyStateExpectations := func(state *mock_dockerstate.MockTaskEngineState) { - task := standardTask() - gomock.InOrder( - state.EXPECT().TaskARNByV3EndpointID(endpointId).Return(taskARN, true), - state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2), - state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), - state.EXPECT().TaskByArn(taskARN).Return(task, true), - state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(), - state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), - ) - } + tcs := []blackholePortFaultTestCase{ { - name: fmt.Sprintf("%s success", name), + name: fmt.Sprintf("%s success host mode", name), + expectedStatusCode: 200, + requestBody: happyBlackHolePortReqBody, + expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), + setStateExpectations: agentStateExpectations, + faultInjectionEnabled: true, + networkMode: apitask.HostNetworkMode, + }, + { + name: fmt.Sprintf("%s success awsvpc mode", name), expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), - setStateExpectations: happyStateExpectations, + setStateExpectations: agentStateExpectations, + faultInjectionEnabled: true, + networkMode: apitask.AWSVPCNetworkMode, }, { name: fmt.Sprintf("%s unknown request body", name), @@ -3614,7 +3633,9 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe "Unknown": "", }, expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), - setStateExpectations: happyStateExpectations, + setStateExpectations: agentStateExpectations, + faultInjectionEnabled: true, + networkMode: apitask.AWSVPCNetworkMode, }, { name: fmt.Sprintf("%s malformed request body", name), @@ -3625,6 +3646,8 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe "TrafficType": trafficType, }, expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkBlackholePortRequest.Port of type uint16"), + faultInjectionEnabled: true, + networkMode: apitask.AWSVPCNetworkMode, }, { name: fmt.Sprintf("%s incomplete request body", name), @@ -3634,6 +3657,8 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe "Protocol": protocol, }, expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("required parameter TrafficType is missing"), + faultInjectionEnabled: true, + networkMode: apitask.AWSVPCNetworkMode, }, { name: fmt.Sprintf("%s empty value request body", name), @@ -3644,6 +3669,8 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe "TrafficType": "", }, expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("required parameter TrafficType is missing"), + faultInjectionEnabled: true, + networkMode: apitask.AWSVPCNetworkMode, }, { name: fmt.Sprintf("%s invalid protocol value request body", name), @@ -3654,6 +3681,8 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe "TrafficType": trafficType, }, expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("invalid value invalid for parameter Protocol"), + faultInjectionEnabled: true, + networkMode: apitask.AWSVPCNetworkMode, }, { name: fmt.Sprintf("%s invalid traffic type value request body", name), @@ -3664,29 +3693,53 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe "TrafficType": "invalid", }, expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("invalid value invalid for parameter TrafficType"), + faultInjectionEnabled: true, + networkMode: apitask.AWSVPCNetworkMode, }, { name: fmt.Sprintf("%s task lookup fail", name), expectedStatusCode: 404, requestBody: happyBlackHolePortReqBody, expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)), - setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState, faultInjectionEnabled bool, networkMode string) { gomock.InOrder( state.EXPECT().TaskARNByV3EndpointID(endpointId).Return("", false), ) }, + faultInjectionEnabled: true, + networkMode: apitask.AWSVPCNetworkMode, }, { name: fmt.Sprintf("%s task metadata fetch fail", name), expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)), - setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState, faultInjectionEnabled bool, networkMode string) { gomock.InOrder( state.EXPECT().TaskARNByV3EndpointID(endpointId).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(nil, false), ) }, + faultInjectionEnabled: true, + networkMode: apitask.AWSVPCNetworkMode, + }, + { + name: fmt.Sprintf("%s fault injection disabled", name), + expectedStatusCode: 400, + requestBody: happyBlackHolePortReqBody, + expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("fault injection is not enabled for task: %s", taskARN)), + setStateExpectations: agentStateExpectations, + faultInjectionEnabled: false, + networkMode: apitask.AWSVPCNetworkMode, + }, + { + name: fmt.Sprintf("%s invalid network mode", name), + expectedStatusCode: 400, + requestBody: happyBlackHolePortReqBody, + expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("invalid mode is not supported. Please use either host or awsvpc mode."), + setStateExpectations: agentStateExpectations, + faultInjectionEnabled: true, + networkMode: "invalid", }, } return tcs @@ -3722,7 +3775,7 @@ func testRegisterFaultHandler(t *testing.T, tcs []blackholePortFaultTestCase, me metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) if tc.setStateExpectations != nil { - tc.setStateExpectations(state) + tc.setStateExpectations(state, tc.faultInjectionEnabled, tc.networkMode) } router := mux.NewRouter() diff --git a/agent/handlers/v4/tmdsstate.go b/agent/handlers/v4/tmdsstate.go index 3e8e098dc27..4daefdbb766 100644 --- a/agent/handlers/v4/tmdsstate.go +++ b/agent/handlers/v4/tmdsstate.go @@ -151,6 +151,25 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool) NewPulledContainerResponse(dockerContainer, task.GetPrimaryENI())) } + if task.IsFaultInjectionEnabled() { + // TODO: The correct values for the task network config will need to be set/initialized + taskResponse.FaultInjectionEnabled = task.IsFaultInjectionEnabled() + taskNetworkConfig := tmdsv4.TaskNetworkConfig{ + NetworkMode: task.GetNetworkMode(), + NetworkNamespaces: []*tmdsv4.NetworkNamespace{ + { + Path: task.GetNetworkNamespace(), + NetworkInterfaces: []*tmdsv4.NetworkInterface{ + { + DeviceName: "", + }, + }, + }, + }, + } + taskResponse.TaskNetworkConfig = &taskNetworkConfig + } + return *taskResponse, nil } diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 768624c32c0..d550dafbd20 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -21,6 +21,7 @@ import ( "io" "net/http" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" @@ -36,6 +37,8 @@ const ( startFaultRequestType = "start %s" stopFaultRequestType = "stop %s" checkStatusFaultRequestType = "check status %s" + invalidNetworkModeError = "%s mode is not supported. Please use either host or awsvpc mode." + faultInjectionEnabledError = "fault injection is not enabled for task: %s" ) type FaultHandler struct { @@ -51,6 +54,7 @@ func NetworkFaultPath(fault string) string { utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault) } +// StartNetworkBlackholePort will return the request handler function for starting a network blackhole port fault func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var request types.NetworkBlackholePortRequest @@ -66,9 +70,6 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht if err != nil { return } - logger.Debug("Successfully parsed fault request payload", logger.Fields{ - field.Request: request.ToString(), - }) // Obtain the task metadata via the endpoint container ID // TODO: Will be using the returned task metadata in a future PR @@ -95,6 +96,7 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } } +// StopNetworkBlackHolePort will return the request handler function for stopping a network blackhole port fault func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var request types.NetworkBlackholePortRequest @@ -140,6 +142,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt } } +// CheckNetworkBlackHolePort will return the request handler function for checking the status of a network blackhole port fault func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var request types.NetworkBlackholePortRequest @@ -435,6 +438,7 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. } } +// decodeRequest will translate/unmarshal an incoming fault injection request into one of the network fault structs func decodeRequest(w http.ResponseWriter, request types.NetworkFaultRequest, requestType string, r *http.Request) error { logRequest(requestType, r) jsonDecoder := json.NewDecoder(r.Body) @@ -458,6 +462,7 @@ func decodeRequest(w http.ResponseWriter, request types.NetworkFaultRequest, req return nil } +// validateRequest will validate that the incoming fault injection request will have the required fields. func validateRequest(w http.ResponseWriter, request types.NetworkFaultRequest, requestType string) error { if err := request.ValidateRequest(); err != nil { responseBody := types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("%v", err)) @@ -502,12 +507,67 @@ func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, re return nil, errResponse } - // TODO: Check if task is FIS-enabled - // TODO: Check if task is using a valid network mode + // Check if task is FIS-enabled + if !taskMetadata.FaultInjectionEnabled { + errResponse := fmt.Sprintf(faultInjectionEnabledError, taskMetadata.TaskARN) + responseBody := types.NewNetworkFaultInjectionErrorResponse(errResponse) + logger.Error("Error: Task is not fault injection enabled.", logger.Fields{ + field.RequestType: requestType, + field.TMDSEndpointContainerID: endpointContainerID, + field.Response: responseBody.ToString(), + field.TaskARN: taskMetadata.TaskARN, + field.Error: errResponse, + }) + utils.WriteJSONResponse( + w, + http.StatusBadRequest, + responseBody, + requestType, + ) + return nil, errors.New(errResponse) + } + + if err := validateTaskNetworkConfig(taskMetadata.TaskNetworkConfig); err != nil { + code, errResponse := getTaskMetadataErrorResponse(endpointContainerID, requestType, err) + responseBody := types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("%v", errResponse)) + logger.Error("Error: Unable to resolve task network config within task metadata", logger.Fields{ + field.Error: err, + field.RequestType: requestType, + field.Response: responseBody.ToString(), + field.TMDSEndpointContainerID: endpointContainerID, + }) + utils.WriteJSONResponse( + w, + code, + responseBody, + requestType, + ) + return nil, errResponse + } + + // Check if task is using a valid network mode + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + if networkMode != ecs.NetworkModeHost && networkMode != ecs.NetworkModeAwsvpc { + errResponse := fmt.Sprintf(invalidNetworkModeError, networkMode) + responseBody := types.NewNetworkFaultInjectionErrorResponse(errResponse) + logger.Error("Error: Invalid network mode for fault injection", logger.Fields{ + field.RequestType: requestType, + field.NetworkMode: networkMode, + field.Response: responseBody.ToString(), + }) + utils.WriteJSONResponse( + w, + http.StatusBadRequest, + responseBody, + requestType, + ) + return nil, errors.New(errResponse) + } return &taskMetadata, nil } +// getTaskMetadataErrorResponse will be used to classify certain errors that was returned from a GetTaskMetadata function call. func getTaskMetadataErrorResponse(endpointContainerID, requestType string, err error) (int, error) { var errContainerLookupFailed *state.ErrorLookupFailure if errors.As(err, &errContainerLookupFailed) { @@ -526,6 +586,7 @@ func getTaskMetadataErrorResponse(endpointContainerID, requestType string, err e return http.StatusInternalServerError, fmt.Errorf("failed to get task metadata due to internal server error for container: %s", endpointContainerID) } +// logRequest is used to log incoming fault injection requests. func logRequest(requestType string, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { @@ -541,3 +602,20 @@ func logRequest(requestType string, r *http.Request) { }) r.Body = io.NopCloser(bytes.NewBuffer(body)) } + +// validateTaskNetworkConfig validates the passed in task network config for any null/empty values. +func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error { + if taskNetworkConfig == nil { + return errors.New("TaskNetworkConfig is empty within task metadata") + } + + if len(taskNetworkConfig.NetworkNamespaces) == 0 || taskNetworkConfig.NetworkNamespaces[0] == nil { + return errors.New("empty network namespaces within task network config") + } + + if len(taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces) == 0 || taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces == nil { + return errors.New("empty network interfaces within task network config") + } + + return nil +} diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 768624c32c0..d550dafbd20 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -21,6 +21,7 @@ import ( "io" "net/http" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" @@ -36,6 +37,8 @@ const ( startFaultRequestType = "start %s" stopFaultRequestType = "stop %s" checkStatusFaultRequestType = "check status %s" + invalidNetworkModeError = "%s mode is not supported. Please use either host or awsvpc mode." + faultInjectionEnabledError = "fault injection is not enabled for task: %s" ) type FaultHandler struct { @@ -51,6 +54,7 @@ func NetworkFaultPath(fault string) string { utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault) } +// StartNetworkBlackholePort will return the request handler function for starting a network blackhole port fault func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var request types.NetworkBlackholePortRequest @@ -66,9 +70,6 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht if err != nil { return } - logger.Debug("Successfully parsed fault request payload", logger.Fields{ - field.Request: request.ToString(), - }) // Obtain the task metadata via the endpoint container ID // TODO: Will be using the returned task metadata in a future PR @@ -95,6 +96,7 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } } +// StopNetworkBlackHolePort will return the request handler function for stopping a network blackhole port fault func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var request types.NetworkBlackholePortRequest @@ -140,6 +142,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt } } +// CheckNetworkBlackHolePort will return the request handler function for checking the status of a network blackhole port fault func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var request types.NetworkBlackholePortRequest @@ -435,6 +438,7 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. } } +// decodeRequest will translate/unmarshal an incoming fault injection request into one of the network fault structs func decodeRequest(w http.ResponseWriter, request types.NetworkFaultRequest, requestType string, r *http.Request) error { logRequest(requestType, r) jsonDecoder := json.NewDecoder(r.Body) @@ -458,6 +462,7 @@ func decodeRequest(w http.ResponseWriter, request types.NetworkFaultRequest, req return nil } +// validateRequest will validate that the incoming fault injection request will have the required fields. func validateRequest(w http.ResponseWriter, request types.NetworkFaultRequest, requestType string) error { if err := request.ValidateRequest(); err != nil { responseBody := types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("%v", err)) @@ -502,12 +507,67 @@ func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, re return nil, errResponse } - // TODO: Check if task is FIS-enabled - // TODO: Check if task is using a valid network mode + // Check if task is FIS-enabled + if !taskMetadata.FaultInjectionEnabled { + errResponse := fmt.Sprintf(faultInjectionEnabledError, taskMetadata.TaskARN) + responseBody := types.NewNetworkFaultInjectionErrorResponse(errResponse) + logger.Error("Error: Task is not fault injection enabled.", logger.Fields{ + field.RequestType: requestType, + field.TMDSEndpointContainerID: endpointContainerID, + field.Response: responseBody.ToString(), + field.TaskARN: taskMetadata.TaskARN, + field.Error: errResponse, + }) + utils.WriteJSONResponse( + w, + http.StatusBadRequest, + responseBody, + requestType, + ) + return nil, errors.New(errResponse) + } + + if err := validateTaskNetworkConfig(taskMetadata.TaskNetworkConfig); err != nil { + code, errResponse := getTaskMetadataErrorResponse(endpointContainerID, requestType, err) + responseBody := types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("%v", errResponse)) + logger.Error("Error: Unable to resolve task network config within task metadata", logger.Fields{ + field.Error: err, + field.RequestType: requestType, + field.Response: responseBody.ToString(), + field.TMDSEndpointContainerID: endpointContainerID, + }) + utils.WriteJSONResponse( + w, + code, + responseBody, + requestType, + ) + return nil, errResponse + } + + // Check if task is using a valid network mode + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + if networkMode != ecs.NetworkModeHost && networkMode != ecs.NetworkModeAwsvpc { + errResponse := fmt.Sprintf(invalidNetworkModeError, networkMode) + responseBody := types.NewNetworkFaultInjectionErrorResponse(errResponse) + logger.Error("Error: Invalid network mode for fault injection", logger.Fields{ + field.RequestType: requestType, + field.NetworkMode: networkMode, + field.Response: responseBody.ToString(), + }) + utils.WriteJSONResponse( + w, + http.StatusBadRequest, + responseBody, + requestType, + ) + return nil, errors.New(errResponse) + } return &taskMetadata, nil } +// getTaskMetadataErrorResponse will be used to classify certain errors that was returned from a GetTaskMetadata function call. func getTaskMetadataErrorResponse(endpointContainerID, requestType string, err error) (int, error) { var errContainerLookupFailed *state.ErrorLookupFailure if errors.As(err, &errContainerLookupFailed) { @@ -526,6 +586,7 @@ func getTaskMetadataErrorResponse(endpointContainerID, requestType string, err e return http.StatusInternalServerError, fmt.Errorf("failed to get task metadata due to internal server error for container: %s", endpointContainerID) } +// logRequest is used to log incoming fault injection requests. func logRequest(requestType string, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { @@ -541,3 +602,20 @@ func logRequest(requestType string, r *http.Request) { }) r.Body = io.NopCloser(bytes.NewBuffer(body)) } + +// validateTaskNetworkConfig validates the passed in task network config for any null/empty values. +func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error { + if taskNetworkConfig == nil { + return errors.New("TaskNetworkConfig is empty within task metadata") + } + + if len(taskNetworkConfig.NetworkNamespaces) == 0 || taskNetworkConfig.NetworkNamespaces[0] == nil { + return errors.New("empty network namespaces within task network config") + } + + if len(taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces) == 0 || taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces == nil { + return errors.New("empty network interfaces within task network config") + } + + return nil +} diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index f97015609ed..7b01b8acd91 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -28,6 +28,7 @@ import ( mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types" + v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" mock_state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/mocks" @@ -45,9 +46,47 @@ const ( delayMilliseconds = 123456789 jitterMilliseconds = 4567 lossPercent = 6 + taskARN = "taskArn" + awsvpcNetworkMode = "awsvpc" + deviceName = "eth0" + invalidNetworkMode = "invalid" ) -var ipSources = []string{"52.95.154.1", "52.95.154.2"} +var ( + happyNetworkInterfaces = []*state.NetworkInterface{ + { + DeviceName: deviceName, + }, + } + + happyNetworkNamespaces = []*state.NetworkNamespace{ + { + Path: "/some/path", + NetworkInterfaces: happyNetworkInterfaces, + }, + } + + happyTaskNetworkConfig = state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: happyNetworkNamespaces, + } + + happyTaskResponse = state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + TaskNetworkConfig: &happyTaskNetworkConfig, + FaultInjectionEnabled: true, + } + + ipSources = []string{"52.95.154.1", "52.95.154.2"} +) + +type networkFaultInjectionTestCase struct { + name string + expectedStatusCode int + requestBody interface{} + expectedResponseBody types.NetworkFaultInjectionResponse + setAgentStateExpectations func(agentState *mock_state.MockAgentState) +} // Tests the path for Fault Network Faults API func TestFaultBlackholeFaultPath(t *testing.T) { @@ -62,28 +101,20 @@ func TestFaultPacketLossFaultPath(t *testing.T) { assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-packet-loss", NetworkFaultPath(types.PacketLossFaultType)) } -type networkBlackHolePortTestCase struct { - name string - expectedStatusCode int - requestBody interface{} - expectedResponseBody types.NetworkFaultInjectionResponse - setAgentStateExpectations func(agentState *mock_state.MockAgentState) -} - -func getNetworkBlackHolePortTestCases(name string, expectedHappyResponseBody string) []networkBlackHolePortTestCase { +func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBody string) []networkFaultInjectionTestCase { happyBlackHolePortReqBody := map[string]interface{}{ "Port": port, "Protocol": protocol, "TrafficType": trafficType, } - tcs := []networkBlackHolePortTestCase{ + tcs := []networkFaultInjectionTestCase{ { name: fmt.Sprintf("%s success", name), expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil) + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, }, { @@ -97,7 +128,7 @@ func getNetworkBlackHolePortTestCases(name string, expectedHappyResponseBody str }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil) + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, }, { @@ -192,12 +223,53 @@ func getNetworkBlackHolePortTestCases(name string, expectedHappyResponseBody str agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, errors.New("unknown error")) }, }, + { + name: fmt.Sprintf("%s fault injection disabled", name), + expectedStatusCode: 400, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(faultInjectionEnabledError, taskARN)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: false, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s invalid network mode", name), + expectedStatusCode: 400, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(invalidNetworkModeError, invalidNetworkMode)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: invalidNetworkMode, + NetworkNamespaces: happyNetworkNamespaces, + }, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s empty task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: nil, + }, nil) + }, + }, } return tcs } func TestStartNetworkBlackHolePort(t *testing.T) { - tcs := getNetworkBlackHolePortTestCases("start blackhole port", "running") + tcs := generateNetworkBlackHolePortTestCases("start blackhole port", "running") for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { @@ -251,7 +323,7 @@ func TestStartNetworkBlackHolePort(t *testing.T) { } func TestStopNetworkBlackHolePort(t *testing.T) { - tcs := getNetworkBlackHolePortTestCases("stop blackhole port", "stopped") + tcs := generateNetworkBlackHolePortTestCases("stop blackhole port", "stopped") for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { // Mocks @@ -304,7 +376,7 @@ func TestStopNetworkBlackHolePort(t *testing.T) { } func TestCheckNetworkBlackHolePort(t *testing.T) { - tcs := getNetworkBlackHolePortTestCases("check blackhole port", "running") + tcs := generateNetworkBlackHolePortTestCases("check blackhole port", "running") for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { @@ -357,28 +429,20 @@ func TestCheckNetworkBlackHolePort(t *testing.T) { } } -type networkLatencyTestCase struct { - name string - expectedStatusCode int - requestBody interface{} - expectedResponseBody types.NetworkFaultInjectionResponse - setAgentStateExpectations func(agentState *mock_state.MockAgentState) -} - -func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []networkLatencyTestCase { +func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []networkFaultInjectionTestCase { happyNetworkLatencyReqBody := map[string]interface{}{ "DelayMilliseconds": delayMilliseconds, "JitterMilliseconds": jitterMilliseconds, "Sources": ipSources, } - tcs := []networkLatencyTestCase{ + tcs := []networkFaultInjectionTestCase{ { name: fmt.Sprintf("%s success", name), expectedStatusCode: 200, requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil) + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, }, { @@ -392,7 +456,7 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil) + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, }, { @@ -537,6 +601,47 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, errors.New("unknown error")) }, }, + { + name: fmt.Sprintf("%s fault injection disabled", name), + expectedStatusCode: 400, + requestBody: happyNetworkLatencyReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(faultInjectionEnabledError, taskARN)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: false, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s invalid network mode", name), + expectedStatusCode: 400, + requestBody: happyNetworkLatencyReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(invalidNetworkModeError, invalidNetworkMode)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: invalidNetworkMode, + NetworkNamespaces: happyNetworkNamespaces, + }, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s empty task network config", name), + expectedStatusCode: 500, + requestBody: happyNetworkLatencyReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: nil, + }, nil) + }, + }, } return tcs } @@ -702,27 +807,19 @@ func TestCheckNetworkLatency(t *testing.T) { } } -type networkPacketLossTestCase struct { - name string - expectedStatusCode int - requestBody interface{} - expectedResponseBody types.NetworkFaultInjectionResponse - setAgentStateExpectations func(agentState *mock_state.MockAgentState) -} - -func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) []networkPacketLossTestCase { +func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) []networkFaultInjectionTestCase { happyNetworkPacketLossReqBody := map[string]interface{}{ "LossPercent": lossPercent, "Sources": ipSources, } - tcs := []networkPacketLossTestCase{ + tcs := []networkFaultInjectionTestCase{ { name: fmt.Sprintf("%s success", name), expectedStatusCode: 200, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil) + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, }, { @@ -735,7 +832,7 @@ func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil) + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, }, { @@ -884,6 +981,47 @@ func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, errors.New("unknown error")) }, }, + { + name: fmt.Sprintf("%s fault injection disabled", name), + expectedStatusCode: 400, + requestBody: happyNetworkPacketLossReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(faultInjectionEnabledError, taskARN)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: false, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s invalid network mode", name), + expectedStatusCode: 400, + requestBody: happyNetworkPacketLossReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(invalidNetworkModeError, invalidNetworkMode)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: invalidNetworkMode, + NetworkNamespaces: happyNetworkNamespaces, + }, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s empty task network config", name), + expectedStatusCode: 500, + requestBody: happyNetworkPacketLossReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: nil, + }, nil) + }, + }, } return tcs }