diff --git a/agent/handlers/task_server_setup_linux_test.go b/agent/handlers/task_server_setup_linux_test.go index 4f9bae31c68..03b09b40243 100644 --- a/agent/handlers/task_server_setup_linux_test.go +++ b/agent/handlers/task_server_setup_linux_test.go @@ -17,6 +17,7 @@ package handlers import ( + "net" "testing" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" @@ -25,9 +26,11 @@ import ( mock_stats "github.com/aws/amazon-ecs-agent/agent/stats/mock" mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + mock_netlinkwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "github.com/vishvananda/netlink" ) func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) { @@ -35,6 +38,7 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) { tcs := []struct { name string setStateExpectations func(state *mock_dockerstate.MockTaskEngineState) + setNetLinkExpectations func(netLink *mock_netlinkwrapper.MockNetLink) expectedTaskNetworkConfig *v4.TaskNetworkConfig }{ { @@ -72,6 +76,25 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) { state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(), ) }, + setNetLinkExpectations: func(netLink *mock_netlinkwrapper.MockNetLink) { + routes := []netlink.Route{ + netlink.Route{ + Gw: net.ParseIP("10.194.20.1"), + Dst: nil, + LinkIndex: 0, + }, + } + link := &netlink.Device{ + LinkAttrs: netlink.LinkAttrs{ + Index: 0, + Name: "eth0", + }, + } + gomock.InOrder( + netLink.EXPECT().RouteList(nil, netlink.FAMILY_ALL).Return(routes, nil).AnyTimes(), + netLink.EXPECT().LinkByIndex(link.Attrs().Index).Return(link, nil).AnyTimes(), + ) + }, expectedTaskNetworkConfig: expectedV4TaskNetworkConfig(true, apitask.HostNetworkMode, hostNetworkNamespace, defaultIfname), }, { @@ -103,7 +126,16 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) { tc.setStateExpectations(state) } tmdsAgentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn) - actualTaskResponse, err := tmdsAgentState.GetTaskMetadataWithTaskNetworkConfig(v3EndpointID, nil) + + netConfigClient := v4.NewNetworkConfigClient() + + if tc.setNetLinkExpectations != nil { + mock_netlinkwrapper := mock_netlinkwrapper.NewMockNetLink(ctrl) + tc.setNetLinkExpectations(mock_netlinkwrapper) + netConfigClient.NetlinkClient = mock_netlinkwrapper + } + + actualTaskResponse, err := tmdsAgentState.GetTaskMetadataWithTaskNetworkConfig(v3EndpointID, netConfigClient) assert.NoError(t, err) assert.Equal(t, tc.expectedTaskNetworkConfig, actualTaskResponse.TaskNetworkConfig) diff --git a/agent/handlers/v4/tmdsstate_linux.go b/agent/handlers/v4/tmdsstate_linux.go index 2e25145ee2d..157923b1fec 100644 --- a/agent/handlers/v4/tmdsstate_linux.go +++ b/agent/handlers/v4/tmdsstate_linux.go @@ -17,14 +17,24 @@ package v4 import ( + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" tmdsv4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig" ) func (s *TMDSAgentState) GetTaskMetadataWithTaskNetworkConfig(v3EndpointID string, networkConfigClient *tmdsv4.NetworkConfigClient) (tmdsv4.TaskResponse, error) { taskResponse, err := s.getTaskMetadata(v3EndpointID, false, true) if err == nil { if taskResponse.TaskNetworkConfig != nil && taskResponse.TaskNetworkConfig.NetworkMode == "host" { - taskResponse.TaskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName = "" + hostDeviceName, err := netconfig.DefaultNetInterfaceName(networkConfigClient.NetlinkClient) + if err != nil { + logger.Warn("Unable to obtain default network interface on host", logger.Fields{ + field.TaskARN: taskResponse.TaskARN, + field.Error: err, + }) + } + taskResponse.TaskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName = hostDeviceName } } return taskResponse, err diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index fbcc9b77aeb..df5986df263 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -933,7 +933,7 @@ func validateRequest(w http.ResponseWriter, request types.NetworkFaultRequest, r func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, requestType string, r *http.Request) (*state.TaskResponse, error) { var taskMetadata state.TaskResponse endpointContainerID := mux.Vars(r)[v4.EndpointContainerIDMuxName] - taskMetadata, err := agentState.GetTaskMetadataWithTaskNetworkConfig(endpointContainerID, nil) + taskMetadata, err := agentState.GetTaskMetadataWithTaskNetworkConfig(endpointContainerID, state.NewNetworkConfigClient()) if err != nil { code, errResponse := getTaskMetadataErrorResponse(endpointContainerID, requestType, err) responseBody := types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("%v", errResponse)) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_linux.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_linux.go new file mode 100644 index 00000000000..1df375c5a14 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_linux.go @@ -0,0 +1,61 @@ +//go:build linux +// +build linux + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netconfig + +import ( + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" + + "github.com/vishvananda/netlink" +) + +// DefaultNetInterfaceName returns the device name of the first default network interface +// available on the instance. If none exist, an empty string and nil will be returned. +func DefaultNetInterfaceName(netlinkClient netlinkwrapper.NetLink) (string, error) { + routes, err := netlinkClient.RouteList(nil, netlink.FAMILY_ALL) + if err != nil { + return "", err + } + + // Iterate over all routes + for _, route := range routes { + logger.Debug("Found route", logger.Fields{"Route": route}) + if route.Gw == nil { + // A default route has a gateway. If it doesn't, skip it. + continue + } + + if route.Dst == nil || route.Dst.String() == "0.0.0.0/0" || route.Dst.String() == "::/0" { + // Get the link (interface) associated with the default route + link, err := netlinkClient.LinkByIndex(route.LinkIndex) + if err != nil { + logger.Warn("Not able to get the associated network interface by the index", logger.Fields{ + field.Error: err, + "LinkIndex": route.LinkIndex, + }) + } else { + logger.Debug("Found the associated network interface by the index", logger.Fields{ + "LinkName": link.Attrs().Name, + "LinkIndex": route.LinkIndex, + }) + return link.Attrs().Name, nil + } + } + } + return "", nil +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_windows.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_windows.go new file mode 100644 index 00000000000..f7d24b7f63b --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig/netconfig_windows.go @@ -0,0 +1,25 @@ +//go:build windows +// +build windows + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netconfig + +import "errors" + +// DefaultNetInterfaceName returns the device name of the first default network interface +// available on the instance. This is not supported on windows as of now. +func DefaultNetInterfaceName() (string, error) { + return "", errors.New("Not supported on windows") +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go new file mode 100644 index 00000000000..6ef7a36e111 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go @@ -0,0 +1,108 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper (interfaces: NetLink) + +// Package mock_netlinkwrapper is a generated GoMock package. +package mock_netlinkwrapper + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + netlink "github.com/vishvananda/netlink" +) + +// MockNetLink is a mock of NetLink interface. +type MockNetLink struct { + ctrl *gomock.Controller + recorder *MockNetLinkMockRecorder +} + +// MockNetLinkMockRecorder is the mock recorder for MockNetLink. +type MockNetLinkMockRecorder struct { + mock *MockNetLink +} + +// NewMockNetLink creates a new mock instance. +func NewMockNetLink(ctrl *gomock.Controller) *MockNetLink { + mock := &MockNetLink{ctrl: ctrl} + mock.recorder = &MockNetLinkMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetLink) EXPECT() *MockNetLinkMockRecorder { + return m.recorder +} + +// LinkByIndex mocks base method. +func (m *MockNetLink) LinkByIndex(arg0 int) (netlink.Link, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkByIndex", arg0) + ret0, _ := ret[0].(netlink.Link) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkByIndex indicates an expected call of LinkByIndex. +func (mr *MockNetLinkMockRecorder) LinkByIndex(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByIndex", reflect.TypeOf((*MockNetLink)(nil).LinkByIndex), arg0) +} + +// LinkByName mocks base method. +func (m *MockNetLink) LinkByName(arg0 string) (netlink.Link, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkByName", arg0) + ret0, _ := ret[0].(netlink.Link) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkByName indicates an expected call of LinkByName. +func (mr *MockNetLinkMockRecorder) LinkByName(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByName", reflect.TypeOf((*MockNetLink)(nil).LinkByName), arg0) +} + +// LinkSetUp mocks base method. +func (m *MockNetLink) LinkSetUp(arg0 netlink.Link) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkSetUp", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// LinkSetUp indicates an expected call of LinkSetUp. +func (mr *MockNetLinkMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetUp", reflect.TypeOf((*MockNetLink)(nil).LinkSetUp), arg0) +} + +// RouteList mocks base method. +func (m *MockNetLink) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RouteList", arg0, arg1) + ret0, _ := ret[0].([]netlink.Route) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RouteList indicates an expected call of RouteList. +func (mr *MockNetLinkMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteList", reflect.TypeOf((*MockNetLink)(nil).RouteList), arg0, arg1) +} diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index 36ab0152d66..a5e3c5c0629 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -71,6 +71,7 @@ github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4 github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state github.com/aws/amazon-ecs-agent/ecs-agent/tmds/logging github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/mux +github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig github.com/aws/amazon-ecs-agent/ecs-agent/utils github.com/aws/amazon-ecs-agent/ecs-agent/utils/arn github.com/aws/amazon-ecs-agent/ecs-agent/utils/cipher @@ -78,6 +79,7 @@ github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks github.com/aws/amazon-ecs-agent/ecs-agent/utils/httpproxy github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper +github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry/mock github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index fbcc9b77aeb..df5986df263 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -933,7 +933,7 @@ func validateRequest(w http.ResponseWriter, request types.NetworkFaultRequest, r func validateTaskMetadata(w http.ResponseWriter, agentState state.AgentState, requestType string, r *http.Request) (*state.TaskResponse, error) { var taskMetadata state.TaskResponse endpointContainerID := mux.Vars(r)[v4.EndpointContainerIDMuxName] - taskMetadata, err := agentState.GetTaskMetadataWithTaskNetworkConfig(endpointContainerID, nil) + taskMetadata, err := agentState.GetTaskMetadataWithTaskNetworkConfig(endpointContainerID, state.NewNetworkConfigClient()) if err != nil { code, errResponse := getTaskMetadataErrorResponse(endpointContainerID, requestType, err) responseBody := types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("%v", errResponse))