diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index fe78e0857ed..cd428759117 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -51,8 +51,12 @@ const ( ) var ( - iptablesChainExistCmd = "iptables -C %s -p %s --dport %s -j DROP" - nsenterCommandString = "nsenter --net=%s" + iptablesChainExistCmd = "iptables -C %s -p %s --dport %s -j DROP" + nsenterCommandString = "nsenter --net=%s " + tcCheckInjectionCommandString = "tc -j q show dev %s parent 1:1" + tcAddQdiscRootCommandString = "tc qdisc add dev %s root handle 1: prio priomap 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2" + tcAddQdiscLossCommandString = "tc qdisc add dev %s parent 1:1 handle 10: netem loss %d%%" + tcAddFilterForIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 1 u32 match ip dst %s flowid 1:1" ) type FaultHandler struct { @@ -475,18 +479,47 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. rwMu.Lock() defer rwMu.Unlock() - // TODO: Check status of current fault injection - // TODO: Invoke the start fault injection functionality if not running - - responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") - logger.Info("Successfully started fault", logger.Fields{ + var responseBody types.NetworkFaultInjectionResponse + var httpStatusCode int + stringToBeLogged := "Failed to start fault" + // All command executions for the start network packet loss workflow all together should finish within 5 seconds. + // Thus, create the context here so that it can be shared by all os/exec calls. + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + defer cancel() + // Check the status of current fault injection. + latencyFaultExists, packetLossFaultExists, err := h.checkTCFault(ctx, taskMetadata) + if err != nil { + responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + httpStatusCode = http.StatusInternalServerError + } else { + // If there already exists a fault in the task network namespace. + if latencyFaultExists { + responseBody = types.NewNetworkFaultInjectionErrorResponse("There is already one network latency fault running") + httpStatusCode = http.StatusConflict + } else if packetLossFaultExists { + responseBody = types.NewNetworkFaultInjectionErrorResponse("There is already one network packet loss fault running") + httpStatusCode = http.StatusConflict + } else { + // Invoke the start fault injection functionality if not running. + err := h.startNetworkPacketLossFault(ctx, taskMetadata, request) + if err != nil { + responseBody = types.NewNetworkFaultInjectionErrorResponse("Failed to inject network-packet-loss fault") + httpStatusCode = http.StatusInternalServerError + } else { + stringToBeLogged = "Successfully started fault" + responseBody = types.NewNetworkFaultInjectionSuccessResponse("running") + httpStatusCode = http.StatusOK + } + } + } + logger.Info(stringToBeLogged, logger.Fields{ field.RequestType: requestType, field.Request: request.ToString(), field.Response: responseBody.ToString(), }) utils.WriteJSONResponse( w, - http.StatusOK, + httpStatusCode, responseBody, requestType, ) @@ -794,6 +827,128 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return nil } +// startNetworkPacketLossFault invokes the linux TC utility tool to start the network-packet-loss fault. +func (h *FaultHandler) startNetworkPacketLossFault(ctx context.Context, taskMetadata *state.TaskResponse, request types.NetworkPacketLossRequest) error { + interfaceName := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + // If task's network mode is awsvpc, we need to run nsenter to access the task's network namespace. + nsenterPrefix := "" + if networkMode == ecs.NetworkModeAwsvpc { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path) + } + lossPercent := aws.Uint64Value(request.LossPercent) + + // Command to be executed: + // tc qdisc add dev root handle 1: prio priomap 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 + // "tc qdisc add dev parent 1:1 handle 10: netem loss %" + tcAddQdiscRootCommandComposed := nsenterPrefix + fmt.Sprintf(tcAddQdiscRootCommandString, interfaceName) + cmdList := strings.Split(tcAddQdiscRootCommandComposed, " ") + cmdOutput, err := h.runExecCommand(ctx, cmdList) + if err != nil { + logger.Error(fmt.Sprintf("'%s' command failed with the following error: '%s'. std output: '%s'", + tcAddQdiscRootCommandComposed, err, string(cmdOutput[:]))) + return err + } + logger.Info(fmt.Sprintf("'%s' command result: '%s'", tcAddQdiscRootCommandComposed, string(cmdOutput[:]))) + tcAddQdiscLossCommandComposed := nsenterPrefix + fmt.Sprintf(tcAddQdiscLossCommandString, interfaceName, lossPercent) + cmdList = strings.Split(tcAddQdiscLossCommandComposed, " ") + cmdOutput, err = h.runExecCommand(ctx, cmdList) + if err != nil { + logger.Error(fmt.Sprintf("'%s' command failed with the following error: '%s'. std output: '%s'", + tcAddQdiscLossCommandComposed, err, string(cmdOutput[:]))) + return err + } + logger.Info(fmt.Sprintf("'%s' command result: '%s'", tcAddQdiscLossCommandComposed, string(cmdOutput[:]))) + // After creating the queueing discipline, create filters to associate the IPs in the request with the handle. + for _, ip := range request.Sources { + tcAddFilterForIPCommandComposed := nsenterPrefix + fmt.Sprintf(tcAddFilterForIPCommandString, interfaceName, *ip) + cmdList = strings.Split(tcAddFilterForIPCommandComposed, " ") + _, err = h.runExecCommand(ctx, cmdList) + if err != nil { + logger.Error(fmt.Sprintf("'%s' command failed with the following error: '%s'. std output: '%s'", + tcAddFilterForIPCommandComposed, err, string(cmdOutput[:]))) + return err + } + } + + return nil +} + +// checkTCFault check if there's existing network-latency fault or network-packet-loss fault. +func (h *FaultHandler) checkTCFault(ctx context.Context, taskMetadata *state.TaskResponse) (bool, bool, error) { + interfaceName := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + // If task's network mode is awsvpc, we need to run nsenter to access the task's network namespace. + nsenterPrefix := "" + if networkMode == ecs.NetworkModeAwsvpc { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path) + } + + // We will run the following Linux command to assess if there existing fault. + // "tc -j q show dev {INTERFACE} parent 1:1" + // The command above gives the output of "tc q show dev {INTERFACE} parent 1:1" in json format. + // We will then unmarshall the json string and evaluate the fields of it. + tcCheckInjectionCommandComposed := nsenterPrefix + fmt.Sprintf(tcCheckInjectionCommandString, interfaceName) + cmdList := strings.Split(tcCheckInjectionCommandComposed, " ") + cmdOutput, err := h.runExecCommand(ctx, cmdList) + if err != nil { + logger.Error(fmt.Sprintf("'%s' command failed with the following error: '%s'. std output: '%s'", + tcCheckInjectionCommandComposed, err, string(cmdOutput[:]))) + return false, false, fmt.Errorf("failed to check existing network fault: '%s' command failed with the following error: '%s'. std output: '%s'", + tcCheckInjectionCommandComposed, err, string(cmdOutput[:])) + } + // Log the command output to better help us debug. + logger.Info(fmt.Sprintf("'%s' command result: '%s'", tcCheckInjectionCommandComposed, string(cmdOutput[:]))) + + // Check whether latency fault exists and whether packet loss fault exists separately. + var outputUnmarshalled []map[string]interface{} + err = json.Unmarshal(cmdOutput, &outputUnmarshalled) + if err != nil { + return false, false, errors.New("failed to check existing network fault: failed to unmarshal tc command output: " + err.Error()) + } + latencyFaultExists, err := checkLatencyFault(outputUnmarshalled) + if err != nil { + return false, false, errors.New("failed to check existing network fault: " + err.Error()) + } + packetLossFaultExists, err := checkPacketLossFault(outputUnmarshalled) + if err != nil { + return false, false, errors.New("failed to check existing network fault: " + err.Error()) + } + return latencyFaultExists, packetLossFaultExists, nil +} + +// checkLatencyFault parses the tc command output and checks if there's existing network-latency fault running. +func checkLatencyFault(outputUnmarshalled []map[string]interface{}) (bool, error) { + for _, line := range outputUnmarshalled { + // Check if field "kind":"netem" exists. + if line["kind"] == "netem" { + // Now check if network packet loss fault exists. + if options := line["options"]; options != nil { + if delay := options.(map[string]interface{})["delay"]; delay != nil { + return true, nil + } + } + } + } + return false, nil +} + +// checkPacketLossFault parses the tc command output and checks if there's existing network-packet-loss fault running. +func checkPacketLossFault(outputUnmarshalled []map[string]interface{}) (bool, error) { + for _, line := range outputUnmarshalled { + // First check field "kind":"netem" exists. + if line["kind"] == "netem" { + // Now check if field "loss":"" exists, and if the percentage matches with the value in the request. + if options := line["options"]; options != nil { + if lossRandom := options.(map[string]interface{})["loss-random"]; lossRandom != nil { + return true, nil + } + } + } + } + return false, nil +} + // runExecCommand wraps around the execwrapper, providing a convenient way of running any Linux command // and getting the result in both stdout and stderr. func (h *FaultHandler) runExecCommand(ctx context.Context, cmdList []string) ([]byte, error) { diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index fe78e0857ed..cd428759117 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -51,8 +51,12 @@ const ( ) var ( - iptablesChainExistCmd = "iptables -C %s -p %s --dport %s -j DROP" - nsenterCommandString = "nsenter --net=%s" + iptablesChainExistCmd = "iptables -C %s -p %s --dport %s -j DROP" + nsenterCommandString = "nsenter --net=%s " + tcCheckInjectionCommandString = "tc -j q show dev %s parent 1:1" + tcAddQdiscRootCommandString = "tc qdisc add dev %s root handle 1: prio priomap 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2" + tcAddQdiscLossCommandString = "tc qdisc add dev %s parent 1:1 handle 10: netem loss %d%%" + tcAddFilterForIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 1 u32 match ip dst %s flowid 1:1" ) type FaultHandler struct { @@ -475,18 +479,47 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. rwMu.Lock() defer rwMu.Unlock() - // TODO: Check status of current fault injection - // TODO: Invoke the start fault injection functionality if not running - - responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") - logger.Info("Successfully started fault", logger.Fields{ + var responseBody types.NetworkFaultInjectionResponse + var httpStatusCode int + stringToBeLogged := "Failed to start fault" + // All command executions for the start network packet loss workflow all together should finish within 5 seconds. + // Thus, create the context here so that it can be shared by all os/exec calls. + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + defer cancel() + // Check the status of current fault injection. + latencyFaultExists, packetLossFaultExists, err := h.checkTCFault(ctx, taskMetadata) + if err != nil { + responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + httpStatusCode = http.StatusInternalServerError + } else { + // If there already exists a fault in the task network namespace. + if latencyFaultExists { + responseBody = types.NewNetworkFaultInjectionErrorResponse("There is already one network latency fault running") + httpStatusCode = http.StatusConflict + } else if packetLossFaultExists { + responseBody = types.NewNetworkFaultInjectionErrorResponse("There is already one network packet loss fault running") + httpStatusCode = http.StatusConflict + } else { + // Invoke the start fault injection functionality if not running. + err := h.startNetworkPacketLossFault(ctx, taskMetadata, request) + if err != nil { + responseBody = types.NewNetworkFaultInjectionErrorResponse("Failed to inject network-packet-loss fault") + httpStatusCode = http.StatusInternalServerError + } else { + stringToBeLogged = "Successfully started fault" + responseBody = types.NewNetworkFaultInjectionSuccessResponse("running") + httpStatusCode = http.StatusOK + } + } + } + logger.Info(stringToBeLogged, logger.Fields{ field.RequestType: requestType, field.Request: request.ToString(), field.Response: responseBody.ToString(), }) utils.WriteJSONResponse( w, - http.StatusOK, + httpStatusCode, responseBody, requestType, ) @@ -794,6 +827,128 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return nil } +// startNetworkPacketLossFault invokes the linux TC utility tool to start the network-packet-loss fault. +func (h *FaultHandler) startNetworkPacketLossFault(ctx context.Context, taskMetadata *state.TaskResponse, request types.NetworkPacketLossRequest) error { + interfaceName := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + // If task's network mode is awsvpc, we need to run nsenter to access the task's network namespace. + nsenterPrefix := "" + if networkMode == ecs.NetworkModeAwsvpc { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path) + } + lossPercent := aws.Uint64Value(request.LossPercent) + + // Command to be executed: + // tc qdisc add dev root handle 1: prio priomap 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 + // "tc qdisc add dev parent 1:1 handle 10: netem loss %" + tcAddQdiscRootCommandComposed := nsenterPrefix + fmt.Sprintf(tcAddQdiscRootCommandString, interfaceName) + cmdList := strings.Split(tcAddQdiscRootCommandComposed, " ") + cmdOutput, err := h.runExecCommand(ctx, cmdList) + if err != nil { + logger.Error(fmt.Sprintf("'%s' command failed with the following error: '%s'. std output: '%s'", + tcAddQdiscRootCommandComposed, err, string(cmdOutput[:]))) + return err + } + logger.Info(fmt.Sprintf("'%s' command result: '%s'", tcAddQdiscRootCommandComposed, string(cmdOutput[:]))) + tcAddQdiscLossCommandComposed := nsenterPrefix + fmt.Sprintf(tcAddQdiscLossCommandString, interfaceName, lossPercent) + cmdList = strings.Split(tcAddQdiscLossCommandComposed, " ") + cmdOutput, err = h.runExecCommand(ctx, cmdList) + if err != nil { + logger.Error(fmt.Sprintf("'%s' command failed with the following error: '%s'. std output: '%s'", + tcAddQdiscLossCommandComposed, err, string(cmdOutput[:]))) + return err + } + logger.Info(fmt.Sprintf("'%s' command result: '%s'", tcAddQdiscLossCommandComposed, string(cmdOutput[:]))) + // After creating the queueing discipline, create filters to associate the IPs in the request with the handle. + for _, ip := range request.Sources { + tcAddFilterForIPCommandComposed := nsenterPrefix + fmt.Sprintf(tcAddFilterForIPCommandString, interfaceName, *ip) + cmdList = strings.Split(tcAddFilterForIPCommandComposed, " ") + _, err = h.runExecCommand(ctx, cmdList) + if err != nil { + logger.Error(fmt.Sprintf("'%s' command failed with the following error: '%s'. std output: '%s'", + tcAddFilterForIPCommandComposed, err, string(cmdOutput[:]))) + return err + } + } + + return nil +} + +// checkTCFault check if there's existing network-latency fault or network-packet-loss fault. +func (h *FaultHandler) checkTCFault(ctx context.Context, taskMetadata *state.TaskResponse) (bool, bool, error) { + interfaceName := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + // If task's network mode is awsvpc, we need to run nsenter to access the task's network namespace. + nsenterPrefix := "" + if networkMode == ecs.NetworkModeAwsvpc { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path) + } + + // We will run the following Linux command to assess if there existing fault. + // "tc -j q show dev {INTERFACE} parent 1:1" + // The command above gives the output of "tc q show dev {INTERFACE} parent 1:1" in json format. + // We will then unmarshall the json string and evaluate the fields of it. + tcCheckInjectionCommandComposed := nsenterPrefix + fmt.Sprintf(tcCheckInjectionCommandString, interfaceName) + cmdList := strings.Split(tcCheckInjectionCommandComposed, " ") + cmdOutput, err := h.runExecCommand(ctx, cmdList) + if err != nil { + logger.Error(fmt.Sprintf("'%s' command failed with the following error: '%s'. std output: '%s'", + tcCheckInjectionCommandComposed, err, string(cmdOutput[:]))) + return false, false, fmt.Errorf("failed to check existing network fault: '%s' command failed with the following error: '%s'. std output: '%s'", + tcCheckInjectionCommandComposed, err, string(cmdOutput[:])) + } + // Log the command output to better help us debug. + logger.Info(fmt.Sprintf("'%s' command result: '%s'", tcCheckInjectionCommandComposed, string(cmdOutput[:]))) + + // Check whether latency fault exists and whether packet loss fault exists separately. + var outputUnmarshalled []map[string]interface{} + err = json.Unmarshal(cmdOutput, &outputUnmarshalled) + if err != nil { + return false, false, errors.New("failed to check existing network fault: failed to unmarshal tc command output: " + err.Error()) + } + latencyFaultExists, err := checkLatencyFault(outputUnmarshalled) + if err != nil { + return false, false, errors.New("failed to check existing network fault: " + err.Error()) + } + packetLossFaultExists, err := checkPacketLossFault(outputUnmarshalled) + if err != nil { + return false, false, errors.New("failed to check existing network fault: " + err.Error()) + } + return latencyFaultExists, packetLossFaultExists, nil +} + +// checkLatencyFault parses the tc command output and checks if there's existing network-latency fault running. +func checkLatencyFault(outputUnmarshalled []map[string]interface{}) (bool, error) { + for _, line := range outputUnmarshalled { + // Check if field "kind":"netem" exists. + if line["kind"] == "netem" { + // Now check if network packet loss fault exists. + if options := line["options"]; options != nil { + if delay := options.(map[string]interface{})["delay"]; delay != nil { + return true, nil + } + } + } + } + return false, nil +} + +// checkPacketLossFault parses the tc command output and checks if there's existing network-packet-loss fault running. +func checkPacketLossFault(outputUnmarshalled []map[string]interface{}) (bool, error) { + for _, line := range outputUnmarshalled { + // First check field "kind":"netem" exists. + if line["kind"] == "netem" { + // Now check if field "loss":"" exists, and if the percentage matches with the value in the request. + if options := line["options"]; options != nil { + if lossRandom := options.(map[string]interface{})["loss-random"]; lossRandom != nil { + return true, nil + } + } + } + } + return false, nil +} + // runExecCommand wraps around the execwrapper, providing a convenient way of running any Linux command // and getting the result in both stdout and stderr. func (h *FaultHandler) runExecCommand(ctx context.Context, cmdList []string) ([]byte, error) { diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index 0f80eb96321..a295ea3345b 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -40,18 +40,22 @@ import ( ) const ( - endpointId = "endpointId" - port = 1234 - protocol = "tcp" - trafficType = "ingress" - delayMilliseconds = 123456789 - jitterMilliseconds = 4567 - lossPercent = 6 - taskARN = "taskArn" - awsvpcNetworkMode = "awsvpc" - deviceName = "eth0" - invalidNetworkMode = "invalid" - iptablesChainNotFoundError = "iptables: Bad rule (does a matching rule exist in that chain?)." + endpointId = "endpointId" + port = 1234 + protocol = "tcp" + trafficType = "ingress" + delayMilliseconds = 123456789 + jitterMilliseconds = 4567 + lossPercent = 6 + taskARN = "taskArn" + awsvpcNetworkMode = "awsvpc" + deviceName = "eth0" + invalidNetworkMode = "invalid" + iptablesChainNotFoundError = "iptables: Bad rule (does a matching rule exist in that chain?)." + tcLatencyFaultExistsCommandOutput = `[{"kind":"netem","handle":"10:","parent":"1:1","options":{"limit":1000,"delay":{"delay":0.1,"jitter":0,"correlation":0},"ecn":false,"gap":0}}]` + tcLossFaultExistsCommandOutput = `[{"kind":"netem","handle":"10:","dev":"eth0","parent":"1:1","options":{"limit":1000,"loss-random":{"loss":0.06,"correlation":0},"ecn":false,"gap":0}}]` + tcLossFaultDoesNotExistCommandOutput = `[{"kind":"dummyname"}]` + tcCommandEmptyOutput = `[]` ) var ( @@ -1054,43 +1058,12 @@ func TestCheckNetworkLatency(t *testing.T) { } } -func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) []networkFaultInjectionTestCase { +func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjectionTestCase { happyNetworkPacketLossReqBody := map[string]interface{}{ "LossPercent": lossPercent, "Sources": ipSources, } tcs := []networkFaultInjectionTestCase{ - { - name: fmt.Sprintf("%s success", name), - expectedStatusCode: 200, - requestBody: happyNetworkPacketLossReqBody, - expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) - }, - }, - { - name: fmt.Sprintf("%s unknown request body", name), - expectedStatusCode: 200, - requestBody: map[string]interface{}{ - "LossPercent": lossPercent, - "Sources": ipSources, - "Unknown": "", - }, - expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) - }, - }, - { - name: fmt.Sprintf("%s no request body", name), - expectedStatusCode: 400, - requestBody: nil, - expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required request body is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) - }, - }, { name: fmt.Sprintf("%s malformed request body 1", name), expectedStatusCode: 400, @@ -1281,9 +1254,131 @@ func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) } return tcs } +func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase { + happyNetworkPacketLossReqBody := map[string]interface{}{ + "LossPercent": lossPercent, + "Sources": ipSources, + } + commonTcs := generateCommonNetworkPacketLossTestCases("start network-packet-loss") + tcs := []networkFaultInjectionTestCase{ + { + name: "no-existing-fault", + expectedStatusCode: 200, + requestBody: happyNetworkPacketLossReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil)) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(4).Return(mockCMD) + mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil) + }, + }, + { + name: "existing-network-latency-fault", + expectedStatusCode: 409, + requestBody: happyNetworkPacketLossReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("There is already one network latency fault running"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLatencyFaultExistsCommandOutput), nil) + }, + }, + { + name: "existing-network-packet-loss-fault", + expectedStatusCode: 409, + requestBody: happyNetworkPacketLossReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("There is already one network packet loss fault running"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultExistsCommandOutput), nil) + }, + }, + { + name: "unknown-request-body-no-existing-fault", + expectedStatusCode: 200, + requestBody: map[string]interface{}{ + "LossPercent": lossPercent, + "Sources": ipSources, + "Unknown": "", + }, + expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil)) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(4).Return(mockCMD) + mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil) + }, + }, + { + name: "failed-to-unmarshal-json", + expectedStatusCode: 500, + requestBody: map[string]interface{}{ + "LossPercent": lossPercent, + "Sources": ipSources, + "Unknown": "", + }, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("failed to check existing network fault: failed to unmarshal tc command output: unexpected end of JSON input"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte("["), nil) + }, + }, + { + name: "os/exec-times-out", + expectedStatusCode: 500, + requestBody: map[string]interface{}{ + "LossPercent": lossPercent, + "Sources": ipSources, + "Unknown": "", + }, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("failed to check existing network fault: 'nsenter --net=/some/path tc -j q show dev eth0 parent 1:1' command failed with the following error: 'signal: killed'. std output: ''"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte{}, errors.New("signal: killed")) + }, + }, + } + return append(tcs, commonTcs...) +} + +func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase { + commonTcs := generateCommonNetworkPacketLossTestCases("start network-packet-loss") + tcs := []networkFaultInjectionTestCase{} + return append(tcs, commonTcs...) +} + +func generateCheckNetworkPacketLossTestCases() []networkFaultInjectionTestCase { + commonTcs := generateCommonNetworkPacketLossTestCases("start network-packet-loss") + tcs := []networkFaultInjectionTestCase{} + return append(tcs, commonTcs...) +} func TestStartNetworkPacketLoss(t *testing.T) { - tcs := generateNetworkPacketLossTestCases("start network packet loss", "running") + tcs := generateStartNetworkPacketLossTestCases() for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { @@ -1293,19 +1388,23 @@ func TestStartNetworkPacketLoss(t *testing.T) { agentState := mock_state.NewMockAgentState(ctrl) metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) - execWrapper := mock_execwrapper.NewMockExec(ctrl) if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) } router := mux.NewRouter() - handler := New(agentState, metricsFactory, execWrapper) + mockExec := mock_execwrapper.NewMockExec(ctrl) + handler := New(agentState, metricsFactory, mockExec) router.HandleFunc( NetworkFaultPath(types.PacketLossFaultType), handler.StartNetworkPacketLoss(), ).Methods("PUT") + if tc.setExecExpectations != nil { + tc.setExecExpectations(mockExec, ctrl) + } + method := "PUT" var requestBody io.Reader if tc.requestBody != nil { @@ -1333,7 +1432,7 @@ func TestStartNetworkPacketLoss(t *testing.T) { } func TestStopNetworkPacketLoss(t *testing.T) { - tcs := generateNetworkPacketLossTestCases("stop network packet loss", "stopped") + tcs := generateStopNetworkPacketLossTestCases() for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { // Mocks @@ -1342,14 +1441,14 @@ func TestStopNetworkPacketLoss(t *testing.T) { agentState := mock_state.NewMockAgentState(ctrl) metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) - execWrapper := mock_execwrapper.NewMockExec(ctrl) if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) } router := mux.NewRouter() - handler := New(agentState, metricsFactory, execWrapper) + mockExec := mock_execwrapper.NewMockExec(ctrl) + handler := New(agentState, metricsFactory, mockExec) router.HandleFunc( NetworkFaultPath(types.PacketLossFaultType), handler.StopNetworkPacketLoss(), @@ -1382,7 +1481,7 @@ func TestStopNetworkPacketLoss(t *testing.T) { } func TestCheckNetworkPacketLoss(t *testing.T) { - tcs := generateNetworkPacketLossTestCases("check network packet loss", "running") + tcs := generateCheckNetworkPacketLossTestCases() for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { @@ -1392,14 +1491,17 @@ func TestCheckNetworkPacketLoss(t *testing.T) { agentState := mock_state.NewMockAgentState(ctrl) metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) - execWrapper := mock_execwrapper.NewMockExec(ctrl) + router := mux.NewRouter() + mockExec := mock_execwrapper.NewMockExec(ctrl) + handler := New(agentState, metricsFactory, mockExec) if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) } + if tc.setExecExpectations != nil { + tc.setExecExpectations(mockExec, ctrl) + } - router := mux.NewRouter() - handler := New(agentState, metricsFactory, execWrapper) router.HandleFunc( NetworkFaultPath(types.PacketLossFaultType), handler.CheckNetworkPacketLoss(),