diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 7d052dc9bdc..48c8d30808e 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -24,6 +24,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/url" @@ -56,6 +57,7 @@ import ( v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" + mock_netlinkwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks" "github.com/gorilla/mux" "github.com/aws/aws-sdk-go/aws" @@ -65,6 +67,7 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" ) const ( @@ -3948,6 +3951,7 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) { tcs := []struct { name string setStateExpectations func(state *mock_dockerstate.MockTaskEngineState) + setNetlinkExpectations func(netlinkClient *mock_netlinkwrapper.MockNetLink) expectedTaskNetworkConfig *v4.TaskNetworkConfig }{ { @@ -3985,6 +3989,26 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) { state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(), ) }, + setNetlinkExpectations: func(netlinkClient *mock_netlinkwrapper.MockNetLink) { + routes := []netlink.Route{ + netlink.Route{ + Gw: net.ParseIP("10.194.20.1"), + Dst: nil, + LinkIndex: 0, + }, + } + + link := &netlink.Device{ + LinkAttrs: netlink.LinkAttrs{ + Index: 0, + Name: "eth0", + }, + } + gomock.InOrder( + netlinkClient.EXPECT().RouteList(nil, netlink.FAMILY_ALL).Return(routes, nil).AnyTimes(), + netlinkClient.EXPECT().LinkByIndex(link.Attrs().Index).Return(link, nil).AnyTimes(), + ) + }, expectedTaskNetworkConfig: expectedV4TaskNetworkConfig(true, apitask.HostNetworkMode, hostNetworkNamespace, defaultIfname), }, { @@ -4012,11 +4036,17 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_ecs.NewMockECSClient(ctrl) + mock_netlinkClient := mock_netlinkwrapper.NewMockNetLink(ctrl) + if tc.setStateExpectations != nil { tc.setStateExpectations(state) } + + if tc.setNetlinkExpectations != nil { + tc.setNetlinkExpectations(mock_netlinkClient) + } tmdsAgentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn) - actualTaskResponse, err := tmdsAgentState.GetTaskMetadata(v3EndpointID) + actualTaskResponse, err := tmdsAgentState.GetTaskMetadataWithTaskNetworkConfig(v3EndpointID, mock_netlinkClient) assert.NoError(t, err) assert.Equal(t, tc.expectedTaskNetworkConfig, actualTaskResponse.TaskNetworkConfig) diff --git a/agent/handlers/v4/tmdsstate.go b/agent/handlers/v4/tmdsstate.go index c82869962a1..60bd2889062 100644 --- a/agent/handlers/v4/tmdsstate.go +++ b/agent/handlers/v4/tmdsstate.go @@ -21,6 +21,8 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" tmdsv4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" ) const ( @@ -91,17 +93,40 @@ func (s *TMDSAgentState) GetContainerMetadata(v3EndpointID string) (tmdsv4.Conta // Returns task metadata in v4 format for the task identified by the provided endpointContainerID. func (s *TMDSAgentState) GetTaskMetadata(v3EndpointID string) (tmdsv4.TaskResponse, error) { - return s.getTaskMetadata(v3EndpointID, false) + return s.getTaskMetadata(v3EndpointID, false, false) } // Returns task metadata including task and container instance tags in v4 format for the // task identified by the provided endpointContainerID. func (s *TMDSAgentState) GetTaskMetadataWithTags(v3EndpointID string) (tmdsv4.TaskResponse, error) { - return s.getTaskMetadata(v3EndpointID, true) + return s.getTaskMetadata(v3EndpointID, true, false) +} + +func (s *TMDSAgentState) GetTaskMetadataWithTaskNetworkConfig(v3EndpointID string, netlinkClient netlinkwrapper.NetLink) (tmdsv4.TaskResponse, error) { + taskResponse, err := s.getTaskMetadata(v3EndpointID, false, true) + if err == nil { + if taskResponse.TaskNetworkConfig.NetworkMode == "host" && taskResponse.TaskNetworkConfig != nil { + taskARN := taskResponse.TaskARN + deviceName, err := netconfig.DefaultNetInterfaceName(netlinkClient) + if err != nil { + logger.Warn("Unable to obtain default network interface name on host for task.", logger.Fields{ + field.TaskARN: taskARN, + field.Error: err, + }) + } else { + logger.Info("Obtained default network interface name on host for task", logger.Fields{ + field.TaskARN: taskARN, + "defaultInterfaceName": deviceName, + }) + taskResponse.TaskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName = deviceName + } + } + } + return taskResponse, err } // Returns task metadata in v4 format for the task identified by the provided endpointContainerID. -func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool) (tmdsv4.TaskResponse, error) { +func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags, includeTaskNetworkConfig bool) (tmdsv4.TaskResponse, error) { taskARN, ok := s.state.TaskARNByV3EndpointID(v3EndpointID) if !ok { return tmdsv4.TaskResponse{}, tmdsv4.NewErrorLookupFailure(fmt.Sprintf( @@ -156,18 +181,18 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool) } taskResponse.FaultInjectionEnabled = task.IsFaultInjectionEnabled() - var taskNetworkConfig *tmdsv4.TaskNetworkConfig - if task.IsNetworkModeHost() { - // For host most, we don't really need the network namespace in order to do anything within the host instance network namespace - // and so we will set this to an arbitrary value such as "host". - // TODO: Will need to find/obtain the interface name of the default network interface on the host instance - taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), defaultHostNetworkNamespace, task.GetDefaultIfname()) - } else { - taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), task.GetNetworkNamespace(), task.GetDefaultIfname()) + if includeTaskNetworkConfig { + var taskNetworkConfig *tmdsv4.TaskNetworkConfig + if task.IsNetworkModeHost() { + // For host most, we don't really need the network namespace in order to do anything within the host instance network namespace + // and so we will set this to an arbitrary value such as "host". + taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), defaultHostNetworkNamespace, task.GetDefaultIfname()) + } else { + taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), task.GetNetworkNamespace(), task.GetDefaultIfname()) + } + taskResponse.TaskNetworkConfig = taskNetworkConfig } - taskResponse.TaskNetworkConfig = taskNetworkConfig - return *taskResponse, nil } diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index db002641ea9..6fe3f7596d7 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -35,6 +35,7 @@ import ( v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" "github.com/aws/aws-sdk-go/aws" "github.com/gorilla/mux" @@ -74,6 +75,7 @@ type FaultHandler struct { AgentState state.AgentState MetricsFactory metrics.EntryFactory osExecWrapper execwrapper.Exec + netlinkClient netlinkwrapper.NetLink } func New(agentState state.AgentState, mf metrics.EntryFactory, execWrapper execwrapper.Exec) *FaultHandler { @@ -82,6 +84,7 @@ func New(agentState state.AgentState, mf metrics.EntryFactory, execWrapper execw MetricsFactory: mf, mutexMap: sync.Map{}, osExecWrapper: execWrapper, + netlinkClient: netlinkwrapper.New(), } } @@ -116,7 +119,7 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -262,7 +265,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt }) // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -410,7 +413,7 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht }) // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -525,7 +528,7 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -572,7 +575,7 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -619,7 +622,7 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -665,7 +668,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -744,7 +747,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -824,7 +827,7 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID. - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -930,10 +933,10 @@ func validateRequest(w http.ResponseWriter, request types.NetworkFaultRequest, r // validateTaskMetadata will first fetch the associated task metadata and then validate it to make sure // the task has enabled fault injection and the corresponding network mode is supported. -func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, requestType string, r *http.Request) (*state.TaskResponse, error) { +func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, requestType string, r *http.Request, netlinkClient netlinkwrapper.NetLink) (*state.TaskResponse, error) { var taskMetadata state.TaskResponse endpointContainerID := mux.Vars(r)[v4.EndpointContainerIDMuxName] - taskMetadata, err := agentState.GetTaskMetadata(endpointContainerID) + taskMetadata, err := agentState.GetTaskMetadataWithTaskNetworkConfig(endpointContainerID, netlinkClient) if err != nil { code, errResponse := getTaskMetadataErrorResponse(endpointContainerID, requestType, err) responseBody := types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("%v", errResponse)) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/state.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/state.go index 0774ba52a2b..f556ef4a018 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/state.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/state.go @@ -13,7 +13,11 @@ package state -import "fmt" +import ( + "fmt" + + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" +) // Error to be returned when container or task lookup failed type ErrorLookupFailure struct { @@ -109,6 +113,8 @@ type AgentState interface { // Returns ErrorMetadataFetchFailure if something else goes wrong. GetTaskMetadataWithTags(endpointContainerID string) (TaskResponse, error) + GetTaskMetadataWithTaskNetworkConfig(endpointContainerID string, netlinkClient netlinkwrapper.NetLink) (TaskResponse, error) + // Returns container stats in v4 format for the container identified by the provided // endpointContainerID. // Returns ErrorStatsLookupFailure if container lookup fails. diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_linux.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_linux.go new file mode 100644 index 00000000000..1df375c5a14 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_linux.go @@ -0,0 +1,61 @@ +//go:build linux +// +build linux + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netconfig + +import ( + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" + + "github.com/vishvananda/netlink" +) + +// DefaultNetInterfaceName returns the device name of the first default network interface +// available on the instance. If none exist, an empty string and nil will be returned. +func DefaultNetInterfaceName(netlinkClient netlinkwrapper.NetLink) (string, error) { + routes, err := netlinkClient.RouteList(nil, netlink.FAMILY_ALL) + if err != nil { + return "", err + } + + // Iterate over all routes + for _, route := range routes { + logger.Debug("Found route", logger.Fields{"Route": route}) + if route.Gw == nil { + // A default route has a gateway. If it doesn't, skip it. + continue + } + + if route.Dst == nil || route.Dst.String() == "0.0.0.0/0" || route.Dst.String() == "::/0" { + // Get the link (interface) associated with the default route + link, err := netlinkClient.LinkByIndex(route.LinkIndex) + if err != nil { + logger.Warn("Not able to get the associated network interface by the index", logger.Fields{ + field.Error: err, + "LinkIndex": route.LinkIndex, + }) + } else { + logger.Debug("Found the associated network interface by the index", logger.Fields{ + "LinkName": link.Attrs().Name, + "LinkIndex": route.LinkIndex, + }) + return link.Attrs().Name, nil + } + } + } + return "", nil +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_windows.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_windows.go new file mode 100644 index 00000000000..f7d24b7f63b --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_windows.go @@ -0,0 +1,25 @@ +//go:build windows +// +build windows + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netconfig + +import "errors" + +// DefaultNetInterfaceName returns the device name of the first default network interface +// available on the instance. This is not supported on windows as of now. +func DefaultNetInterfaceName() (string, error) { + return "", errors.New("Not supported on windows") +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/generate_mocks.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/generate_mocks.go new file mode 100644 index 00000000000..c24db0fdfc3 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/generate_mocks.go @@ -0,0 +1,16 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netlinkwrapper + +//go:generate mockgen -destination=mocks/netlinkwrapper_mocks_linux.go -copyright_file=../../../scripts/copyright_file github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper NetLink diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go new file mode 100644 index 00000000000..6ef7a36e111 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go @@ -0,0 +1,108 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper (interfaces: NetLink) + +// Package mock_netlinkwrapper is a generated GoMock package. +package mock_netlinkwrapper + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + netlink "github.com/vishvananda/netlink" +) + +// MockNetLink is a mock of NetLink interface. +type MockNetLink struct { + ctrl *gomock.Controller + recorder *MockNetLinkMockRecorder +} + +// MockNetLinkMockRecorder is the mock recorder for MockNetLink. +type MockNetLinkMockRecorder struct { + mock *MockNetLink +} + +// NewMockNetLink creates a new mock instance. +func NewMockNetLink(ctrl *gomock.Controller) *MockNetLink { + mock := &MockNetLink{ctrl: ctrl} + mock.recorder = &MockNetLinkMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetLink) EXPECT() *MockNetLinkMockRecorder { + return m.recorder +} + +// LinkByIndex mocks base method. +func (m *MockNetLink) LinkByIndex(arg0 int) (netlink.Link, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkByIndex", arg0) + ret0, _ := ret[0].(netlink.Link) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkByIndex indicates an expected call of LinkByIndex. +func (mr *MockNetLinkMockRecorder) LinkByIndex(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByIndex", reflect.TypeOf((*MockNetLink)(nil).LinkByIndex), arg0) +} + +// LinkByName mocks base method. +func (m *MockNetLink) LinkByName(arg0 string) (netlink.Link, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkByName", arg0) + ret0, _ := ret[0].(netlink.Link) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkByName indicates an expected call of LinkByName. +func (mr *MockNetLinkMockRecorder) LinkByName(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByName", reflect.TypeOf((*MockNetLink)(nil).LinkByName), arg0) +} + +// LinkSetUp mocks base method. +func (m *MockNetLink) LinkSetUp(arg0 netlink.Link) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkSetUp", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// LinkSetUp indicates an expected call of LinkSetUp. +func (mr *MockNetLinkMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetUp", reflect.TypeOf((*MockNetLink)(nil).LinkSetUp), arg0) +} + +// RouteList mocks base method. +func (m *MockNetLink) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RouteList", arg0, arg1) + ret0, _ := ret[0].([]netlink.Route) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RouteList indicates an expected call of RouteList. +func (mr *MockNetLinkMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteList", reflect.TypeOf((*MockNetLink)(nil).RouteList), arg0, arg1) +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/netlink_linux.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/netlink_linux.go new file mode 100644 index 00000000000..dafbece6f21 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/netlink_linux.go @@ -0,0 +1,52 @@ +//go:build !windows +// +build !windows + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netlinkwrapper + +import ( + "github.com/vishvananda/netlink" +) + +type NetLink interface { + LinkByName(name string) (netlink.Link, error) + LinkSetUp(link netlink.Link) error + RouteList(link netlink.Link, family int) ([]netlink.Route, error) + LinkByIndex(index int) (netlink.Link, error) +} + +type netLink struct{} + +func New() NetLink { + return &netLink{} +} + +func (nl *netLink) LinkByName(name string) (netlink.Link, error) { + return netlink.LinkByName(name) +} + +func (nl *netLink) LinkSetUp(link netlink.Link) error { + return netlink.LinkSetUp(link) +} + +// RouteList gets a list of routes in the system. Equivalent to: `ip route show`. +// The list can be filtered by link and ip family. +func (nl *netLink) RouteList(link netlink.Link, family int) ([]netlink.Route, error) { + return netlink.RouteList(link, family) +} + +func (nl *netLink) LinkByIndex(index int) (netlink.Link, error) { + return netlink.LinkByIndex(index) +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/netlink_windows.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/netlink_windows.go new file mode 100644 index 00000000000..7e5b50c30d4 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/netlink_windows.go @@ -0,0 +1,50 @@ +//go:build windows +// +build windows + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netlinkwrapper + +import ( + "github.com/vishvananda/netlink" +) + +type NetLink interface { + LinkByName(name string) (netlink.Link, error) + LinkSetUp(link netlink.Link) error + RouteList(link netlink.Link, family int) ([]netlink.Route, error) + LinkByIndex(index int) (netlink.Link, error) +} + +type netLink struct{} + +func New() NetLink { + return &netLink{} +} + +func (nl *netLink) LinkByName(name string) (netlink.Link, error) { + return nil, nil +} + +func (nl *netLink) LinkSetUp(link netlink.Link) error { + return nil +} + +func (nl *netLink) RouteList(link netlink.Link, family int) ([]netlink.Route, error) { + return nil, nil +} + +func (nl *netLink) LinkByIndex(index int) (netlink.Link, error) { + return nil, nil +} diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index 777a834f94f..a5e3c5c0629 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -71,12 +71,15 @@ github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4 github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state github.com/aws/amazon-ecs-agent/ecs-agent/tmds/logging github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/mux +github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig github.com/aws/amazon-ecs-agent/ecs-agent/utils github.com/aws/amazon-ecs-agent/ecs-agent/utils/arn github.com/aws/amazon-ecs-agent/ecs-agent/utils/cipher github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks github.com/aws/amazon-ecs-agent/ecs-agent/utils/httpproxy +github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper +github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry/mock github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index db002641ea9..6fe3f7596d7 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -35,6 +35,7 @@ import ( v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" "github.com/aws/aws-sdk-go/aws" "github.com/gorilla/mux" @@ -74,6 +75,7 @@ type FaultHandler struct { AgentState state.AgentState MetricsFactory metrics.EntryFactory osExecWrapper execwrapper.Exec + netlinkClient netlinkwrapper.NetLink } func New(agentState state.AgentState, mf metrics.EntryFactory, execWrapper execwrapper.Exec) *FaultHandler { @@ -82,6 +84,7 @@ func New(agentState state.AgentState, mf metrics.EntryFactory, execWrapper execw MetricsFactory: mf, mutexMap: sync.Map{}, osExecWrapper: execWrapper, + netlinkClient: netlinkwrapper.New(), } } @@ -116,7 +119,7 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -262,7 +265,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt }) // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -410,7 +413,7 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht }) // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -525,7 +528,7 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -572,7 +575,7 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -619,7 +622,7 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -665,7 +668,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -744,7 +747,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -824,7 +827,7 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID. - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -930,10 +933,10 @@ func validateRequest(w http.ResponseWriter, request types.NetworkFaultRequest, r // validateTaskMetadata will first fetch the associated task metadata and then validate it to make sure // the task has enabled fault injection and the corresponding network mode is supported. -func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, requestType string, r *http.Request) (*state.TaskResponse, error) { +func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, requestType string, r *http.Request, netlinkClient netlinkwrapper.NetLink) (*state.TaskResponse, error) { var taskMetadata state.TaskResponse endpointContainerID := mux.Vars(r)[v4.EndpointContainerIDMuxName] - taskMetadata, err := agentState.GetTaskMetadata(endpointContainerID) + taskMetadata, err := agentState.GetTaskMetadataWithTaskNetworkConfig(endpointContainerID, netlinkClient) if err != nil { code, errResponse := getTaskMetadataErrorResponse(endpointContainerID, requestType, err) responseBody := types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("%v", errResponse)) diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index 014f1ad4e0d..450b7d35f77 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -34,6 +34,7 @@ import ( state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" mock_state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/mocks" mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" + mock_netlinkwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks" "github.com/golang/mock/gomock" "github.com/gorilla/mux" @@ -125,7 +126,7 @@ type networkFaultInjectionTestCase struct { expectedStatusCode int requestBody interface{} expectedResponseBody types.NetworkFaultInjectionResponse - setAgentStateExpectations func(agentState *mock_state.MockAgentState) + setAgentStateExpectations func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) setExecExpectations func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) } @@ -159,8 +160,11 @@ func testNetworkFaultInjectionCommon(t *testing.T, router := mux.NewRouter() mockExec := mock_execwrapper.NewMockExec(ctrl) handler := New(agentState, metricsFactory, mockExec) + if tc.setAgentStateExpectations != nil { - tc.setAgentStateExpectations(agentState) + mock_netlinkClient := mock_netlinkwrapper.NewMockNetLink(ctrl) + handler.netlinkClient = mock_netlinkClient + tc.setAgentStateExpectations(agentState, mock_netlinkClient) } if tc.setExecExpectations != nil { tc.setExecExpectations(mockExec, ctrl) @@ -245,8 +249,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje expectedStatusCode: 400, requestBody: nil, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required request body is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -258,8 +262,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje "TrafficType": trafficType, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkBlackholePortRequest.Port of type uint16"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -270,8 +274,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje "Protocol": protocol, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter TrafficType is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -283,8 +287,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje "TrafficType": "", }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter TrafficType is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -296,8 +300,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje "TrafficType": trafficType, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value invalid for parameter Protocol"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -309,8 +313,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje "TrafficType": "invalid", }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value invalid for parameter TrafficType"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -318,8 +322,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje expectedStatusCode: 404, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")). Times(1) }, @@ -329,8 +333,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( "Unable to generate metadata for task")). Times(1) }, @@ -340,8 +344,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, errors.New("unknown error")). Times(1) }, @@ -351,8 +355,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje expectedStatusCode: 400, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(faultInjectionEnabledError, taskARN)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: false, }, nil).Times(1) @@ -363,8 +367,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje expectedStatusCode: 400, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(invalidNetworkModeError, invalidNetworkMode)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: true, TaskNetworkConfig: &state.TaskNetworkConfig{ @@ -380,8 +384,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: true, TaskNetworkConfig: nil, @@ -394,8 +398,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: true, TaskNetworkConfig: &state.TaskNetworkConfig{ @@ -411,8 +415,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: true, TaskNetworkConfig: &state.TaskNetworkConfig{ @@ -428,8 +432,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: true, TaskNetworkConfig: &state.TaskNetworkConfig{ @@ -449,8 +453,8 @@ func generateCommonNetworkBlackHolePortTestCases(name string) []networkFaultInje expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, name)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -477,8 +481,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -510,8 +514,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase "Unknown": "", }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -538,8 +542,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -558,8 +562,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(iptablesChainAlreadyExistError), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -582,8 +586,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -608,8 +612,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -644,8 +648,8 @@ func generateStopBlackHolePortFaultTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("stopped"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -675,8 +679,8 @@ func generateStopBlackHolePortFaultTestCases() []networkFaultInjectionTestCase { "Unknown": "", }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("stopped"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -701,8 +705,8 @@ func generateStopBlackHolePortFaultTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("stopped"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -723,8 +727,8 @@ func generateStopBlackHolePortFaultTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -745,8 +749,8 @@ func generateStopBlackHolePortFaultTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -769,8 +773,8 @@ func generateStopBlackHolePortFaultTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -802,8 +806,8 @@ func generateCheckBlackHolePortFaultStatusTestCases() []networkFaultInjectionTes expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -827,8 +831,8 @@ func generateCheckBlackHolePortFaultStatusTestCases() []networkFaultInjectionTes "Unknown": "", }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -847,8 +851,8 @@ func generateCheckBlackHolePortFaultStatusTestCases() []networkFaultInjectionTes expectedStatusCode: 200, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("not-running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -869,8 +873,8 @@ func generateCheckBlackHolePortFaultStatusTestCases() []networkFaultInjectionTes expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -917,8 +921,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n expectedStatusCode: 200, requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -933,8 +937,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n "Unknown": "", }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(happyTaskResponse, nil). Times(1) }, @@ -944,8 +948,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n expectedStatusCode: 400, requestBody: nil, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required request body is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -957,8 +961,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n "Sources": ipSources, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.DelayMilliseconds of type uint64"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, nil). Times(0) }, @@ -972,8 +976,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n "Sources": ipSources, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.JitterMilliseconds of type uint64"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, nil). Times(0) }, @@ -987,8 +991,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n "Sources": "", }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.Sources of type []*string"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, nil). Times(0) }, @@ -1002,8 +1006,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n "Sources": []string{}, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter Sources is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, nil). Times(0) }, @@ -1016,8 +1020,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n "JitterMilliseconds": jitterMilliseconds, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter Sources is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, nil). Times(0) }, @@ -1030,8 +1034,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n "Sources": ipSources, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter JitterMilliseconds is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, nil). Times(0) }, @@ -1044,8 +1048,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n "Sources": ipSources, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter DelayMilliseconds is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, nil). Times(0) }, @@ -1059,8 +1063,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n "Sources": []string{"10.1.2.3.4"}, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 10.1.2.3.4 for parameter Sources"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, nil). Times(0) }, @@ -1074,8 +1078,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n "Sources": []string{"52.95.154.0/33"}, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 52.95.154.0/33 for parameter Sources"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, nil). Times(0) }, @@ -1085,8 +1089,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n expectedStatusCode: 404, requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")). Times(1) }, @@ -1096,8 +1100,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n expectedStatusCode: 500, requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( "Unable to generate metadata for task")). Times(1) @@ -1108,8 +1112,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n expectedStatusCode: 500, requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId). + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient). Return(state.TaskResponse{}, errors.New("unknown error")). Times(1) }, @@ -1119,8 +1123,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n expectedStatusCode: 400, requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(faultInjectionEnabledError, taskARN)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: false, }, nil) @@ -1131,8 +1135,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n expectedStatusCode: 400, requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(invalidNetworkModeError, invalidNetworkMode)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: true, TaskNetworkConfig: &state.TaskNetworkConfig{ @@ -1147,8 +1151,8 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n expectedStatusCode: 500, requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: true, TaskNetworkConfig: nil, @@ -1184,8 +1188,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Sources": ipSources, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkPacketLossRequest.LossPercent of type uint64"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -1196,8 +1200,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Sources": "", }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkPacketLossRequest.Sources of type []*string"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -1208,8 +1212,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Sources": []string{}, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter Sources is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -1219,8 +1223,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "LossPercent": lossPercent, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter Sources is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -1230,8 +1234,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Sources": ipSources, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter LossPercent is missing"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -1242,8 +1246,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Sources": ipSources, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal number -1 into Go struct field NetworkPacketLossRequest.LossPercent of type uint64"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -1254,8 +1258,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Sources": ipSources, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 101 for parameter LossPercent"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -1266,8 +1270,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Sources": ipSources, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 0 for parameter LossPercent"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -1278,8 +1282,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Sources": []string{"10.1.2.3.4"}, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 10.1.2.3.4 for parameter Sources"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -1290,8 +1294,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Sources": []string{"52.95.154.0/33"}, }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 52.95.154.0/33 for parameter Sources"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, nil).Times(0) }, }, { @@ -1299,8 +1303,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti expectedStatusCode: 404, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")) }, }, { @@ -1308,8 +1312,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti expectedStatusCode: 500, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( "Unable to generate metadata for task")) }, }, @@ -1318,8 +1322,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti expectedStatusCode: 500, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, errors.New("unknown error")) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{}, errors.New("unknown error")) }, }, { @@ -1327,8 +1331,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti expectedStatusCode: 400, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(faultInjectionEnabledError, taskARN)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: false, }, nil) @@ -1339,8 +1343,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti expectedStatusCode: 400, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(invalidNetworkModeError, invalidNetworkMode)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: true, TaskNetworkConfig: &state.TaskNetworkConfig{ @@ -1355,8 +1359,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti expectedStatusCode: 500, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: true, TaskNetworkConfig: nil, @@ -1372,8 +1376,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Unknown": "", }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("failed to check existing network fault: failed to unmarshal tc command output: unexpected end of JSON input. TaskArn: taskArn"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1392,8 +1396,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Unknown": "", }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("failed to check existing network fault: 'nsenter --net=/some/path tc -j q show dev eth0 parent 1:1' command failed with the following error: 'signal: killed'. std output: ''. TaskArn: taskArn"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1408,8 +1412,8 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti expectedStatusCode: 500, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, name)), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), -1*time.Second) @@ -1433,8 +1437,8 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 200, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1453,8 +1457,8 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 409, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("There is already one network latency fault running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1469,8 +1473,8 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 409, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("There is already one network packet loss fault running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1489,8 +1493,8 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase { "Unknown": "", }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1516,8 +1520,8 @@ func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 200, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("stopped"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1534,8 +1538,8 @@ func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 200, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("stopped"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1550,8 +1554,8 @@ func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 200, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("stopped"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1574,8 +1578,8 @@ func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase { "Unknown": "", }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("stopped"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1599,8 +1603,8 @@ func generateCheckNetworkPacketLossTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 200, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("not-running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1617,8 +1621,8 @@ func generateCheckNetworkPacketLossTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 200, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("not-running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1635,8 +1639,8 @@ func generateCheckNetworkPacketLossTestCases() []networkFaultInjectionTestCase { expectedStatusCode: 200, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) @@ -1657,8 +1661,8 @@ func generateCheckNetworkPacketLossTestCases() []networkFaultInjectionTestCase { "Unknown": "", }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("not-running"), - setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netlinkClient *mock_netlinkwrapper.MockNetLink) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netlinkClient).Return(happyTaskResponse, nil) }, setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeoutDuration) diff --git a/ecs-agent/tmds/handlers/v4/state/mocks/state_mock.go b/ecs-agent/tmds/handlers/v4/state/mocks/state_mock.go index 35d769eaca8..5760af8ce6c 100644 --- a/ecs-agent/tmds/handlers/v4/state/mocks/state_mock.go +++ b/ecs-agent/tmds/handlers/v4/state/mocks/state_mock.go @@ -22,6 +22,7 @@ import ( reflect "reflect" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + netlinkwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" gomock "github.com/golang/mock/gomock" ) @@ -108,6 +109,21 @@ func (mr *MockAgentStateMockRecorder) GetTaskMetadataWithTags(arg0 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskMetadataWithTags", reflect.TypeOf((*MockAgentState)(nil).GetTaskMetadataWithTags), arg0) } +// GetTaskMetadataWithTaskNetworkConfig mocks base method. +func (m *MockAgentState) GetTaskMetadataWithTaskNetworkConfig(arg0 string, arg1 netlinkwrapper.NetLink) (state.TaskResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTaskMetadataWithTaskNetworkConfig", arg0, arg1) + ret0, _ := ret[0].(state.TaskResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTaskMetadataWithTaskNetworkConfig indicates an expected call of GetTaskMetadataWithTaskNetworkConfig. +func (mr *MockAgentStateMockRecorder) GetTaskMetadataWithTaskNetworkConfig(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskMetadataWithTaskNetworkConfig", reflect.TypeOf((*MockAgentState)(nil).GetTaskMetadataWithTaskNetworkConfig), arg0, arg1) +} + // GetTaskStats mocks base method. func (m *MockAgentState) GetTaskStats(arg0 string) (map[string]*state.StatsResponse, error) { m.ctrl.T.Helper() diff --git a/ecs-agent/tmds/handlers/v4/state/state.go b/ecs-agent/tmds/handlers/v4/state/state.go index 0774ba52a2b..f556ef4a018 100644 --- a/ecs-agent/tmds/handlers/v4/state/state.go +++ b/ecs-agent/tmds/handlers/v4/state/state.go @@ -13,7 +13,11 @@ package state -import "fmt" +import ( + "fmt" + + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" +) // Error to be returned when container or task lookup failed type ErrorLookupFailure struct { @@ -109,6 +113,8 @@ type AgentState interface { // Returns ErrorMetadataFetchFailure if something else goes wrong. GetTaskMetadataWithTags(endpointContainerID string) (TaskResponse, error) + GetTaskMetadataWithTaskNetworkConfig(endpointContainerID string, netlinkClient netlinkwrapper.NetLink) (TaskResponse, error) + // Returns container stats in v4 format for the container identified by the provided // endpointContainerID. // Returns ErrorStatsLookupFailure if container lookup fails. diff --git a/ecs-agent/utils/netlinkwrapper/netlink_windows.go b/ecs-agent/utils/netlinkwrapper/netlink_windows.go new file mode 100644 index 00000000000..7e5b50c30d4 --- /dev/null +++ b/ecs-agent/utils/netlinkwrapper/netlink_windows.go @@ -0,0 +1,50 @@ +//go:build windows +// +build windows + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netlinkwrapper + +import ( + "github.com/vishvananda/netlink" +) + +type NetLink interface { + LinkByName(name string) (netlink.Link, error) + LinkSetUp(link netlink.Link) error + RouteList(link netlink.Link, family int) ([]netlink.Route, error) + LinkByIndex(index int) (netlink.Link, error) +} + +type netLink struct{} + +func New() NetLink { + return &netLink{} +} + +func (nl *netLink) LinkByName(name string) (netlink.Link, error) { + return nil, nil +} + +func (nl *netLink) LinkSetUp(link netlink.Link) error { + return nil +} + +func (nl *netLink) RouteList(link netlink.Link, family int) ([]netlink.Route, error) { + return nil, nil +} + +func (nl *netLink) LinkByIndex(index int) (netlink.Link, error) { + return nil, nil +}