From af3afb46d1cd6d4cbad9310a1f7a3afb30070822 Mon Sep 17 00:00:00 2001 From: Xing Zheng Date: Mon, 26 Aug 2024 20:39:54 +0000 Subject: [PATCH] Add read/write lock in the fault injection handler --- agent/handlers/task_server_setup.go | 5 +- .../handlers/fault/v1/handlers/handlers.go | 134 ++++++++-- .../handlers/fault/v1/handlers/handlers.go | 134 ++++++++-- .../fault/v1/handlers/handlers_test.go | 228 +++++++++++++----- 4 files changed, 398 insertions(+), 103 deletions(-) diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index ec8affbf175..ec0ec020f23 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -196,10 +196,7 @@ func registerFaultHandlers( agentState *v4.TMDSAgentState, metricsFactory metrics.EntryFactory, ) { - handler := fault.FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } + handler := fault.New(agentState, metricsFactory) if muxRouter == nil { return diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index d550dafbd20..b38a959cf2e 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "net/http" + "sync" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -42,18 +43,36 @@ const ( ) type FaultHandler struct { - // TODO: Mutex will be used in a future PR - // mu sync.Mutex + // mutexMap is used to avoid multiple clients to manipulate same resource at same + // time. The 'key' is the ENI ID or the network namespace path and 'value' is the + // RWMutex. + // Using concurrent map here because the handler is shared by all requests. + mutexMap sync.Map AgentState state.AgentState MetricsFactory metrics.EntryFactory } +func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { + return &FaultHandler{ + AgentState: agentState, + MetricsFactory: mf, + mutexMap: sync.Map{}, + } +} + // NetworkFaultPath will take in a fault type and return the TMDS endpoint path func NetworkFaultPath(fault string) string { return fmt.Sprintf("/api/%s/fault/v1/%s", utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault) } +// loadLock returns the lock associated with given key. +func (h *FaultHandler) loadLock(key string) *sync.RWMutex { + mu := new(sync.RWMutex) + actualMu, _ := h.mutexMap.LoadOrStore(key, mu) + return actualMu.(*sync.RWMutex) +} + // StartNetworkBlackholePort will return the request handler function for starting a network blackhole port fault func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { @@ -72,12 +91,18 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network blackhole port fault would be injected to given task network + // namespace. + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -118,12 +143,16 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -164,14 +193,17 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } - // Check status of current fault injection + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") logger.Info("Successfully checked status for fault", logger.Fields{ @@ -206,12 +238,21 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network latency fault would be injected to given elastic network + // interface/ENI. + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -248,12 +289,19 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -290,12 +338,19 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -331,12 +386,21 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network packet loss fault would be injected to given elastic network + // interface/ENI. + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -373,12 +437,19 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -416,11 +487,19 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. // Obtain the task metadata via the endpoint container ID // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -613,9 +692,24 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return errors.New("empty network namespaces within task network config") } + // Task network namespace path is required to inject faults in the associated task. + if taskNetworkConfig.NetworkNamespaces[0].Path == "" { + return errors.New("no path in the network namespace within task network config") + } + if len(taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces) == 0 || taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces == nil { return errors.New("empty network interfaces within task network config") } + // Device name is required to inject network faults to given ENI in the task. + if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName == "" { + return errors.New("no ENI device name in the network namespace within task network config") + } + + // ENIID is required to avoid race condition where multiple requests are manunipuating same ENI. + if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].ENIID == "" { + return errors.New("no ENI ID in the network namespace within task network config") + } + return nil } diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index d550dafbd20..b38a959cf2e 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "net/http" + "sync" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -42,18 +43,36 @@ const ( ) type FaultHandler struct { - // TODO: Mutex will be used in a future PR - // mu sync.Mutex + // mutexMap is used to avoid multiple clients to manipulate same resource at same + // time. The 'key' is the ENI ID or the network namespace path and 'value' is the + // RWMutex. + // Using concurrent map here because the handler is shared by all requests. + mutexMap sync.Map AgentState state.AgentState MetricsFactory metrics.EntryFactory } +func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { + return &FaultHandler{ + AgentState: agentState, + MetricsFactory: mf, + mutexMap: sync.Map{}, + } +} + // NetworkFaultPath will take in a fault type and return the TMDS endpoint path func NetworkFaultPath(fault string) string { return fmt.Sprintf("/api/%s/fault/v1/%s", utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault) } +// loadLock returns the lock associated with given key. +func (h *FaultHandler) loadLock(key string) *sync.RWMutex { + mu := new(sync.RWMutex) + actualMu, _ := h.mutexMap.LoadOrStore(key, mu) + return actualMu.(*sync.RWMutex) +} + // StartNetworkBlackholePort will return the request handler function for starting a network blackhole port fault func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { @@ -72,12 +91,18 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network blackhole port fault would be injected to given task network + // namespace. + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -118,12 +143,16 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -164,14 +193,17 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } - // Check status of current fault injection + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") logger.Info("Successfully checked status for fault", logger.Fields{ @@ -206,12 +238,21 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network latency fault would be injected to given elastic network + // interface/ENI. + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -248,12 +289,19 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -290,12 +338,19 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -331,12 +386,21 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network packet loss fault would be injected to given elastic network + // interface/ENI. + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -373,12 +437,19 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -416,11 +487,19 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. // Obtain the task metadata via the endpoint container ID // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -613,9 +692,24 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return errors.New("empty network namespaces within task network config") } + // Task network namespace path is required to inject faults in the associated task. + if taskNetworkConfig.NetworkNamespaces[0].Path == "" { + return errors.New("no path in the network namespace within task network config") + } + if len(taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces) == 0 || taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces == nil { return errors.New("empty network interfaces within task network config") } + // Device name is required to inject network faults to given ENI in the task. + if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName == "" { + return errors.New("no ENI device name in the network namespace within task network config") + } + + // ENIID is required to avoid race condition where multiple requests are manunipuating same ENI. + if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].ENIID == "" { + return errors.New("no ENI ID in the network namespace within task network config") + } + return nil } diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index 7b01b8acd91..03e529fe794 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -50,12 +50,28 @@ const ( awsvpcNetworkMode = "awsvpc" deviceName = "eth0" invalidNetworkMode = "invalid" + eniID = "eni-123" ) var ( + noDeviceNameInNetworkInterfaces = []*state.NetworkInterface{ + { + DeviceName: "", + ENIID: eniID, + }, + } + + noENIIDInNetworkInterfaces = []*state.NetworkInterface{ + { + DeviceName: deviceName, + ENIID: "", + }, + } + happyNetworkInterfaces = []*state.NetworkInterface{ { DeviceName: deviceName, + ENIID: eniID, }, } @@ -66,6 +82,13 @@ var ( }, } + noPathInNetworkNamespaces = []*state.NetworkNamespace{ + { + Path: "", + NetworkInterfaces: happyNetworkInterfaces, + }, + } + happyTaskNetworkConfig = state.TaskNetworkConfig{ NetworkMode: awsvpcNetworkMode, NetworkNamespaces: happyNetworkNamespaces, @@ -114,7 +137,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -128,7 +153,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -201,7 +228,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")). + Times(1) }, }, { @@ -211,7 +240,8 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( - "Unable to generate metadata for task")) + "Unable to generate metadata for task")). + Times(1) }, }, { @@ -220,7 +250,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, errors.New("unknown error")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, errors.New("unknown error")). + Times(1) }, }, { @@ -252,10 +284,11 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod }, }, { - name: fmt.Sprintf("%s empty task network config", name), - expectedStatusCode: 500, - requestBody: happyBlackHolePortReqBody, - expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + name: fmt.Sprintf("%s empty task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, @@ -264,6 +297,84 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod }, nil) }, }, + { + name: fmt.Sprintf("%s no task network namespace", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: nil, + }, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s no path in task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: noPathInNetworkNamespaces, + }, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s no device name in task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: []*state.NetworkNamespace{ + &state.NetworkNamespace{ + Path: "/path", + NetworkInterfaces: noDeviceNameInNetworkInterfaces, + }, + }, + }, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s no ENI ID in task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: []*state.NetworkNamespace{ + &state.NetworkNamespace{ + Path: "/path", + NetworkInterfaces: noENIIDInNetworkInterfaces, + }, + }, + }, + }, nil) + }, + }, } return tcs } @@ -285,12 +396,7 @@ func TestStartNetworkBlackHolePort(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.BlackHolePortFaultType), handler.StartNetworkBlackholePort(), @@ -338,12 +444,7 @@ func TestStopNetworkBlackHolePort(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.BlackHolePortFaultType), handler.StopNetworkBlackHolePort(), @@ -388,11 +489,7 @@ func TestCheckNetworkBlackHolePort(t *testing.T) { metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } + handler := New(agentState, metricsFactory) if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) @@ -442,7 +539,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -456,7 +555,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -469,7 +570,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.DelayMilliseconds of type uint64"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -482,7 +585,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.JitterMilliseconds of type uint64"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -495,7 +600,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.Sources of type []*string"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -508,7 +615,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter Sources is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -520,7 +629,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter Sources is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -532,7 +643,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter JitterMilliseconds is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -544,7 +657,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter DelayMilliseconds is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -557,7 +672,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 10.1.2.3.4 for parameter Sources"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -570,7 +687,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 52.95.154.0/33 for parameter Sources"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -579,7 +698,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")). + Times(1) }, }, { @@ -588,8 +709,10 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( - "Unable to generate metadata for task")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( + "Unable to generate metadata for task")). + Times(1) }, }, { @@ -598,7 +721,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, errors.New("unknown error")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, errors.New("unknown error")). + Times(1) }, }, { @@ -663,12 +788,7 @@ func TestStartNetworkLatency(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.LatencyFaultType), handler.StartNetworkLatency(), @@ -716,12 +836,7 @@ func TestStopNetworkLatency(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.LatencyFaultType), handler.StopNetworkLatency(), @@ -766,12 +881,7 @@ func TestCheckNetworkLatency(t *testing.T) { metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) }