diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 158e08fd203..48c8d30808e 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -24,6 +24,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/url" @@ -56,6 +57,7 @@ import ( v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" + mock_netlinkwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks" "github.com/gorilla/mux" "github.com/aws/aws-sdk-go/aws" @@ -65,6 +67,7 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" ) const ( @@ -3948,6 +3951,7 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) { tcs := []struct { name string setStateExpectations func(state *mock_dockerstate.MockTaskEngineState) + setNetlinkExpectations func(netlinkClient *mock_netlinkwrapper.MockNetLink) expectedTaskNetworkConfig *v4.TaskNetworkConfig }{ { @@ -3985,6 +3989,26 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) { state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(), ) }, + setNetlinkExpectations: func(netlinkClient *mock_netlinkwrapper.MockNetLink) { + routes := []netlink.Route{ + netlink.Route{ + Gw: net.ParseIP("10.194.20.1"), + Dst: nil, + LinkIndex: 0, + }, + } + + link := &netlink.Device{ + LinkAttrs: netlink.LinkAttrs{ + Index: 0, + Name: "eth0", + }, + } + gomock.InOrder( + netlinkClient.EXPECT().RouteList(nil, netlink.FAMILY_ALL).Return(routes, nil).AnyTimes(), + netlinkClient.EXPECT().LinkByIndex(link.Attrs().Index).Return(link, nil).AnyTimes(), + ) + }, expectedTaskNetworkConfig: expectedV4TaskNetworkConfig(true, apitask.HostNetworkMode, hostNetworkNamespace, defaultIfname), }, { @@ -4012,11 +4036,17 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_ecs.NewMockECSClient(ctrl) + mock_netlinkClient := mock_netlinkwrapper.NewMockNetLink(ctrl) + if tc.setStateExpectations != nil { tc.setStateExpectations(state) } + + if tc.setNetlinkExpectations != nil { + tc.setNetlinkExpectations(mock_netlinkClient) + } tmdsAgentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn) - actualTaskResponse, err := tmdsAgentState.GetTaskMetadataWithTaskNetworkConfig(v3EndpointID) + actualTaskResponse, err := tmdsAgentState.GetTaskMetadataWithTaskNetworkConfig(v3EndpointID, mock_netlinkClient) assert.NoError(t, err) assert.Equal(t, tc.expectedTaskNetworkConfig, actualTaskResponse.TaskNetworkConfig) diff --git a/agent/handlers/v4/tmdsstate.go b/agent/handlers/v4/tmdsstate.go index af0d8a18f83..a7b8aa3da3e 100644 --- a/agent/handlers/v4/tmdsstate.go +++ b/agent/handlers/v4/tmdsstate.go @@ -21,6 +21,8 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" tmdsv4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" ) const ( @@ -91,21 +93,21 @@ func (s *TMDSAgentState) GetContainerMetadata(v3EndpointID string) (tmdsv4.Conta // Returns task metadata in v4 format for the task identified by the provided endpointContainerID. func (s *TMDSAgentState) GetTaskMetadata(v3EndpointID string) (tmdsv4.TaskResponse, error) { - return s.getTaskMetadata(v3EndpointID, false, false) + return s.getTaskMetadata(v3EndpointID, false, false, nil) } // Returns task metadata including task and container instance tags in v4 format for the // task identified by the provided endpointContainerID. func (s *TMDSAgentState) GetTaskMetadataWithTags(v3EndpointID string) (tmdsv4.TaskResponse, error) { - return s.getTaskMetadata(v3EndpointID, true, false) + return s.getTaskMetadata(v3EndpointID, true, false, nil) } -func (s *TMDSAgentState) GetTaskMetadataWithTaskNetworkConfig(v3EndpointID string) (tmdsv4.TaskResponse, error) { - return s.getTaskMetadata(v3EndpointID, false, true) +func (s *TMDSAgentState) GetTaskMetadataWithTaskNetworkConfig(v3EndpointID string, netlinkClient netlinkwrapper.NetLink) (tmdsv4.TaskResponse, error) { + return s.getTaskMetadata(v3EndpointID, false, true, netlinkClient) } // Returns task metadata in v4 format for the task identified by the provided endpointContainerID. -func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags, includeTaskNetworkConfig bool) (tmdsv4.TaskResponse, error) { +func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags, includeTaskNetworkConfig bool, netlinkClient netlinkwrapper.NetLink) (tmdsv4.TaskResponse, error) { taskARN, ok := s.state.TaskARNByV3EndpointID(v3EndpointID) if !ok { return tmdsv4.TaskResponse{}, tmdsv4.NewErrorLookupFailure(fmt.Sprintf( @@ -163,10 +165,23 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags, inclu taskResponse.FaultInjectionEnabled = task.IsFaultInjectionEnabled() var taskNetworkConfig *tmdsv4.TaskNetworkConfig if task.IsNetworkModeHost() { + deviceName := "" + if netlinkClient != nil { + deviceName, err = netconfig.DefaultNetInterfaceName(netlinkClient) + if err != nil { + logger.Warn("Unable to obtain default network interface name on host for task.", logger.Fields{ + field.TaskARN: taskARN, + field.Error: err, + }) + } + logger.Info("Obtained default network interface name on host for task", logger.Fields{ + field.TaskARN: taskARN, + "defaultInterfaceName": deviceName, + }) + } // For host most, we don't really need the network namespace in order to do anything within the host instance network namespace // and so we will set this to an arbitrary value such as "host". - // TODO: Will need to find/obtain the interface name of the default network interface on the host instance - taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), defaultHostNetworkNamespace, task.GetDefaultIfname()) + taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), defaultHostNetworkNamespace, deviceName) } else { taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), task.GetNetworkNamespace(), task.GetDefaultIfname()) } diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index be87c3a5066..6fe3f7596d7 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -35,6 +35,7 @@ import ( v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" "github.com/aws/aws-sdk-go/aws" "github.com/gorilla/mux" @@ -74,6 +75,7 @@ type FaultHandler struct { AgentState state.AgentState MetricsFactory metrics.EntryFactory osExecWrapper execwrapper.Exec + netlinkClient netlinkwrapper.NetLink } func New(agentState state.AgentState, mf metrics.EntryFactory, execWrapper execwrapper.Exec) *FaultHandler { @@ -82,6 +84,7 @@ func New(agentState state.AgentState, mf metrics.EntryFactory, execWrapper execw MetricsFactory: mf, mutexMap: sync.Map{}, osExecWrapper: execWrapper, + netlinkClient: netlinkwrapper.New(), } } @@ -116,7 +119,7 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -262,7 +265,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt }) // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -410,7 +413,7 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht }) // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -525,7 +528,7 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -572,7 +575,7 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -619,7 +622,7 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -665,7 +668,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -744,7 +747,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -824,7 +827,7 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID. - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -930,10 +933,10 @@ func validateRequest(w http.ResponseWriter, request types.NetworkFaultRequest, r // validateTaskMetadata will first fetch the associated task metadata and then validate it to make sure // the task has enabled fault injection and the corresponding network mode is supported. -func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, requestType string, r *http.Request) (*state.TaskResponse, error) { +func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, requestType string, r *http.Request, netlinkClient netlinkwrapper.NetLink) (*state.TaskResponse, error) { var taskMetadata state.TaskResponse endpointContainerID := mux.Vars(r)[v4.EndpointContainerIDMuxName] - taskMetadata, err := agentState.GetTaskMetadataWithTaskNetworkConfig(endpointContainerID) + taskMetadata, err := agentState.GetTaskMetadataWithTaskNetworkConfig(endpointContainerID, netlinkClient) if err != nil { code, errResponse := getTaskMetadataErrorResponse(endpointContainerID, requestType, err) responseBody := types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("%v", errResponse)) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/state.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/state.go index e7b1cb703be..f556ef4a018 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/state.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/state.go @@ -13,7 +13,11 @@ package state -import "fmt" +import ( + "fmt" + + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" +) // Error to be returned when container or task lookup failed type ErrorLookupFailure struct { @@ -109,7 +113,7 @@ type AgentState interface { // Returns ErrorMetadataFetchFailure if something else goes wrong. GetTaskMetadataWithTags(endpointContainerID string) (TaskResponse, error) - GetTaskMetadataWithTaskNetworkConfig(endpointContainerID string) (TaskResponse, error) + GetTaskMetadataWithTaskNetworkConfig(endpointContainerID string, netlinkClient netlinkwrapper.NetLink) (TaskResponse, error) // Returns container stats in v4 format for the container identified by the provided // endpointContainerID. diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_linux.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_linux.go new file mode 100644 index 00000000000..1df375c5a14 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_linux.go @@ -0,0 +1,61 @@ +//go:build linux +// +build linux + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netconfig + +import ( + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" + + "github.com/vishvananda/netlink" +) + +// DefaultNetInterfaceName returns the device name of the first default network interface +// available on the instance. If none exist, an empty string and nil will be returned. +func DefaultNetInterfaceName(netlinkClient netlinkwrapper.NetLink) (string, error) { + routes, err := netlinkClient.RouteList(nil, netlink.FAMILY_ALL) + if err != nil { + return "", err + } + + // Iterate over all routes + for _, route := range routes { + logger.Debug("Found route", logger.Fields{"Route": route}) + if route.Gw == nil { + // A default route has a gateway. If it doesn't, skip it. + continue + } + + if route.Dst == nil || route.Dst.String() == "0.0.0.0/0" || route.Dst.String() == "::/0" { + // Get the link (interface) associated with the default route + link, err := netlinkClient.LinkByIndex(route.LinkIndex) + if err != nil { + logger.Warn("Not able to get the associated network interface by the index", logger.Fields{ + field.Error: err, + "LinkIndex": route.LinkIndex, + }) + } else { + logger.Debug("Found the associated network interface by the index", logger.Fields{ + "LinkName": link.Attrs().Name, + "LinkIndex": route.LinkIndex, + }) + return link.Attrs().Name, nil + } + } + } + return "", nil +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_windows.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_windows.go new file mode 100644 index 00000000000..f7d24b7f63b --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_windows.go @@ -0,0 +1,25 @@ +//go:build windows +// +build windows + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netconfig + +import "errors" + +// DefaultNetInterfaceName returns the device name of the first default network interface +// available on the instance. This is not supported on windows as of now. +func DefaultNetInterfaceName() (string, error) { + return "", errors.New("Not supported on windows") +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/generate_mocks.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/generate_mocks.go new file mode 100644 index 00000000000..c24db0fdfc3 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/generate_mocks.go @@ -0,0 +1,16 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netlinkwrapper + +//go:generate mockgen -destination=mocks/netlinkwrapper_mocks_linux.go -copyright_file=../../../scripts/copyright_file github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper NetLink diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go new file mode 100644 index 00000000000..6ef7a36e111 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go @@ -0,0 +1,108 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper (interfaces: NetLink) + +// Package mock_netlinkwrapper is a generated GoMock package. +package mock_netlinkwrapper + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + netlink "github.com/vishvananda/netlink" +) + +// MockNetLink is a mock of NetLink interface. +type MockNetLink struct { + ctrl *gomock.Controller + recorder *MockNetLinkMockRecorder +} + +// MockNetLinkMockRecorder is the mock recorder for MockNetLink. +type MockNetLinkMockRecorder struct { + mock *MockNetLink +} + +// NewMockNetLink creates a new mock instance. +func NewMockNetLink(ctrl *gomock.Controller) *MockNetLink { + mock := &MockNetLink{ctrl: ctrl} + mock.recorder = &MockNetLinkMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetLink) EXPECT() *MockNetLinkMockRecorder { + return m.recorder +} + +// LinkByIndex mocks base method. +func (m *MockNetLink) LinkByIndex(arg0 int) (netlink.Link, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkByIndex", arg0) + ret0, _ := ret[0].(netlink.Link) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkByIndex indicates an expected call of LinkByIndex. +func (mr *MockNetLinkMockRecorder) LinkByIndex(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByIndex", reflect.TypeOf((*MockNetLink)(nil).LinkByIndex), arg0) +} + +// LinkByName mocks base method. +func (m *MockNetLink) LinkByName(arg0 string) (netlink.Link, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkByName", arg0) + ret0, _ := ret[0].(netlink.Link) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkByName indicates an expected call of LinkByName. +func (mr *MockNetLinkMockRecorder) LinkByName(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByName", reflect.TypeOf((*MockNetLink)(nil).LinkByName), arg0) +} + +// LinkSetUp mocks base method. +func (m *MockNetLink) LinkSetUp(arg0 netlink.Link) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkSetUp", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// LinkSetUp indicates an expected call of LinkSetUp. +func (mr *MockNetLinkMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetUp", reflect.TypeOf((*MockNetLink)(nil).LinkSetUp), arg0) +} + +// RouteList mocks base method. +func (m *MockNetLink) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RouteList", arg0, arg1) + ret0, _ := ret[0].([]netlink.Route) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RouteList indicates an expected call of RouteList. +func (mr *MockNetLinkMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteList", reflect.TypeOf((*MockNetLink)(nil).RouteList), arg0, arg1) +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/netlink_linux.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/netlink_linux.go new file mode 100644 index 00000000000..dafbece6f21 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/netlink_linux.go @@ -0,0 +1,52 @@ +//go:build !windows +// +build !windows + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netlinkwrapper + +import ( + "github.com/vishvananda/netlink" +) + +type NetLink interface { + LinkByName(name string) (netlink.Link, error) + LinkSetUp(link netlink.Link) error + RouteList(link netlink.Link, family int) ([]netlink.Route, error) + LinkByIndex(index int) (netlink.Link, error) +} + +type netLink struct{} + +func New() NetLink { + return &netLink{} +} + +func (nl *netLink) LinkByName(name string) (netlink.Link, error) { + return netlink.LinkByName(name) +} + +func (nl *netLink) LinkSetUp(link netlink.Link) error { + return netlink.LinkSetUp(link) +} + +// RouteList gets a list of routes in the system. Equivalent to: `ip route show`. +// The list can be filtered by link and ip family. +func (nl *netLink) RouteList(link netlink.Link, family int) ([]netlink.Route, error) { + return netlink.RouteList(link, family) +} + +func (nl *netLink) LinkByIndex(index int) (netlink.Link, error) { + return netlink.LinkByIndex(index) +} diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index 777a834f94f..a5e3c5c0629 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -71,12 +71,15 @@ github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4 github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state github.com/aws/amazon-ecs-agent/ecs-agent/tmds/logging github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/mux +github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig github.com/aws/amazon-ecs-agent/ecs-agent/utils github.com/aws/amazon-ecs-agent/ecs-agent/utils/arn github.com/aws/amazon-ecs-agent/ecs-agent/utils/cipher github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks github.com/aws/amazon-ecs-agent/ecs-agent/utils/httpproxy +github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper +github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry/mock github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index be87c3a5066..6fe3f7596d7 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -35,6 +35,7 @@ import ( v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" "github.com/aws/aws-sdk-go/aws" "github.com/gorilla/mux" @@ -74,6 +75,7 @@ type FaultHandler struct { AgentState state.AgentState MetricsFactory metrics.EntryFactory osExecWrapper execwrapper.Exec + netlinkClient netlinkwrapper.NetLink } func New(agentState state.AgentState, mf metrics.EntryFactory, execWrapper execwrapper.Exec) *FaultHandler { @@ -82,6 +84,7 @@ func New(agentState state.AgentState, mf metrics.EntryFactory, execWrapper execw MetricsFactory: mf, mutexMap: sync.Map{}, osExecWrapper: execWrapper, + netlinkClient: netlinkwrapper.New(), } } @@ -116,7 +119,7 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -262,7 +265,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt }) // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -410,7 +413,7 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht }) // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -525,7 +528,7 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -572,7 +575,7 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -619,7 +622,7 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -665,7 +668,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -744,7 +747,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R } // Obtain the task metadata via the endpoint container ID - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -824,7 +827,7 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID. - taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r, h.netlinkClient) if err != nil { return } @@ -930,10 +933,10 @@ func validateRequest(w http.ResponseWriter, request types.NetworkFaultRequest, r // validateTaskMetadata will first fetch the associated task metadata and then validate it to make sure // the task has enabled fault injection and the corresponding network mode is supported. -func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, requestType string, r *http.Request) (*state.TaskResponse, error) { +func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, requestType string, r *http.Request, netlinkClient netlinkwrapper.NetLink) (*state.TaskResponse, error) { var taskMetadata state.TaskResponse endpointContainerID := mux.Vars(r)[v4.EndpointContainerIDMuxName] - taskMetadata, err := agentState.GetTaskMetadataWithTaskNetworkConfig(endpointContainerID) + taskMetadata, err := agentState.GetTaskMetadataWithTaskNetworkConfig(endpointContainerID, netlinkClient) if err != nil { code, errResponse := getTaskMetadataErrorResponse(endpointContainerID, requestType, err) responseBody := types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("%v", errResponse)) diff --git a/ecs-agent/tmds/handlers/v4/state/mocks/state_mock.go b/ecs-agent/tmds/handlers/v4/state/mocks/state_mock.go index 07bb387f942..5760af8ce6c 100644 --- a/ecs-agent/tmds/handlers/v4/state/mocks/state_mock.go +++ b/ecs-agent/tmds/handlers/v4/state/mocks/state_mock.go @@ -22,6 +22,7 @@ import ( reflect "reflect" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + netlinkwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" gomock "github.com/golang/mock/gomock" ) @@ -109,18 +110,18 @@ func (mr *MockAgentStateMockRecorder) GetTaskMetadataWithTags(arg0 interface{}) } // GetTaskMetadataWithTaskNetworkConfig mocks base method. -func (m *MockAgentState) GetTaskMetadataWithTaskNetworkConfig(arg0 string) (state.TaskResponse, error) { +func (m *MockAgentState) GetTaskMetadataWithTaskNetworkConfig(arg0 string, arg1 netlinkwrapper.NetLink) (state.TaskResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskMetadataWithTaskNetworkConfig", arg0) + ret := m.ctrl.Call(m, "GetTaskMetadataWithTaskNetworkConfig", arg0, arg1) ret0, _ := ret[0].(state.TaskResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTaskMetadataWithTaskNetworkConfig indicates an expected call of GetTaskMetadataWithTaskNetworkConfig. -func (mr *MockAgentStateMockRecorder) GetTaskMetadataWithTaskNetworkConfig(arg0 interface{}) *gomock.Call { +func (mr *MockAgentStateMockRecorder) GetTaskMetadataWithTaskNetworkConfig(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskMetadataWithTaskNetworkConfig", reflect.TypeOf((*MockAgentState)(nil).GetTaskMetadataWithTaskNetworkConfig), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskMetadataWithTaskNetworkConfig", reflect.TypeOf((*MockAgentState)(nil).GetTaskMetadataWithTaskNetworkConfig), arg0, arg1) } // GetTaskStats mocks base method. diff --git a/ecs-agent/tmds/handlers/v4/state/state.go b/ecs-agent/tmds/handlers/v4/state/state.go index e7b1cb703be..f556ef4a018 100644 --- a/ecs-agent/tmds/handlers/v4/state/state.go +++ b/ecs-agent/tmds/handlers/v4/state/state.go @@ -13,7 +13,11 @@ package state -import "fmt" +import ( + "fmt" + + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" +) // Error to be returned when container or task lookup failed type ErrorLookupFailure struct { @@ -109,7 +113,7 @@ type AgentState interface { // Returns ErrorMetadataFetchFailure if something else goes wrong. GetTaskMetadataWithTags(endpointContainerID string) (TaskResponse, error) - GetTaskMetadataWithTaskNetworkConfig(endpointContainerID string) (TaskResponse, error) + GetTaskMetadataWithTaskNetworkConfig(endpointContainerID string, netlinkClient netlinkwrapper.NetLink) (TaskResponse, error) // Returns container stats in v4 format for the container identified by the provided // endpointContainerID.