diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index b4d60993e8d..7d052dc9bdc 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -18,6 +18,7 @@ package handlers import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -117,10 +118,18 @@ const ( hostNetworkNamespace = "host" defaultIfname = "eth0" - port = 1234 - protocol = "tcp" - trafficType = "ingress" - iptablesChainNotFoundError = "iptables: Bad rule (does a matching rule exist in that chain?)." + port = 1234 + protocol = "tcp" + trafficType = "ingress" + delayMilliseconds = 123456789 + jitterMilliseconds = 4567 + lossPercent = 6 + invalidNetworkMode = "invalid" + iptablesChainNotFoundError = "iptables: Bad rule (does a matching rule exist in that chain?)." + iptablesChainAlreadyExistError = "iptables: Chain already exists." + tcLossFaultExistsCommandOutput = `[{"kind":"netem","handle":"10:","dev":"eth0","parent":"1:1","options":{"limit":1000,"loss-random":{"loss":0.06,"correlation":0},"ecn":false,"gap":0}}]` + tcCommandEmptyOutput = `[]` + requestTimeoutDuration = 5 * time.Second ) var ( @@ -463,11 +472,24 @@ var ( ) } + ipSources = []string{"52.95.154.1", "52.95.154.2"} + happyBlackHolePortReqBody = map[string]interface{}{ "Port": port, "Protocol": protocol, "TrafficType": trafficType, } + + happyNetworkLatencyReqBody = map[string]interface{}{ + "DelayMilliseconds": delayMilliseconds, + "JitterMilliseconds": jitterMilliseconds, + "Sources": ipSources, + } + + happyNetworkPacketLossReqBody = map[string]interface{}{ + "LossPercent": lossPercent, + "Sources": ipSources, + } ) func standardTask() *apitask.Task { @@ -3703,334 +3725,175 @@ func TestUpdateTaskProtection(t *testing.T) { })) } -type blackholePortFaultTestCase struct { +type execExpectations func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) + +type networkFaultTestCase struct { name string expectedStatusCode int requestBody interface{} expectedFaultResponse faulttype.NetworkFaultInjectionResponse setStateExpectations func(state *mock_dockerstate.MockTaskEngineState, faultInjectionEnabled bool, networkMode string) - setExecExpectations func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) + setExecExpectations execExpectations faultInjectionEnabled bool networkMode string } -func getNetworkBlackHolePortHandlerTestCases(name, fault string) []blackholePortFaultTestCase { - tcs := []blackholePortFaultTestCase{ - { - name: fmt.Sprintf("%s malformed request body", name), - expectedStatusCode: 400, - requestBody: map[string]interface{}{ - "Port": "port", - "Protocol": protocol, - "TrafficType": trafficType, - }, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkBlackholePortRequest.Port of type uint16"), - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: fmt.Sprintf("%s incomplete request body", name), - expectedStatusCode: 400, - requestBody: map[string]interface{}{ - "Port": port, - "Protocol": protocol, - }, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("required parameter TrafficType is missing"), - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: fmt.Sprintf("%s empty value request body", name), - expectedStatusCode: 400, - requestBody: map[string]interface{}{ - "Port": port, - "Protocol": protocol, - "TrafficType": "", - }, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("required parameter TrafficType is missing"), - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: fmt.Sprintf("%s invalid protocol value request body", name), - expectedStatusCode: 400, - requestBody: map[string]interface{}{ - "Port": port, - "Protocol": "invalid", - "TrafficType": trafficType, - }, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("invalid value invalid for parameter Protocol"), - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: fmt.Sprintf("%s invalid traffic type value request body", name), - expectedStatusCode: 400, - requestBody: map[string]interface{}{ - "Port": port, - "Protocol": protocol, - "TrafficType": "invalid", - }, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("invalid value invalid for parameter TrafficType"), - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: fmt.Sprintf("%s task lookup fail", name), - expectedStatusCode: 404, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)), - setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState, faultInjectionEnabled bool, networkMode string) { - gomock.InOrder( - state.EXPECT().TaskARNByV3EndpointID(endpointId).Return("", false), - ) - }, - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: fmt.Sprintf("%s task metadata fetch fail", name), - expectedStatusCode: 500, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)), - setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState, faultInjectionEnabled bool, networkMode string) { - gomock.InOrder( - state.EXPECT().TaskARNByV3EndpointID(endpointId).Return(taskARN, true), - state.EXPECT().TaskByArn(taskARN).Return(nil, false), - ) - }, - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: fmt.Sprintf("%s fault injection disabled", name), - expectedStatusCode: 400, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("enableFaultInjection is not enabled for task: %s", taskARN)), - setStateExpectations: agentStateExpectations, - faultInjectionEnabled: false, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: fmt.Sprintf("%s invalid network mode", name), - expectedStatusCode: 400, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("invalid mode is not supported. Please use either host or awsvpc mode."), - setStateExpectations: agentStateExpectations, - faultInjectionEnabled: true, - networkMode: "invalid", - }, - } - return tcs -} - -func getStartNetworkBlackHolePortHandlerTestCases() []blackholePortFaultTestCase { - commonTcs := getNetworkBlackHolePortHandlerTestCases("start blackhole port", faulttype.BlackHolePortFaultType) - tcs := []blackholePortFaultTestCase{ +// generateCommonNetworkFaultInjectionTestCases generates and returns the happy cases for all network fault injection requests +// Note: A more robust test cases is defined in the actual HTTP handler directory +func generateCommonNetworkFaultInjectionTestCases(requestType, successResponse string, exec execExpectations, requestBody interface{}) []networkFaultTestCase { + tcs := []networkFaultTestCase{ { - name: "start blackhole port success host mode", + name: fmt.Sprintf("%s success host mode", requestType), expectedStatusCode: 200, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse("running"), + requestBody: requestBody, + expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse(successResponse), setStateExpectations: agentStateExpectations, + setExecExpectations: exec, faultInjectionEnabled: true, networkMode: apitask.HostNetworkMode, }, { - name: "start blackhole port success awsvpc mode", + name: fmt.Sprintf("%s success awsvpc mode", requestType), expectedStatusCode: 200, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse("running"), - setStateExpectations: agentStateExpectations, - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: "start blackhole port unknown request body", - expectedStatusCode: 200, - requestBody: map[string]interface{}{ - "Port": port, - "Protocol": protocol, - "TrafficType": trafficType, - "Unknown": "", - }, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse("running"), + requestBody: requestBody, + expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse(successResponse), setStateExpectations: agentStateExpectations, + setExecExpectations: exec, faultInjectionEnabled: true, networkMode: apitask.AWSVPCNetworkMode, }, } - return append(tcs, commonTcs...) -} - -func getStopNetworkBlackHolePortHandlerTestCases() []blackholePortFaultTestCase { - commonTcs := getNetworkBlackHolePortHandlerTestCases("stop blackhole port", faulttype.BlackHolePortFaultType) - tcs := []blackholePortFaultTestCase{ - { - name: "stop blackhole port success host mode", - expectedStatusCode: 200, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse("stopped"), - setStateExpectations: agentStateExpectations, - faultInjectionEnabled: true, - networkMode: apitask.HostNetworkMode, - }, - { - name: "stop blackhole port success awsvpc mode", - expectedStatusCode: 200, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse("stopped"), - setStateExpectations: agentStateExpectations, - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: "stop blackhole port unknown request body", - expectedStatusCode: 200, - requestBody: map[string]interface{}{ - "Port": port, - "Protocol": protocol, - "TrafficType": trafficType, - "Unknown": "", - }, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse("stopped"), - setStateExpectations: agentStateExpectations, - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - } - return append(tcs, commonTcs...) -} - -func getCheckStatusNetworkBlackHolePortHandlerTestCases() []blackholePortFaultTestCase { - commonTcs := getNetworkBlackHolePortHandlerTestCases("start blackhole port", faulttype.BlackHolePortFaultType) - tcs := []blackholePortFaultTestCase{ - { - name: "check blackhole port success host mode running", - expectedStatusCode: 200, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse("running"), - setStateExpectations: agentStateExpectations, - setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { - cmdExec := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder( - exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), - cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), - ) - }, - faultInjectionEnabled: true, - networkMode: apitask.HostNetworkMode, - }, - { - name: "check blackhole port success awsvpc mode running", - expectedStatusCode: 200, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse("running"), - setStateExpectations: agentStateExpectations, - setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { - cmdExec := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder( - exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), - cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), - ) - }, - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: "check blackhole port unknown request body", - expectedStatusCode: 200, - requestBody: map[string]interface{}{ - "Port": port, - "Protocol": protocol, - "TrafficType": trafficType, - "Unknown": "", - }, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse("running"), - setStateExpectations: agentStateExpectations, - setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { - cmdExec := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder( - exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), - cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), - ) - }, - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: "check blackhole port success host mode not running", - expectedStatusCode: 200, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse("not-running"), - setStateExpectations: agentStateExpectations, - setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { - cmdExec := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder( - exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), - cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit 1")), - exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), - exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), - ) - }, - faultInjectionEnabled: true, - networkMode: apitask.HostNetworkMode, - }, - { - name: "check blackhole port success awsvpc mode not running", - expectedStatusCode: 200, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse("not-running"), - setStateExpectations: agentStateExpectations, - setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { - cmdExec := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder( - exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), - cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit 1")), - exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), - exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), - ) - }, - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - { - name: "check blackhole port fail", - expectedStatusCode: 500, - requestBody: happyBlackHolePortReqBody, - expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("internal error"), - setStateExpectations: agentStateExpectations, - setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { - cmdExec := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder( - exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), - cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte("internal error"), errors.New("exit 1")), - exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, false), - ) - }, - faultInjectionEnabled: true, - networkMode: apitask.AWSVPCNetworkMode, - }, - } - return append(tcs, commonTcs...) + return tcs } func TestRegisterStartBlackholePortFaultHandler(t *testing.T) { - tcs := getStartNetworkBlackHolePortHandlerTestCases() + setExecExpectations := func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), + exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + ) + } + tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody) testRegisterFaultHandler(t, tcs, "PUT", faulttype.BlackHolePortFaultType) } func TestRegisterStopBlackholePortFaultHandler(t *testing.T) { - tcs := getStopNetworkBlackHolePortHandlerTestCases() + setExecExpectations := func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + ) + } + tcs := generateCommonNetworkFaultInjectionTestCases("stop blackhole port", "stopped", setExecExpectations, happyBlackHolePortReqBody) testRegisterFaultHandler(t, tcs, "DELETE", faulttype.BlackHolePortFaultType) } func TestRegisterCheckBlackholePortFaultHandler(t *testing.T) { - tcs := getCheckStatusNetworkBlackHolePortHandlerTestCases() + setExecExpectations := func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + ) + } + tcs := generateCommonNetworkFaultInjectionTestCases("check blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody) testRegisterFaultHandler(t, tcs, "GET", faulttype.BlackHolePortFaultType) } -func testRegisterFaultHandler(t *testing.T, tcs []blackholePortFaultTestCase, method, fault string) { +func TestRegisterStartLatencyFaultHandler(t *testing.T) { + // TODO: Will need to set the correct os/exec exectation calls once this is implemented + setExecExpectations := func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + + } + tcs := generateCommonNetworkFaultInjectionTestCases("start latency", "running", setExecExpectations, happyNetworkLatencyReqBody) + testRegisterFaultHandler(t, tcs, "PUT", faulttype.LatencyFaultType) +} + +func TestRegisterStopLatencyFaultHandler(t *testing.T) { + // TODO: Will need to set the correct os/exec exectation calls once this is implemented + setExecExpectations := func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + + } + tcs := generateCommonNetworkFaultInjectionTestCases("stop latency", "stopped", setExecExpectations, happyNetworkLatencyReqBody) + testRegisterFaultHandler(t, tcs, "DELETE", faulttype.LatencyFaultType) +} + +func TestRegisterCheckLatencyFaultHandler(t *testing.T) { + // TODO: Will need to set the correct os/exec exectation calls once this is implemented + setExecExpectations := func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + + } + tcs := generateCommonNetworkFaultInjectionTestCases("check latency", "running", setExecExpectations, happyNetworkLatencyReqBody) + testRegisterFaultHandler(t, tcs, "GET", faulttype.LatencyFaultType) +} + +func TestRegisterStartPacketLossFaultHandler(t *testing.T) { + setExecExpectations := func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil), + ) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(4).Return(mockCMD) + mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil) + } + tcs := generateCommonNetworkFaultInjectionTestCases("start packet loss", "running", setExecExpectations, happyNetworkPacketLossReqBody) + testRegisterFaultHandler(t, tcs, "PUT", faulttype.PacketLossFaultType) +} + +func TestRegisterStopPacketLossFaultHandler(t *testing.T) { + setExecExpectations := func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil), + ) + } + tcs := generateCommonNetworkFaultInjectionTestCases("stop packet loss", "stopped", setExecExpectations, happyNetworkPacketLossReqBody) + testRegisterFaultHandler(t, tcs, "DELETE", faulttype.PacketLossFaultType) +} + +func TestRegisterCheckPacketLossFaultHandler(t *testing.T) { + setExecExpectations := func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultExistsCommandOutput), nil), + ) + } + tcs := generateCommonNetworkFaultInjectionTestCases("check packet loss", "running", setExecExpectations, happyNetworkPacketLossReqBody) + testRegisterFaultHandler(t, tcs, "GET", faulttype.PacketLossFaultType) +} + +func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, method, fault string) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { // Mocks diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 6534bc2fe07..db002641ea9 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -49,7 +49,13 @@ const ( requestTimedOutError = "%s: request timed out" requestTimeoutDuration = 5 * time.Second // Commands that will be used to start/stop/check fault. + iptablesNewChainCmd = "iptables -N %s" + iptablesAppendChainRuleCmd = "iptables -A %s -p %s --dport %s -j DROP" + iptablesInsertChainCmd = "iptables -I %s -j %s" iptablesChainExistCmd = "iptables -C %s -p %s --dport %s -j DROP" + iptablesClearChainCmd = "iptables -F %s" + iptablesDeleteFromTableCmd = "iptables -D %s -j %s" + iptablesDeleteChainCmd = "iptables -X %s" nsenterCommandString = "nsenter --net=%s " tcCheckInjectionCommandString = "tc -j q show dev %s parent 1:1" tcAddQdiscRootCommandString = "tc qdisc add dev %s root handle 1: prio priomap 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2" @@ -121,24 +127,119 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht rwMu.Lock() defer rwMu.Unlock() - // TODO: Check status of current fault injection - // TODO: Invoke the start fault injection functionality if not running + ctx := context.Background() + ctxWithTimeout, cancel := h.osExecWrapper.NewExecContextWithTimeout(ctx, requestTimeoutDuration) + defer cancel() - responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") - logger.Info("Successfully started fault", logger.Fields{ + var responseBody types.NetworkFaultInjectionResponse + var statusCode int + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + taskArn := taskMetadata.TaskARN + stringToBeLogged := "Failed to start fault" + port := strconv.FormatUint(uint64(aws.Uint16Value(request.Port)), 10) + chainName := fmt.Sprintf("%s-%s-%s", aws.StringValue(request.TrafficType), aws.StringValue(request.Protocol), port) + insertTable := "INPUT" + if aws.StringValue(request.TrafficType) == "egress" { + insertTable = "OUTPUT" + } + + cmdOutput, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + networkMode, networkNSPath, insertTable, taskArn) + if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { + statusCode = http.StatusInternalServerError + responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) + } else if cmdErr != nil { + statusCode = http.StatusInternalServerError + responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) + } else { + statusCode = http.StatusOK + responseBody = types.NewNetworkFaultInjectionSuccessResponse("running") + stringToBeLogged = "Successfully started fault" + } + logger.Info(stringToBeLogged, logger.Fields{ field.RequestType: requestType, field.Request: request.ToString(), field.Response: responseBody.ToString(), }) utils.WriteJSONResponse( w, - http.StatusOK, + statusCode, responseBody, requestType, ) } } +// startNetworkBlackholePort will start/inject a new black hole port fault if there isn't one with the specific traffic type, protocol, and port number that's running already. +// The general workflow is as followed: +// 1. Checks if there's not a already running chain with the specified protocol and port number already via checkNetworkBlackHolePort() +// 2. Creates a new chain via `iptables -N ` (the chain name is in the form of "--") +// 3. Appends a new rule to the newly created chain via `iptables -A -p --dport -j DROP` +// 4. Inserts the newly created chain into the built-in INPUT/OUTPUT table +func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol, port, chain, networkMode, netNs, insertTable, taskArn string) (string, error) { + running, cmdOutput, err := h.checkNetworkBlackHolePort(ctx, protocol, port, chain, networkMode, netNs, taskArn) + if err != nil { + return cmdOutput, err + } + if !running { + logger.Info("[INFO] Attempting to start network black hole port fault", logger.Fields{ + "netns": netNs, + "chain": chain, + "taskArn": taskArn, + }) + + // For host mode, the task network namespace is the host network namespace (i.e. we don't need to run nsenter) + nsenterPrefix := "" + if networkMode == ecs.NetworkModeAwsvpc { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, netNs) + } + + // Creating a new chain + newChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesNewChainCmd, chain) + cmdOutput, err := h.runExecCommand(ctx, strings.Split(newChainCmdString, " ")) + if err != nil { + logger.Error("Unable to create new chain", logger.Fields{ + "netns": netNs, + "command": newChainCmdString, + "output": string(cmdOutput), + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + + // Appending a new rule based on the protocol and port number from the request body + appendRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, chain, protocol, port) + cmdOutput, err = h.runExecCommand(ctx, strings.Split(appendRuleCmdString, " ")) + if err != nil { + logger.Error("Unable to append rule to chain", logger.Fields{ + "netns": netNs, + "command": appendRuleCmdString, + "output": string(cmdOutput), + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + + // Inserting the chain into the built-in INPUT/OUTPUT table + insertChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesInsertChainCmd, insertTable, chain) + cmdOutput, err = h.runExecCommand(ctx, strings.Split(insertChainCmdString, " ")) + if err != nil { + logger.Error("Unable to insert chain to table", logger.Fields{ + "netns": netNs, + "command": insertChainCmdString, + "output": string(cmdOutput), + "insertTable": insertTable, + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + } + return "", nil +} + // StopNetworkBlackHolePort will return the request handler function for stopping a network blackhole port fault func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { @@ -172,24 +273,121 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt rwMu.Lock() defer rwMu.Unlock() - // TODO: Check status of current fault injection - // TODO: Invoke the stop fault injection functionality if running + ctx := context.Background() + ctxWithTimeout, cancel := h.osExecWrapper.NewExecContextWithTimeout(ctx, requestTimeoutDuration) + defer cancel() - responseBody := types.NewNetworkFaultInjectionSuccessResponse("stopped") - logger.Info("Successfully stopped fault", logger.Fields{ + var responseBody types.NetworkFaultInjectionResponse + var statusCode int + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + taskArn := taskMetadata.TaskARN + stringToBeLogged := "Failed to stop fault" + port := strconv.FormatUint(uint64(aws.Uint16Value(request.Port)), 10) + chainName := fmt.Sprintf("%s-%s-%s", aws.StringValue(request.TrafficType), aws.StringValue(request.Protocol), port) + insertTable := "INPUT" + if aws.StringValue(request.TrafficType) == "egress" { + insertTable = "OUTPUT" + } + + cmdOutput, cmdErr := h.stopNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + networkMode, networkNSPath, insertTable, taskArn) + + if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { + statusCode = http.StatusInternalServerError + responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) + } else if cmdErr != nil { + statusCode = http.StatusInternalServerError + responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) + } else { + statusCode = http.StatusOK + responseBody = types.NewNetworkFaultInjectionSuccessResponse("stopped") + stringToBeLogged = "Successfully stopped fault" + } + logger.Info(stringToBeLogged, logger.Fields{ field.RequestType: requestType, field.Request: request.ToString(), field.Response: responseBody.ToString(), }) utils.WriteJSONResponse( w, - http.StatusOK, + statusCode, responseBody, requestType, ) } } +// stopNetworkBlackHolePort will stop a black hole port fault based on the chain name which is generated via "--". +// The general workflow is as followed: +// 1. Checks if there's a running chain with the specified protocol and port number via checkNetworkBlackHolePort() +// 2. Clears all rules within the specific chain via `iptables -F ` +// 3. Removes the specific chain from the built-in INPUT/OUTPUT table via `iptables -D -j ` +// 4. Deletes the specific chain via `iptables -X ` +func (h *FaultHandler) stopNetworkBlackHolePort(ctx context.Context, protocol, port, chain, networkMode, netNs, insertTable, taskArn string) (string, error) { + running, cmdOutput, err := h.checkNetworkBlackHolePort(ctx, protocol, port, chain, networkMode, netNs, taskArn) + if err != nil { + return cmdOutput, err + } + if running { + logger.Info("[INFO] Attempting to stop network black hole port fault", logger.Fields{ + "netns": netNs, + "chain": chain, + "taskArn": taskArn, + }) + + // For host mode, the task network namespace is the host network namespace (i.e. we don't need to run nsenter) + nsenterPrefix := "" + if networkMode == ecs.NetworkModeAwsvpc { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, netNs) + } + + // Clearing the appended rules that's associated to the chain + clearChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesClearChainCmd, chain) + cmdOutput, err := h.runExecCommand(ctx, strings.Split(clearChainCmdString, " ")) + if err != nil { + logger.Error("Unable to clear chain", logger.Fields{ + "netns": netNs, + "command": clearChainCmdString, + "output": string(cmdOutput), + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + + // Removing the chain from either the built-in INPUT/OUTPUT table + deleteFromTableCmdString := nsenterPrefix + fmt.Sprintf(iptablesDeleteFromTableCmd, insertTable, chain) + cmdOutput, err = h.runExecCommand(ctx, strings.Split(deleteFromTableCmdString, " ")) + if err != nil { + logger.Error("Unable to delete chain from table", logger.Fields{ + "netns": netNs, + "command": deleteFromTableCmdString, + "output": string(cmdOutput), + "insertTable": insertTable, + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + + // Deleting the chain + deleteChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesDeleteChainCmd, chain) + cmdOutput, err = h.runExecCommand(ctx, strings.Split(deleteChainCmdString, " ")) + if err != nil { + logger.Error("Unable to delete chain", logger.Fields{ + "netns": netNs, + "command": deleteChainCmdString, + "output": string(cmdOutput), + "insertTable": insertTable, + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + } + return "", nil +} + // CheckNetworkBlackHolePort will return the request handler function for checking the status of a network blackhole port fault func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { @@ -224,31 +422,24 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht defer rwMu.RUnlock() ctx := context.Background() - ctxWithTimeout, cancel := context.WithTimeout(ctx, requestTimeoutDuration) + ctxWithTimeout, cancel := h.osExecWrapper.NewExecContextWithTimeout(ctx, requestTimeoutDuration) defer cancel() var responseBody types.NetworkFaultInjectionResponse var statusCode int + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + taskArn := taskMetadata.TaskARN + stringToBeLogged := "Failed to check fault" port := strconv.FormatUint(uint64(aws.Uint16Value(request.Port)), 10) chainName := fmt.Sprintf("%s-%s-%s", aws.StringValue(request.TrafficType), aws.StringValue(request.Protocol), port) running, cmdOutput, cmdErr := h.checkNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, - taskMetadata.TaskNetworkConfig.NetworkMode, taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path) + networkMode, networkNSPath, taskArn) // We've timed out trying to check if the black hole port fault injection is running - if err := ctx.Err(); err == context.DeadlineExceeded { - logger.Error("Request timed out", logger.Fields{ - field.RequestType: requestType, - field.Request: request.ToString(), - field.Error: err, - }) + if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { statusCode = http.StatusInternalServerError responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) } else if cmdErr != nil { - logger.Error("Unknown error encountered for request", logger.Fields{ - field.RequestType: requestType, - field.Request: request.ToString(), - field.Error: cmdErr, - }) statusCode = http.StatusInternalServerError responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) } else { @@ -258,12 +449,13 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht } else { responseBody = types.NewNetworkFaultInjectionSuccessResponse("not-running") } - logger.Info("[INFO] Successfully checked status for fault", logger.Fields{ - field.RequestType: requestType, - field.Request: request.ToString(), - field.Response: responseBody.ToString(), - }) + stringToBeLogged = "Successfully check status fault" } + logger.Info(stringToBeLogged, logger.Fields{ + field.RequestType: requestType, + field.Request: request.ToString(), + field.Response: responseBody.ToString(), + }) utils.WriteJSONResponse( w, statusCode, @@ -274,16 +466,17 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht } // checkNetworkBlackHolePort will check if there's a running black hole port within the task network namespace based on the chain name and the passed in required request fields. -// It does so by calling iptables linux utility tool. -func (h *FaultHandler) checkNetworkBlackHolePort(ctx context.Context, protocol, port, chain, networkMode, netNs string) (bool, string, error) { - cmdString := fmt.Sprintf(iptablesChainExistCmd, chain, protocol, port) - cmdList := strings.Split(cmdString, " ") - +// It does so by calling `iptables -C -p --dport -j DROP`. +func (h *FaultHandler) checkNetworkBlackHolePort(ctx context.Context, protocol, port, chain, networkMode, netNs, taskArn string) (bool, string, error) { // For host mode, the task network namespace is the host network namespace (i.e. we don't need to run nsenter) - if networkMode != ecs.NetworkModeHost { - cmdList = append(strings.Split(fmt.Sprintf(nsenterCommandString, netNs), " "), cmdList...) + nsenterPrefix := "" + if networkMode == ecs.NetworkModeAwsvpc { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, netNs) } + cmdString := nsenterPrefix + fmt.Sprintf(iptablesChainExistCmd, chain, protocol, port) + cmdList := strings.Split(cmdString, " ") + cmdOutput, err := h.runExecCommand(ctx, cmdList) if err != nil { if exitErr, eok := h.osExecWrapper.ConvertToExitError(err); eok { @@ -291,6 +484,7 @@ func (h *FaultHandler) checkNetworkBlackHolePort(ctx context.Context, protocol, "netns": netNs, "command": strings.Join(cmdList, " "), "output": string(cmdOutput), + "taskArn": taskArn, "exitCode": h.osExecWrapper.GetExitCode(exitErr), }) return false, string(cmdOutput), nil @@ -299,6 +493,7 @@ func (h *FaultHandler) checkNetworkBlackHolePort(ctx context.Context, protocol, "netns": netNs, "command": strings.Join(cmdList, " "), "output": string(cmdOutput), + "taskArn": taskArn, "err": err, }) return false, string(cmdOutput), err @@ -307,6 +502,7 @@ func (h *FaultHandler) checkNetworkBlackHolePort(ctx context.Context, protocol, "netns": netNs, "command": strings.Join(cmdList, " "), "output": string(cmdOutput), + "taskArn": taskArn, }) return true, string(cmdOutput), nil } @@ -485,7 +681,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. stringToBeLogged := "Failed to start fault" // All command executions for the start network packet loss workflow all together should finish within 5 seconds. // Thus, create the context here so that it can be shared by all os/exec calls. - ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + ctx, cancel := h.osExecWrapper.NewExecContextWithTimeout(context.Background(), requestTimeoutDuration) defer cancel() // Check the status of current fault injection. latencyFaultExists, packetLossFaultExists, err := h.checkTCFault(ctx, taskMetadata) @@ -534,7 +730,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var request types.NetworkPacketLossRequest - requestType := fmt.Sprintf(startFaultRequestType, types.PacketLossFaultType) + requestType := fmt.Sprintf(stopFaultRequestType, types.PacketLossFaultType) // Parse the fault request err := decodeRequest(w, &request, requestType, r) @@ -564,7 +760,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R stringToBeLogged := "Failed to stop fault" // All command executions for the stop network packet loss workflow all together should finish within 5 seconds. // Thus, create the context here so that it can be shared by all os/exec calls. - ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + ctx, cancel := h.osExecWrapper.NewExecContextWithTimeout(context.Background(), requestTimeoutDuration) defer cancel() // Check the status of current fault injection. _, packetLossFaultExists, err := h.checkTCFault(ctx, taskMetadata) @@ -581,7 +777,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R responseBody = types.NewNetworkFaultInjectionSuccessResponse("stopped") httpStatusCode = http.StatusOK } else { - // Invoke the stop fault injection functionality if not running. + // Invoke the stop fault injection functionality if running. err := h.stopNetworkPacketLossFault(ctx, taskMetadata) if errors.Is(err, context.DeadlineExceeded) { responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) @@ -614,7 +810,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var request types.NetworkPacketLossRequest - requestType := fmt.Sprintf(startFaultRequestType, types.PacketLossFaultType) + requestType := fmt.Sprintf(checkStatusFaultRequestType, types.PacketLossFaultType) // Parse the fault request. err := decodeRequest(w, &request, requestType, r) @@ -645,7 +841,7 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. stringToBeLogged := "Failed to check status for fault" // All command executions for the start network packet loss workflow all together should finish within 5 seconds. // Thus, create the context here so that it can be shared by all os/exec calls. - ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + ctx, cancel := h.osExecWrapper.NewExecContextWithTimeout(context.Background(), requestTimeoutDuration) defer cancel() // Check the status of current fault injection. _, packetLossFaultExists, err := h.checkTCFault(ctx, taskMetadata) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/exec.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/exec.go index d1f94cb82c5..59758e4d976 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/exec.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/exec.go @@ -18,6 +18,7 @@ import ( "io" "os" "os/exec" + "time" ) // Exec acts as a wrapper to functions exposed by the exec package. @@ -27,6 +28,7 @@ type Exec interface { CommandContext(ctx context.Context, name string, arg ...string) Cmd ConvertToExitError(err error) (*exec.ExitError, bool) GetExitCode(exitErr *exec.ExitError) int + NewExecContextWithTimeout(parent context.Context, duration time.Duration) (context.Context, context.CancelFunc) } // execWrapper is a placeholder struct which implements the Exec interface. @@ -53,6 +55,10 @@ func (e *execWrapper) GetExitCode(exitErr *exec.ExitError) int { return exitErr.ExitCode() } +func (e *execWrapper) NewExecContextWithTimeout(parent context.Context, duration time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, duration) +} + // Cmd acts as a wrapper to functions exposed by the exec.Cmd object. // Having this interface enables us to create mock objects we can use // for testing. diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go index 25bff9f5241..2157fa4f779 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go @@ -24,6 +24,7 @@ import ( os "os" exec "os/exec" reflect "reflect" + time "time" execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" gomock "github.com/golang/mock/gomock" @@ -250,3 +251,18 @@ func (mr *MockExecMockRecorder) GetExitCode(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExitCode", reflect.TypeOf((*MockExec)(nil).GetExitCode), arg0) } + +// NewExecContextWithTimeout mocks base method. +func (m *MockExec) NewExecContextWithTimeout(arg0 context.Context, arg1 time.Duration) (context.Context, context.CancelFunc) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewExecContextWithTimeout", arg0, arg1) + ret0, _ := ret[0].(context.Context) + ret1, _ := ret[1].(context.CancelFunc) + return ret0, ret1 +} + +// NewExecContextWithTimeout indicates an expected call of NewExecContextWithTimeout. +func (mr *MockExecMockRecorder) NewExecContextWithTimeout(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewExecContextWithTimeout", reflect.TypeOf((*MockExec)(nil).NewExecContextWithTimeout), arg0, arg1) +} diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 6534bc2fe07..db002641ea9 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -49,7 +49,13 @@ const ( requestTimedOutError = "%s: request timed out" requestTimeoutDuration = 5 * time.Second // Commands that will be used to start/stop/check fault. + iptablesNewChainCmd = "iptables -N %s" + iptablesAppendChainRuleCmd = "iptables -A %s -p %s --dport %s -j DROP" + iptablesInsertChainCmd = "iptables -I %s -j %s" iptablesChainExistCmd = "iptables -C %s -p %s --dport %s -j DROP" + iptablesClearChainCmd = "iptables -F %s" + iptablesDeleteFromTableCmd = "iptables -D %s -j %s" + iptablesDeleteChainCmd = "iptables -X %s" nsenterCommandString = "nsenter --net=%s " tcCheckInjectionCommandString = "tc -j q show dev %s parent 1:1" tcAddQdiscRootCommandString = "tc qdisc add dev %s root handle 1: prio priomap 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2" @@ -121,24 +127,119 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht rwMu.Lock() defer rwMu.Unlock() - // TODO: Check status of current fault injection - // TODO: Invoke the start fault injection functionality if not running + ctx := context.Background() + ctxWithTimeout, cancel := h.osExecWrapper.NewExecContextWithTimeout(ctx, requestTimeoutDuration) + defer cancel() - responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") - logger.Info("Successfully started fault", logger.Fields{ + var responseBody types.NetworkFaultInjectionResponse + var statusCode int + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + taskArn := taskMetadata.TaskARN + stringToBeLogged := "Failed to start fault" + port := strconv.FormatUint(uint64(aws.Uint16Value(request.Port)), 10) + chainName := fmt.Sprintf("%s-%s-%s", aws.StringValue(request.TrafficType), aws.StringValue(request.Protocol), port) + insertTable := "INPUT" + if aws.StringValue(request.TrafficType) == "egress" { + insertTable = "OUTPUT" + } + + cmdOutput, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + networkMode, networkNSPath, insertTable, taskArn) + if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { + statusCode = http.StatusInternalServerError + responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) + } else if cmdErr != nil { + statusCode = http.StatusInternalServerError + responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) + } else { + statusCode = http.StatusOK + responseBody = types.NewNetworkFaultInjectionSuccessResponse("running") + stringToBeLogged = "Successfully started fault" + } + logger.Info(stringToBeLogged, logger.Fields{ field.RequestType: requestType, field.Request: request.ToString(), field.Response: responseBody.ToString(), }) utils.WriteJSONResponse( w, - http.StatusOK, + statusCode, responseBody, requestType, ) } } +// startNetworkBlackholePort will start/inject a new black hole port fault if there isn't one with the specific traffic type, protocol, and port number that's running already. +// The general workflow is as followed: +// 1. Checks if there's not a already running chain with the specified protocol and port number already via checkNetworkBlackHolePort() +// 2. Creates a new chain via `iptables -N ` (the chain name is in the form of "--") +// 3. Appends a new rule to the newly created chain via `iptables -A -p --dport -j DROP` +// 4. Inserts the newly created chain into the built-in INPUT/OUTPUT table +func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol, port, chain, networkMode, netNs, insertTable, taskArn string) (string, error) { + running, cmdOutput, err := h.checkNetworkBlackHolePort(ctx, protocol, port, chain, networkMode, netNs, taskArn) + if err != nil { + return cmdOutput, err + } + if !running { + logger.Info("[INFO] Attempting to start network black hole port fault", logger.Fields{ + "netns": netNs, + "chain": chain, + "taskArn": taskArn, + }) + + // For host mode, the task network namespace is the host network namespace (i.e. we don't need to run nsenter) + nsenterPrefix := "" + if networkMode == ecs.NetworkModeAwsvpc { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, netNs) + } + + // Creating a new chain + newChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesNewChainCmd, chain) + cmdOutput, err := h.runExecCommand(ctx, strings.Split(newChainCmdString, " ")) + if err != nil { + logger.Error("Unable to create new chain", logger.Fields{ + "netns": netNs, + "command": newChainCmdString, + "output": string(cmdOutput), + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + + // Appending a new rule based on the protocol and port number from the request body + appendRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, chain, protocol, port) + cmdOutput, err = h.runExecCommand(ctx, strings.Split(appendRuleCmdString, " ")) + if err != nil { + logger.Error("Unable to append rule to chain", logger.Fields{ + "netns": netNs, + "command": appendRuleCmdString, + "output": string(cmdOutput), + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + + // Inserting the chain into the built-in INPUT/OUTPUT table + insertChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesInsertChainCmd, insertTable, chain) + cmdOutput, err = h.runExecCommand(ctx, strings.Split(insertChainCmdString, " ")) + if err != nil { + logger.Error("Unable to insert chain to table", logger.Fields{ + "netns": netNs, + "command": insertChainCmdString, + "output": string(cmdOutput), + "insertTable": insertTable, + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + } + return "", nil +} + // StopNetworkBlackHolePort will return the request handler function for stopping a network blackhole port fault func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { @@ -172,24 +273,121 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt rwMu.Lock() defer rwMu.Unlock() - // TODO: Check status of current fault injection - // TODO: Invoke the stop fault injection functionality if running + ctx := context.Background() + ctxWithTimeout, cancel := h.osExecWrapper.NewExecContextWithTimeout(ctx, requestTimeoutDuration) + defer cancel() - responseBody := types.NewNetworkFaultInjectionSuccessResponse("stopped") - logger.Info("Successfully stopped fault", logger.Fields{ + var responseBody types.NetworkFaultInjectionResponse + var statusCode int + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + taskArn := taskMetadata.TaskARN + stringToBeLogged := "Failed to stop fault" + port := strconv.FormatUint(uint64(aws.Uint16Value(request.Port)), 10) + chainName := fmt.Sprintf("%s-%s-%s", aws.StringValue(request.TrafficType), aws.StringValue(request.Protocol), port) + insertTable := "INPUT" + if aws.StringValue(request.TrafficType) == "egress" { + insertTable = "OUTPUT" + } + + cmdOutput, cmdErr := h.stopNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + networkMode, networkNSPath, insertTable, taskArn) + + if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { + statusCode = http.StatusInternalServerError + responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) + } else if cmdErr != nil { + statusCode = http.StatusInternalServerError + responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) + } else { + statusCode = http.StatusOK + responseBody = types.NewNetworkFaultInjectionSuccessResponse("stopped") + stringToBeLogged = "Successfully stopped fault" + } + logger.Info(stringToBeLogged, logger.Fields{ field.RequestType: requestType, field.Request: request.ToString(), field.Response: responseBody.ToString(), }) utils.WriteJSONResponse( w, - http.StatusOK, + statusCode, responseBody, requestType, ) } } +// stopNetworkBlackHolePort will stop a black hole port fault based on the chain name which is generated via "--". +// The general workflow is as followed: +// 1. Checks if there's a running chain with the specified protocol and port number via checkNetworkBlackHolePort() +// 2. Clears all rules within the specific chain via `iptables -F ` +// 3. Removes the specific chain from the built-in INPUT/OUTPUT table via `iptables -D -j ` +// 4. Deletes the specific chain via `iptables -X ` +func (h *FaultHandler) stopNetworkBlackHolePort(ctx context.Context, protocol, port, chain, networkMode, netNs, insertTable, taskArn string) (string, error) { + running, cmdOutput, err := h.checkNetworkBlackHolePort(ctx, protocol, port, chain, networkMode, netNs, taskArn) + if err != nil { + return cmdOutput, err + } + if running { + logger.Info("[INFO] Attempting to stop network black hole port fault", logger.Fields{ + "netns": netNs, + "chain": chain, + "taskArn": taskArn, + }) + + // For host mode, the task network namespace is the host network namespace (i.e. we don't need to run nsenter) + nsenterPrefix := "" + if networkMode == ecs.NetworkModeAwsvpc { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, netNs) + } + + // Clearing the appended rules that's associated to the chain + clearChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesClearChainCmd, chain) + cmdOutput, err := h.runExecCommand(ctx, strings.Split(clearChainCmdString, " ")) + if err != nil { + logger.Error("Unable to clear chain", logger.Fields{ + "netns": netNs, + "command": clearChainCmdString, + "output": string(cmdOutput), + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + + // Removing the chain from either the built-in INPUT/OUTPUT table + deleteFromTableCmdString := nsenterPrefix + fmt.Sprintf(iptablesDeleteFromTableCmd, insertTable, chain) + cmdOutput, err = h.runExecCommand(ctx, strings.Split(deleteFromTableCmdString, " ")) + if err != nil { + logger.Error("Unable to delete chain from table", logger.Fields{ + "netns": netNs, + "command": deleteFromTableCmdString, + "output": string(cmdOutput), + "insertTable": insertTable, + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + + // Deleting the chain + deleteChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesDeleteChainCmd, chain) + cmdOutput, err = h.runExecCommand(ctx, strings.Split(deleteChainCmdString, " ")) + if err != nil { + logger.Error("Unable to delete chain", logger.Fields{ + "netns": netNs, + "command": deleteChainCmdString, + "output": string(cmdOutput), + "insertTable": insertTable, + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + } + return "", nil +} + // CheckNetworkBlackHolePort will return the request handler function for checking the status of a network blackhole port fault func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { @@ -224,31 +422,24 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht defer rwMu.RUnlock() ctx := context.Background() - ctxWithTimeout, cancel := context.WithTimeout(ctx, requestTimeoutDuration) + ctxWithTimeout, cancel := h.osExecWrapper.NewExecContextWithTimeout(ctx, requestTimeoutDuration) defer cancel() var responseBody types.NetworkFaultInjectionResponse var statusCode int + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + taskArn := taskMetadata.TaskARN + stringToBeLogged := "Failed to check fault" port := strconv.FormatUint(uint64(aws.Uint16Value(request.Port)), 10) chainName := fmt.Sprintf("%s-%s-%s", aws.StringValue(request.TrafficType), aws.StringValue(request.Protocol), port) running, cmdOutput, cmdErr := h.checkNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, - taskMetadata.TaskNetworkConfig.NetworkMode, taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path) + networkMode, networkNSPath, taskArn) // We've timed out trying to check if the black hole port fault injection is running - if err := ctx.Err(); err == context.DeadlineExceeded { - logger.Error("Request timed out", logger.Fields{ - field.RequestType: requestType, - field.Request: request.ToString(), - field.Error: err, - }) + if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { statusCode = http.StatusInternalServerError responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) } else if cmdErr != nil { - logger.Error("Unknown error encountered for request", logger.Fields{ - field.RequestType: requestType, - field.Request: request.ToString(), - field.Error: cmdErr, - }) statusCode = http.StatusInternalServerError responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) } else { @@ -258,12 +449,13 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht } else { responseBody = types.NewNetworkFaultInjectionSuccessResponse("not-running") } - logger.Info("[INFO] Successfully checked status for fault", logger.Fields{ - field.RequestType: requestType, - field.Request: request.ToString(), - field.Response: responseBody.ToString(), - }) + stringToBeLogged = "Successfully check status fault" } + logger.Info(stringToBeLogged, logger.Fields{ + field.RequestType: requestType, + field.Request: request.ToString(), + field.Response: responseBody.ToString(), + }) utils.WriteJSONResponse( w, statusCode, @@ -274,16 +466,17 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht } // checkNetworkBlackHolePort will check if there's a running black hole port within the task network namespace based on the chain name and the passed in required request fields. -// It does so by calling iptables linux utility tool. -func (h *FaultHandler) checkNetworkBlackHolePort(ctx context.Context, protocol, port, chain, networkMode, netNs string) (bool, string, error) { - cmdString := fmt.Sprintf(iptablesChainExistCmd, chain, protocol, port) - cmdList := strings.Split(cmdString, " ") - +// It does so by calling `iptables -C -p --dport -j DROP`. +func (h *FaultHandler) checkNetworkBlackHolePort(ctx context.Context, protocol, port, chain, networkMode, netNs, taskArn string) (bool, string, error) { // For host mode, the task network namespace is the host network namespace (i.e. we don't need to run nsenter) - if networkMode != ecs.NetworkModeHost { - cmdList = append(strings.Split(fmt.Sprintf(nsenterCommandString, netNs), " "), cmdList...) + nsenterPrefix := "" + if networkMode == ecs.NetworkModeAwsvpc { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, netNs) } + cmdString := nsenterPrefix + fmt.Sprintf(iptablesChainExistCmd, chain, protocol, port) + cmdList := strings.Split(cmdString, " ") + cmdOutput, err := h.runExecCommand(ctx, cmdList) if err != nil { if exitErr, eok := h.osExecWrapper.ConvertToExitError(err); eok { @@ -291,6 +484,7 @@ func (h *FaultHandler) checkNetworkBlackHolePort(ctx context.Context, protocol, "netns": netNs, "command": strings.Join(cmdList, " "), "output": string(cmdOutput), + "taskArn": taskArn, "exitCode": h.osExecWrapper.GetExitCode(exitErr), }) return false, string(cmdOutput), nil @@ -299,6 +493,7 @@ func (h *FaultHandler) checkNetworkBlackHolePort(ctx context.Context, protocol, "netns": netNs, "command": strings.Join(cmdList, " "), "output": string(cmdOutput), + "taskArn": taskArn, "err": err, }) return false, string(cmdOutput), err @@ -307,6 +502,7 @@ func (h *FaultHandler) checkNetworkBlackHolePort(ctx context.Context, protocol, "netns": netNs, "command": strings.Join(cmdList, " "), "output": string(cmdOutput), + "taskArn": taskArn, }) return true, string(cmdOutput), nil } @@ -485,7 +681,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. stringToBeLogged := "Failed to start fault" // All command executions for the start network packet loss workflow all together should finish within 5 seconds. // Thus, create the context here so that it can be shared by all os/exec calls. - ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + ctx, cancel := h.osExecWrapper.NewExecContextWithTimeout(context.Background(), requestTimeoutDuration) defer cancel() // Check the status of current fault injection. latencyFaultExists, packetLossFaultExists, err := h.checkTCFault(ctx, taskMetadata) @@ -534,7 +730,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var request types.NetworkPacketLossRequest - requestType := fmt.Sprintf(startFaultRequestType, types.PacketLossFaultType) + requestType := fmt.Sprintf(stopFaultRequestType, types.PacketLossFaultType) // Parse the fault request err := decodeRequest(w, &request, requestType, r) @@ -564,7 +760,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R stringToBeLogged := "Failed to stop fault" // All command executions for the stop network packet loss workflow all together should finish within 5 seconds. // Thus, create the context here so that it can be shared by all os/exec calls. - ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + ctx, cancel := h.osExecWrapper.NewExecContextWithTimeout(context.Background(), requestTimeoutDuration) defer cancel() // Check the status of current fault injection. _, packetLossFaultExists, err := h.checkTCFault(ctx, taskMetadata) @@ -581,7 +777,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R responseBody = types.NewNetworkFaultInjectionSuccessResponse("stopped") httpStatusCode = http.StatusOK } else { - // Invoke the stop fault injection functionality if not running. + // Invoke the stop fault injection functionality if running. err := h.stopNetworkPacketLossFault(ctx, taskMetadata) if errors.Is(err, context.DeadlineExceeded) { responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) @@ -614,7 +810,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var request types.NetworkPacketLossRequest - requestType := fmt.Sprintf(startFaultRequestType, types.PacketLossFaultType) + requestType := fmt.Sprintf(checkStatusFaultRequestType, types.PacketLossFaultType) // Parse the fault request. err := decodeRequest(w, &request, requestType, r) @@ -645,7 +841,7 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. stringToBeLogged := "Failed to check status for fault" // All command executions for the start network packet loss workflow all together should finish within 5 seconds. // Thus, create the context here so that it can be shared by all os/exec calls. - ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + ctx, cancel := h.osExecWrapper.NewExecContextWithTimeout(context.Background(), requestTimeoutDuration) defer cancel() // Check the status of current fault injection. _, packetLossFaultExists, err := h.checkTCFault(ctx, taskMetadata) diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index 1dbca9bd5f5..014f1ad4e0d 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -18,6 +18,7 @@ package handlers import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -25,6 +26,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types" @@ -51,6 +53,8 @@ const ( awsvpcNetworkMode = "awsvpc" deviceName = "eth0" invalidNetworkMode = "invalid" + internalError = "internal error" + iptablesChainAlreadyExistError = "iptables: Chain already exists." iptablesChainNotFoundError = "iptables: Bad rule (does a matching rule exist in that chain?)." tcLatencyFaultExistsCommandOutput = `[{"kind":"netem","handle":"10:","parent":"1:1","options":{"limit":1000,"delay":{"delay":0.1,"jitter":0,"correlation":0},"ecn":false,"gap":0}}]` tcLossFaultExistsCommandOutput = `[{"kind":"netem","handle":"10:","dev":"eth0","parent":"1:1","options":{"limit":1000,"loss-random":{"loss":0.06,"correlation":0},"ecn":false,"gap":0}}]` @@ -107,6 +111,13 @@ var ( } ipSources = []string{"52.95.154.1", "52.95.154.2"} + + startNetworkBlackHolePortTestPrefix = fmt.Sprintf(startFaultRequestType, types.BlackHolePortFaultType) + stopNetworkBlackHolePortTestPrefix = fmt.Sprintf(stopFaultRequestType, types.BlackHolePortFaultType) + checkNetworkBlackHolePortTestPrefix = fmt.Sprintf(checkStatusFaultRequestType, types.BlackHolePortFaultType) + startNetworkPacketLossTestPrefix = fmt.Sprintf(startFaultRequestType, types.PacketLossFaultType) + stopNetworkPacketLossTestPrefix = fmt.Sprintf(stopFaultRequestType, types.PacketLossFaultType) + checkNetworkPacketLossTestPrefix = fmt.Sprintf(checkStatusFaultRequestType, types.PacketLossFaultType) ) type networkFaultInjectionTestCase struct { @@ -227,7 +238,7 @@ func testNetworkFaultInjectionCommon(t *testing.T, } } -func generateNetworkBlackHolePortTestCases(name string) []networkFaultInjectionTestCase { +func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInjectionTestCase { tcs := []networkFaultInjectionTestCase{ { name: fmt.Sprintf("%s no request body", name), @@ -433,15 +444,36 @@ func generateNetworkBlackHolePortTestCases(name string) []networkFaultInjectionT }, nil).Times(1) }, }, + { + name: fmt.Sprintf("%s request timed out", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, name)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), -1*time.Second) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, errors.New("signal: killed")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, false), + ) + }, + }, } return tcs } func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase { - commonTcs := generateNetworkBlackHolePortTestCases("start blackhole port") + commonTcs := generateCommonNetworkBlackHolePortTestCases(startNetworkBlackHolePortTestPrefix) tcs := []networkFaultInjectionTestCase{ { - name: "start blackhole port success running", + name: fmt.Sprintf("%s success running", startNetworkBlackHolePortTestPrefix), expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), @@ -451,10 +483,25 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase Times(1) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), + exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + ) }, }, { - name: "start blackhole unknown request body", + name: fmt.Sprintf("%s unknown request body", startNetworkBlackHolePortTestPrefix), expectedStatusCode: 200, requestBody: map[string]interface{}{ "Port": port, @@ -469,6 +516,119 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase Times(1) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), + exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + ) + }, + }, + { + name: fmt.Sprintf("%s success already running", startNetworkBlackHolePortTestPrefix), + expectedStatusCode: 200, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + ) + }, + }, + { + name: fmt.Sprintf("%s fail duplicate chain", startNetworkBlackHolePortTestPrefix), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(iptablesChainAlreadyExistError), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), + exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainAlreadyExistError), errors.New("exit status 1")), + ) + }, + }, + { + name: fmt.Sprintf("%s fail append rule to chain", startNetworkBlackHolePortTestPrefix), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), + exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")), + ) + }, + }, + { + name: fmt.Sprintf("%s fail insert chain to table", startNetworkBlackHolePortTestPrefix), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), + exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")), + ) }, }, } @@ -477,10 +637,10 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase } func generateStopBlackHolePortFaultTestCases() []networkFaultInjectionTestCase { - commonTcs := generateNetworkBlackHolePortTestCases("stop blackhole port") + commonTcs := generateCommonNetworkBlackHolePortTestCases(stopNetworkBlackHolePortTestPrefix) tcs := []networkFaultInjectionTestCase{ { - name: "stop blackhole port success running", + name: fmt.Sprintf("%s success running", stopNetworkBlackHolePortTestPrefix), expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("stopped"), @@ -490,10 +650,23 @@ func generateStopBlackHolePortFaultTestCases() []networkFaultInjectionTestCase { Times(1) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + ) }, }, { - name: "stop blackhole unknown request body", + name: fmt.Sprintf("%s unknown request body", stopNetworkBlackHolePortTestPrefix), expectedStatusCode: 200, requestBody: map[string]interface{}{ "Port": port, @@ -508,6 +681,113 @@ func generateStopBlackHolePortFaultTestCases() []networkFaultInjectionTestCase { Times(1) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + ) + }, + }, + { + name: fmt.Sprintf("%s success already stopped", stopNetworkBlackHolePortTestPrefix), + expectedStatusCode: 200, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("stopped"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), + exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), + ) + }, + }, + { + name: fmt.Sprintf("%s fail clear chain", stopNetworkBlackHolePortTestPrefix), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")), + ) + }, + }, + { + name: fmt.Sprintf("%s fail delete chain from table", stopNetworkBlackHolePortTestPrefix), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")), + ) + }, + }, + { + name: fmt.Sprintf("%s fail delete chain", stopNetworkBlackHolePortTestPrefix), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")), + ) }, }, } @@ -515,10 +795,10 @@ func generateStopBlackHolePortFaultTestCases() []networkFaultInjectionTestCase { } func generateCheckBlackHolePortFaultStatusTestCases() []networkFaultInjectionTestCase { - commonTcs := generateNetworkBlackHolePortTestCases("check blackhole port") + commonTcs := generateCommonNetworkBlackHolePortTestCases(checkNetworkBlackHolePortTestPrefix) tcs := []networkFaultInjectionTestCase{ { - name: "check blackhole port success running", + name: fmt.Sprintf("%s success running", checkNetworkBlackHolePortTestPrefix), expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), @@ -528,15 +808,17 @@ func generateCheckBlackHolePortFaultStatusTestCases() []networkFaultInjectionTes Times(1) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) cmdExec := mock_execwrapper.NewMockCmd(ctrl) gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), ) }, }, { - name: "check blackhole unknown request body", + name: fmt.Sprintf("%s unknown request body", checkNetworkBlackHolePortTestPrefix), expectedStatusCode: 200, requestBody: map[string]interface{}{ "Port": port, @@ -551,15 +833,17 @@ func generateCheckBlackHolePortFaultStatusTestCases() []networkFaultInjectionTes Times(1) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) cmdExec := mock_execwrapper.NewMockCmd(ctrl) gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), ) }, }, { - name: "check blackhole port success not running", + name: fmt.Sprintf("%s success not running", checkNetworkBlackHolePortTestPrefix), expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("not-running"), @@ -569,8 +853,10 @@ func generateCheckBlackHolePortFaultStatusTestCases() []networkFaultInjectionTes Times(1) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) cmdExec := mock_execwrapper.NewMockCmd(ctrl) gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), @@ -579,20 +865,22 @@ func generateCheckBlackHolePortFaultStatusTestCases() []networkFaultInjectionTes }, }, { - name: "check blackhole port failure", + name: fmt.Sprintf("%s failure", checkNetworkBlackHolePortTestPrefix), expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, - expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("internal error"), + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { agentState.EXPECT().GetTaskMetadata(endpointId). Return(happyTaskResponse, nil). Times(1) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) cmdExec := mock_execwrapper.NewMockCmd(ctrl) gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), - cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte("internal error"), errors.New("exit 2")), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit 2")), exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, false), ) }, @@ -1088,7 +1376,9 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel) exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte("["), nil) }, @@ -1106,17 +1396,37 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel) exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte{}, errors.New("signal: killed")) }, }, + { + name: "request timed out", + expectedStatusCode: 500, + requestBody: happyNetworkPacketLossReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, name)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), -1*time.Second) + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil), + ) + }, + }, } return tcs } func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase { - commonTcs := generateCommonNetworkPacketLossTestCases("start network-packet-loss") + commonTcs := generateCommonNetworkPacketLossTestCases(startNetworkPacketLossTestPrefix) tcs := []networkFaultInjectionTestCase{ { name: "no-existing-fault", @@ -1127,9 +1437,13 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), - mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil)) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil), + ) exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(4).Return(mockCMD) mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil) }, @@ -1143,7 +1457,9 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel) exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLatencyFaultExistsCommandOutput), nil) }, @@ -1157,7 +1473,9 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel) exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultExistsCommandOutput), nil) }, @@ -1175,9 +1493,13 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), - mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil)) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil), + ) exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(4).Return(mockCMD) mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil) }, @@ -1187,7 +1509,7 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase { } func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase { - commonTcs := generateCommonNetworkPacketLossTestCases("stop network-packet-loss") + commonTcs := generateCommonNetworkPacketLossTestCases(stopNetworkPacketLossTestPrefix) tcs := []networkFaultInjectionTestCase{ { name: "no-existing-fault", @@ -1198,9 +1520,13 @@ func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), - mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil)) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil), + ) }, }, { @@ -1212,7 +1538,9 @@ func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel) exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLatencyFaultExistsCommandOutput), nil) }, @@ -1226,9 +1554,13 @@ func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), - mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultExistsCommandOutput), nil)) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultExistsCommandOutput), nil), + ) exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(3).Return(mockCMD) mockCMD.EXPECT().CombinedOutput().Times(3).Return([]byte(""), nil) }, @@ -1246,9 +1578,13 @@ func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), - mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil)) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil), + ) }, }, } @@ -1256,7 +1592,7 @@ func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase { } func generateCheckNetworkPacketLossTestCases() []networkFaultInjectionTestCase { - commonTcs := generateCommonNetworkPacketLossTestCases("check network-packet-loss") + commonTcs := generateCommonNetworkPacketLossTestCases(checkNetworkPacketLossTestPrefix) tcs := []networkFaultInjectionTestCase{ { name: "no-existing-fault", @@ -1267,9 +1603,13 @@ func generateCheckNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), - mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil)) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil), + ) }, }, { @@ -1281,9 +1621,13 @@ func generateCheckNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), - mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLatencyFaultExistsCommandOutput), nil)) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLatencyFaultExistsCommandOutput), nil), + ) }, }, { @@ -1295,9 +1639,13 @@ func generateCheckNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), - mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultExistsCommandOutput), nil)) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultExistsCommandOutput), nil), + ) }, }, { @@ -1313,9 +1661,13 @@ func generateCheckNetworkPacketLossTestCases() []networkFaultInjectionTestCase { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) mockCMD := mock_execwrapper.NewMockCmd(ctrl) - gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), - mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil)) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil), + ) }, }, } diff --git a/ecs-agent/utils/execwrapper/exec.go b/ecs-agent/utils/execwrapper/exec.go index d1f94cb82c5..59758e4d976 100644 --- a/ecs-agent/utils/execwrapper/exec.go +++ b/ecs-agent/utils/execwrapper/exec.go @@ -18,6 +18,7 @@ import ( "io" "os" "os/exec" + "time" ) // Exec acts as a wrapper to functions exposed by the exec package. @@ -27,6 +28,7 @@ type Exec interface { CommandContext(ctx context.Context, name string, arg ...string) Cmd ConvertToExitError(err error) (*exec.ExitError, bool) GetExitCode(exitErr *exec.ExitError) int + NewExecContextWithTimeout(parent context.Context, duration time.Duration) (context.Context, context.CancelFunc) } // execWrapper is a placeholder struct which implements the Exec interface. @@ -53,6 +55,10 @@ func (e *execWrapper) GetExitCode(exitErr *exec.ExitError) int { return exitErr.ExitCode() } +func (e *execWrapper) NewExecContextWithTimeout(parent context.Context, duration time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, duration) +} + // Cmd acts as a wrapper to functions exposed by the exec.Cmd object. // Having this interface enables us to create mock objects we can use // for testing. diff --git a/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go b/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go index 25bff9f5241..2157fa4f779 100644 --- a/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go +++ b/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go @@ -24,6 +24,7 @@ import ( os "os" exec "os/exec" reflect "reflect" + time "time" execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" gomock "github.com/golang/mock/gomock" @@ -250,3 +251,18 @@ func (mr *MockExecMockRecorder) GetExitCode(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExitCode", reflect.TypeOf((*MockExec)(nil).GetExitCode), arg0) } + +// NewExecContextWithTimeout mocks base method. +func (m *MockExec) NewExecContextWithTimeout(arg0 context.Context, arg1 time.Duration) (context.Context, context.CancelFunc) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewExecContextWithTimeout", arg0, arg1) + ret0, _ := ret[0].(context.Context) + ret1, _ := ret[1].(context.CancelFunc) + return ret0, ret1 +} + +// NewExecContextWithTimeout indicates an expected call of NewExecContextWithTimeout. +func (mr *MockExecMockRecorder) NewExecContextWithTimeout(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewExecContextWithTimeout", reflect.TypeOf((*MockExec)(nil).NewExecContextWithTimeout), arg0, arg1) +}