From 66464f399ee684bf3999016bb36d4bd2418a6f27 Mon Sep 17 00:00:00 2001 From: Xing Zheng Date: Mon, 26 Aug 2024 20:39:54 +0000 Subject: [PATCH 1/3] Add read/write lock in the fault injection handler --- agent/handlers/task_server_setup.go | 5 +- agent/handlers/v4/tmdsstate.go | 3 +- .../handlers/fault/v1/handlers/handlers.go | 134 ++++++++-- .../handlers/fault/v1/handlers/handlers.go | 134 ++++++++-- .../fault/v1/handlers/handlers_test.go | 249 ++++++++++++------ 5 files changed, 403 insertions(+), 122 deletions(-) diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index ec8affbf175..ec0ec020f23 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -196,10 +196,7 @@ func registerFaultHandlers( agentState *v4.TMDSAgentState, metricsFactory metrics.EntryFactory, ) { - handler := fault.FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } + handler := fault.New(agentState, metricsFactory) if muxRouter == nil { return diff --git a/agent/handlers/v4/tmdsstate.go b/agent/handlers/v4/tmdsstate.go index 4daefdbb766..36f9c45f4b2 100644 --- a/agent/handlers/v4/tmdsstate.go +++ b/agent/handlers/v4/tmdsstate.go @@ -161,7 +161,8 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool) Path: task.GetNetworkNamespace(), NetworkInterfaces: []*tmdsv4.NetworkInterface{ { - DeviceName: "", + ENIID: "eni-fake-id", + DeviceName: "ethx", }, }, }, diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index d550dafbd20..b38a959cf2e 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "net/http" + "sync" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -42,18 +43,36 @@ const ( ) type FaultHandler struct { - // TODO: Mutex will be used in a future PR - // mu sync.Mutex + // mutexMap is used to avoid multiple clients to manipulate same resource at same + // time. The 'key' is the ENI ID or the network namespace path and 'value' is the + // RWMutex. + // Using concurrent map here because the handler is shared by all requests. + mutexMap sync.Map AgentState state.AgentState MetricsFactory metrics.EntryFactory } +func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { + return &FaultHandler{ + AgentState: agentState, + MetricsFactory: mf, + mutexMap: sync.Map{}, + } +} + // NetworkFaultPath will take in a fault type and return the TMDS endpoint path func NetworkFaultPath(fault string) string { return fmt.Sprintf("/api/%s/fault/v1/%s", utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault) } +// loadLock returns the lock associated with given key. +func (h *FaultHandler) loadLock(key string) *sync.RWMutex { + mu := new(sync.RWMutex) + actualMu, _ := h.mutexMap.LoadOrStore(key, mu) + return actualMu.(*sync.RWMutex) +} + // StartNetworkBlackholePort will return the request handler function for starting a network blackhole port fault func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { @@ -72,12 +91,18 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network blackhole port fault would be injected to given task network + // namespace. + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -118,12 +143,16 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -164,14 +193,17 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } - // Check status of current fault injection + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") logger.Info("Successfully checked status for fault", logger.Fields{ @@ -206,12 +238,21 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network latency fault would be injected to given elastic network + // interface/ENI. + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -248,12 +289,19 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -290,12 +338,19 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -331,12 +386,21 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network packet loss fault would be injected to given elastic network + // interface/ENI. + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -373,12 +437,19 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -416,11 +487,19 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. // Obtain the task metadata via the endpoint container ID // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -613,9 +692,24 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return errors.New("empty network namespaces within task network config") } + // Task network namespace path is required to inject faults in the associated task. + if taskNetworkConfig.NetworkNamespaces[0].Path == "" { + return errors.New("no path in the network namespace within task network config") + } + if len(taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces) == 0 || taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces == nil { return errors.New("empty network interfaces within task network config") } + // Device name is required to inject network faults to given ENI in the task. + if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName == "" { + return errors.New("no ENI device name in the network namespace within task network config") + } + + // ENIID is required to avoid race condition where multiple requests are manunipuating same ENI. + if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].ENIID == "" { + return errors.New("no ENI ID in the network namespace within task network config") + } + return nil } diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index d550dafbd20..b38a959cf2e 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "net/http" + "sync" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -42,18 +43,36 @@ const ( ) type FaultHandler struct { - // TODO: Mutex will be used in a future PR - // mu sync.Mutex + // mutexMap is used to avoid multiple clients to manipulate same resource at same + // time. The 'key' is the ENI ID or the network namespace path and 'value' is the + // RWMutex. + // Using concurrent map here because the handler is shared by all requests. + mutexMap sync.Map AgentState state.AgentState MetricsFactory metrics.EntryFactory } +func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { + return &FaultHandler{ + AgentState: agentState, + MetricsFactory: mf, + mutexMap: sync.Map{}, + } +} + // NetworkFaultPath will take in a fault type and return the TMDS endpoint path func NetworkFaultPath(fault string) string { return fmt.Sprintf("/api/%s/fault/v1/%s", utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault) } +// loadLock returns the lock associated with given key. +func (h *FaultHandler) loadLock(key string) *sync.RWMutex { + mu := new(sync.RWMutex) + actualMu, _ := h.mutexMap.LoadOrStore(key, mu) + return actualMu.(*sync.RWMutex) +} + // StartNetworkBlackholePort will return the request handler function for starting a network blackhole port fault func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { @@ -72,12 +91,18 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network blackhole port fault would be injected to given task network + // namespace. + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -118,12 +143,16 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -164,14 +193,17 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } - // Check status of current fault injection + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") logger.Info("Successfully checked status for fault", logger.Fields{ @@ -206,12 +238,21 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network latency fault would be injected to given elastic network + // interface/ENI. + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -248,12 +289,19 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -290,12 +338,19 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -331,12 +386,21 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // the network packet loss fault would be injected to given elastic network + // interface/ENI. + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -373,12 +437,19 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -416,11 +487,19 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. // Obtain the task metadata via the endpoint container ID // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + eniID := taskMetadata.TaskNetworkConfig. + NetworkNamespaces[0]. + NetworkInterfaces[0]. + ENIID + rwMu := h.loadLock(eniID) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -613,9 +692,24 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return errors.New("empty network namespaces within task network config") } + // Task network namespace path is required to inject faults in the associated task. + if taskNetworkConfig.NetworkNamespaces[0].Path == "" { + return errors.New("no path in the network namespace within task network config") + } + if len(taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces) == 0 || taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces == nil { return errors.New("empty network interfaces within task network config") } + // Device name is required to inject network faults to given ENI in the task. + if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName == "" { + return errors.New("no ENI device name in the network namespace within task network config") + } + + // ENIID is required to avoid race condition where multiple requests are manunipuating same ENI. + if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].ENIID == "" { + return errors.New("no ENI ID in the network namespace within task network config") + } + return nil } diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index 7b01b8acd91..47b17ad592b 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -50,12 +50,28 @@ const ( awsvpcNetworkMode = "awsvpc" deviceName = "eth0" invalidNetworkMode = "invalid" + eniID = "eni-123" ) var ( + noDeviceNameInNetworkInterfaces = []*state.NetworkInterface{ + { + DeviceName: "", + ENIID: eniID, + }, + } + + noENIIDInNetworkInterfaces = []*state.NetworkInterface{ + { + DeviceName: deviceName, + ENIID: "", + }, + } + happyNetworkInterfaces = []*state.NetworkInterface{ { DeviceName: deviceName, + ENIID: eniID, }, } @@ -66,6 +82,13 @@ var ( }, } + noPathInNetworkNamespaces = []*state.NetworkNamespace{ + { + Path: "", + NetworkInterfaces: happyNetworkInterfaces, + }, + } + happyTaskNetworkConfig = state.TaskNetworkConfig{ NetworkMode: awsvpcNetworkMode, NetworkNamespaces: happyNetworkNamespaces, @@ -114,7 +137,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -128,7 +153,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -201,7 +228,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")). + Times(1) }, }, { @@ -211,7 +240,8 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( - "Unable to generate metadata for task")) + "Unable to generate metadata for task")). + Times(1) }, }, { @@ -220,7 +250,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, errors.New("unknown error")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, errors.New("unknown error")). + Times(1) }, }, { @@ -252,10 +284,11 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod }, }, { - name: fmt.Sprintf("%s empty task network config", name), - expectedStatusCode: 500, - requestBody: happyBlackHolePortReqBody, - expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + name: fmt.Sprintf("%s empty task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, @@ -264,6 +297,84 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod }, nil) }, }, + { + name: fmt.Sprintf("%s no task network namespace", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: nil, + }, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s no path in task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: noPathInNetworkNamespaces, + }, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s no device name in task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: []*state.NetworkNamespace{ + &state.NetworkNamespace{ + Path: "/path", + NetworkInterfaces: noDeviceNameInNetworkInterfaces, + }, + }, + }, + }, nil) + }, + }, + { + name: fmt.Sprintf("%s no ENI ID in task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: []*state.NetworkNamespace{ + &state.NetworkNamespace{ + Path: "/path", + NetworkInterfaces: noENIIDInNetworkInterfaces, + }, + }, + }, + }, nil) + }, + }, } return tcs } @@ -285,12 +396,7 @@ func TestStartNetworkBlackHolePort(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.BlackHolePortFaultType), handler.StartNetworkBlackholePort(), @@ -338,12 +444,7 @@ func TestStopNetworkBlackHolePort(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.BlackHolePortFaultType), handler.StopNetworkBlackHolePort(), @@ -388,11 +489,7 @@ func TestCheckNetworkBlackHolePort(t *testing.T) { metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } + handler := New(agentState, metricsFactory) if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) @@ -442,7 +539,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -456,7 +555,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -469,7 +570,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.DelayMilliseconds of type uint64"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -482,7 +585,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.JitterMilliseconds of type uint64"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -495,7 +600,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.Sources of type []*string"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -508,7 +615,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter Sources is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -520,7 +629,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter Sources is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -532,7 +643,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter JitterMilliseconds is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -544,7 +657,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter DelayMilliseconds is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -557,7 +672,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 10.1.2.3.4 for parameter Sources"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -570,7 +687,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 52.95.154.0/33 for parameter Sources"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -579,7 +698,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")). + Times(1) }, }, { @@ -588,8 +709,10 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( - "Unable to generate metadata for task")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( + "Unable to generate metadata for task")). + Times(1) }, }, { @@ -598,7 +721,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, errors.New("unknown error")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, errors.New("unknown error")). + Times(1) }, }, { @@ -663,12 +788,7 @@ func TestStartNetworkLatency(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.LatencyFaultType), handler.StartNetworkLatency(), @@ -716,12 +836,7 @@ func TestStopNetworkLatency(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.LatencyFaultType), handler.StopNetworkLatency(), @@ -766,12 +881,7 @@ func TestCheckNetworkLatency(t *testing.T) { metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) } @@ -1043,12 +1153,7 @@ func TestStartNetworkPacketLoss(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.PacketLossFaultType), handler.StartNetworkPacketLoss(), @@ -1096,12 +1201,7 @@ func TestStopNetworkPacketLoss(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.PacketLossFaultType), handler.StopNetworkPacketLoss(), @@ -1146,12 +1246,7 @@ func TestCheckNetworkPacketLoss(t *testing.T) { metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) } From 8c868c015bc6a9c12a04d5b767f7b53a65ff9cae Mon Sep 17 00:00:00 2001 From: Xing Zheng Date: Tue, 3 Sep 2024 19:12:42 +0000 Subject: [PATCH 2/3] Remove ENI ID which is used for RWMutex --- agent/handlers/v4/tmdsstate.go | 1 - .../handlers/fault/v1/handlers/handlers.go | 65 +++++++------------ .../tmds/handlers/v4/state/response.go | 2 - .../handlers/fault/v1/handlers/handlers.go | 65 +++++++------------ .../fault/v1/handlers/handlers_test.go | 44 ++----------- ecs-agent/tmds/handlers/v4/handlers_test.go | 1 - ecs-agent/tmds/handlers/v4/state/response.go | 2 - 7 files changed, 50 insertions(+), 130 deletions(-) diff --git a/agent/handlers/v4/tmdsstate.go b/agent/handlers/v4/tmdsstate.go index 36f9c45f4b2..01b2b1c3858 100644 --- a/agent/handlers/v4/tmdsstate.go +++ b/agent/handlers/v4/tmdsstate.go @@ -161,7 +161,6 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool) Path: task.GetNetworkNamespace(), NetworkInterfaces: []*tmdsv4.NetworkInterface{ { - ENIID: "eni-fake-id", DeviceName: "ethx", }, }, diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index b38a959cf2e..438d0cdbc74 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -44,8 +44,7 @@ const ( type FaultHandler struct { // mutexMap is used to avoid multiple clients to manipulate same resource at same - // time. The 'key' is the ENI ID or the network namespace path and 'value' is the - // RWMutex. + // time. The 'key' is the the network namespace path and 'value' is the RWMutex. // Using concurrent map here because the handler is shared by all requests. mutexMap sync.Map AgentState state.AgentState @@ -96,8 +95,7 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht return } - // the network blackhole port fault would be injected to given task network - // namespace. + // To avoid multiple requests to manipulate same network resource networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path rwMu := h.loadLock(networkNSPath) rwMu.Lock() @@ -148,6 +146,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt return } + // To avoid multiple requests to manipulate same network resource networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path rwMu := h.loadLock(networkNSPath) rwMu.Lock() @@ -198,6 +197,7 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht return } + // To avoid multiple requests to manipulate same network resource networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path rwMu := h.loadLock(networkNSPath) rwMu.RLock() @@ -243,13 +243,9 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req return } - // the network latency fault would be injected to given elastic network - // interface/ENI. - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.Lock() defer rwMu.Unlock() @@ -294,11 +290,9 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ return } - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.Lock() defer rwMu.Unlock() @@ -343,11 +337,9 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req return } - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.RLock() defer rwMu.RUnlock() @@ -391,13 +383,9 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. return } - // the network packet loss fault would be injected to given elastic network - // interface/ENI. - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.Lock() defer rwMu.Unlock() @@ -442,11 +430,9 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R return } - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.Lock() defer rwMu.Unlock() @@ -492,11 +478,9 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. return } - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.RLock() defer rwMu.RUnlock() @@ -706,10 +690,5 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return errors.New("no ENI device name in the network namespace within task network config") } - // ENIID is required to avoid race condition where multiple requests are manunipuating same ENI. - if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].ENIID == "" { - return errors.New("no ENI ID in the network namespace within task network config") - } - return nil } diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go index 124c2f961c7..2e1b9b8e468 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go @@ -56,8 +56,6 @@ type NetworkNamespace struct { type NetworkInterface struct { // DeviceName is the device name on the host. DeviceName string - // ENIID is the id of eni. - ENIID string } // Instance's clock drift status diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index b38a959cf2e..438d0cdbc74 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -44,8 +44,7 @@ const ( type FaultHandler struct { // mutexMap is used to avoid multiple clients to manipulate same resource at same - // time. The 'key' is the ENI ID or the network namespace path and 'value' is the - // RWMutex. + // time. The 'key' is the the network namespace path and 'value' is the RWMutex. // Using concurrent map here because the handler is shared by all requests. mutexMap sync.Map AgentState state.AgentState @@ -96,8 +95,7 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht return } - // the network blackhole port fault would be injected to given task network - // namespace. + // To avoid multiple requests to manipulate same network resource networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path rwMu := h.loadLock(networkNSPath) rwMu.Lock() @@ -148,6 +146,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt return } + // To avoid multiple requests to manipulate same network resource networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path rwMu := h.loadLock(networkNSPath) rwMu.Lock() @@ -198,6 +197,7 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht return } + // To avoid multiple requests to manipulate same network resource networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path rwMu := h.loadLock(networkNSPath) rwMu.RLock() @@ -243,13 +243,9 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req return } - // the network latency fault would be injected to given elastic network - // interface/ENI. - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.Lock() defer rwMu.Unlock() @@ -294,11 +290,9 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ return } - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.Lock() defer rwMu.Unlock() @@ -343,11 +337,9 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req return } - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.RLock() defer rwMu.RUnlock() @@ -391,13 +383,9 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. return } - // the network packet loss fault would be injected to given elastic network - // interface/ENI. - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.Lock() defer rwMu.Unlock() @@ -442,11 +430,9 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R return } - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.Lock() defer rwMu.Unlock() @@ -492,11 +478,9 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. return } - eniID := taskMetadata.TaskNetworkConfig. - NetworkNamespaces[0]. - NetworkInterfaces[0]. - ENIID - rwMu := h.loadLock(eniID) + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) rwMu.RLock() defer rwMu.RUnlock() @@ -706,10 +690,5 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return errors.New("no ENI device name in the network namespace within task network config") } - // ENIID is required to avoid race condition where multiple requests are manunipuating same ENI. - if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].ENIID == "" { - return errors.New("no ENI ID in the network namespace within task network config") - } - return nil } diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index 47b17ad592b..c810a48b94c 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -50,28 +50,18 @@ const ( awsvpcNetworkMode = "awsvpc" deviceName = "eth0" invalidNetworkMode = "invalid" - eniID = "eni-123" ) var ( noDeviceNameInNetworkInterfaces = []*state.NetworkInterface{ { DeviceName: "", - ENIID: eniID, - }, - } - - noENIIDInNetworkInterfaces = []*state.NetworkInterface{ - { - DeviceName: deviceName, - ENIID: "", }, } happyNetworkInterfaces = []*state.NetworkInterface{ { DeviceName: deviceName, - ENIID: eniID, }, } @@ -264,7 +254,7 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: false, - }, nil) + }, nil).Times(1) }, }, { @@ -280,7 +270,7 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod NetworkMode: invalidNetworkMode, NetworkNamespaces: happyNetworkNamespaces, }, - }, nil) + }, nil).Times(1) }, }, { @@ -294,7 +284,7 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: true, TaskNetworkConfig: nil, - }, nil) + }, nil).Times(1) }, }, { @@ -311,7 +301,7 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod NetworkMode: awsvpcNetworkMode, NetworkNamespaces: nil, }, - }, nil) + }, nil).Times(1) }, }, { @@ -328,7 +318,7 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod NetworkMode: awsvpcNetworkMode, NetworkNamespaces: noPathInNetworkNamespaces, }, - }, nil) + }, nil).Times(1) }, }, { @@ -350,29 +340,7 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod }, }, }, - }, nil) - }, - }, - { - name: fmt.Sprintf("%s no ENI ID in task network config", name), - expectedStatusCode: 500, - requestBody: happyBlackHolePortReqBody, - expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( - fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ - TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, - FaultInjectionEnabled: true, - TaskNetworkConfig: &state.TaskNetworkConfig{ - NetworkMode: awsvpcNetworkMode, - NetworkNamespaces: []*state.NetworkNamespace{ - &state.NetworkNamespace{ - Path: "/path", - NetworkInterfaces: noENIIDInNetworkInterfaces, - }, - }, - }, - }, nil) + }, nil).Times(1) }, }, } diff --git a/ecs-agent/tmds/handlers/v4/handlers_test.go b/ecs-agent/tmds/handlers/v4/handlers_test.go index 55d7a5012ec..388170f7244 100644 --- a/ecs-agent/tmds/handlers/v4/handlers_test.go +++ b/ecs-agent/tmds/handlers/v4/handlers_test.go @@ -167,7 +167,6 @@ func taskResponse() *state.TaskResponse { NetworkInterfaces: []*state.NetworkInterface{ &state.NetworkInterface{ DeviceName: "eth1", - ENIID: "eni-013ff4ad5747a0f6a", }, }, }, diff --git a/ecs-agent/tmds/handlers/v4/state/response.go b/ecs-agent/tmds/handlers/v4/state/response.go index 124c2f961c7..2e1b9b8e468 100644 --- a/ecs-agent/tmds/handlers/v4/state/response.go +++ b/ecs-agent/tmds/handlers/v4/state/response.go @@ -56,8 +56,6 @@ type NetworkNamespace struct { type NetworkInterface struct { // DeviceName is the device name on the host. DeviceName string - // ENIID is the id of eni. - ENIID string } // Instance's clock drift status From c58e734921f4a47e568a0b4e779d50690eaa4384 Mon Sep 17 00:00:00 2001 From: Xing Zheng Date: Thu, 5 Sep 2024 16:59:36 +0000 Subject: [PATCH 3/3] Add a comment --- agent/handlers/v4/tmdsstate.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/agent/handlers/v4/tmdsstate.go b/agent/handlers/v4/tmdsstate.go index 01b2b1c3858..d033512164b 100644 --- a/agent/handlers/v4/tmdsstate.go +++ b/agent/handlers/v4/tmdsstate.go @@ -161,7 +161,12 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool) Path: task.GetNetworkNamespace(), NetworkInterfaces: []*tmdsv4.NetworkInterface{ { - DeviceName: "ethx", + // TODO: fetch the correct device name. + // We are exposing this information via AgentState to facilitate the fault injection + // handler to start/stop/check network faults. + // Use 'eth0'(a fake value) for existing fault injection related unit tests for now and + // it will be updated in the future. + DeviceName: "eth0", }, }, },