diff --git a/ecs-agent/tmds/utils/netconfig/netconfig_linux.go b/ecs-agent/tmds/utils/netconfig/netconfig_linux.go new file mode 100644 index 00000000000..1df375c5a14 --- /dev/null +++ b/ecs-agent/tmds/utils/netconfig/netconfig_linux.go @@ -0,0 +1,61 @@ +//go:build linux +// +build linux + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netconfig + +import ( + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper" + + "github.com/vishvananda/netlink" +) + +// DefaultNetInterfaceName returns the device name of the first default network interface +// available on the instance. If none exist, an empty string and nil will be returned. +func DefaultNetInterfaceName(netlinkClient netlinkwrapper.NetLink) (string, error) { + routes, err := netlinkClient.RouteList(nil, netlink.FAMILY_ALL) + if err != nil { + return "", err + } + + // Iterate over all routes + for _, route := range routes { + logger.Debug("Found route", logger.Fields{"Route": route}) + if route.Gw == nil { + // A default route has a gateway. If it doesn't, skip it. + continue + } + + if route.Dst == nil || route.Dst.String() == "0.0.0.0/0" || route.Dst.String() == "::/0" { + // Get the link (interface) associated with the default route + link, err := netlinkClient.LinkByIndex(route.LinkIndex) + if err != nil { + logger.Warn("Not able to get the associated network interface by the index", logger.Fields{ + field.Error: err, + "LinkIndex": route.LinkIndex, + }) + } else { + logger.Debug("Found the associated network interface by the index", logger.Fields{ + "LinkName": link.Attrs().Name, + "LinkIndex": route.LinkIndex, + }) + return link.Attrs().Name, nil + } + } + } + return "", nil +} diff --git a/ecs-agent/tmds/utils/netconfig/netconfig_linux_test.go b/ecs-agent/tmds/utils/netconfig/netconfig_linux_test.go new file mode 100644 index 00000000000..d0433657be1 --- /dev/null +++ b/ecs-agent/tmds/utils/netconfig/netconfig_linux_test.go @@ -0,0 +1,169 @@ +//go:build linux && unit +// +build linux,unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netconfig + +import ( + "net" + "testing" + + mock_netlinkwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/vishvananda/netlink" +) + +func TestDefaultNetInterfaceName(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + _, allIpNet, err := net.ParseCIDR("0.0.0.0/0") + assert.NoError(t, err) + _, randomIpNet, err := net.ParseCIDR("192.168.1.0/24") + assert.NoError(t, err) + + tcs := []struct { + name string + routes []netlink.Route + link netlink.Link + expectedDefaultNetInterfaceName string + expectedErrMsg string + }{ + { + name: "no default route 1", + routes: []netlink.Route{ + netlink.Route{ + Gw: nil, + Dst: nil, + LinkIndex: 0, + }, + }, + link: &netlink.Device{ + LinkAttrs: netlink.LinkAttrs{ + Index: 0, + Name: "eni-0", + }, + }, + expectedDefaultNetInterfaceName: "", + expectedErrMsg: "", + }, + { + name: "no default route 2", + routes: []netlink.Route{ + netlink.Route{ + Gw: net.ParseIP("10.194.20.1"), + Dst: randomIpNet, + LinkIndex: 0, + }, + }, + link: &netlink.Device{ + LinkAttrs: netlink.LinkAttrs{ + Index: 0, + Name: "eni-0", + }, + }, + expectedDefaultNetInterfaceName: "", + expectedErrMsg: "", + }, + { + name: "one default route 1", + routes: []netlink.Route{ + netlink.Route{ + Gw: net.ParseIP("10.194.20.1"), + Dst: nil, + LinkIndex: 0, + }, + }, + link: &netlink.Device{ + LinkAttrs: netlink.LinkAttrs{ + Index: 0, + Name: "eni-0", + }, + }, + expectedDefaultNetInterfaceName: "eni-0", + expectedErrMsg: "", + }, + { + name: "one default route 2", + routes: []netlink.Route{ + netlink.Route{ + Gw: net.ParseIP("10.194.20.1"), + Dst: allIpNet, + LinkIndex: 1, + }, + }, + link: &netlink.Device{ + LinkAttrs: netlink.LinkAttrs{ + Index: 1, + Name: "eni-1", + }, + }, + expectedDefaultNetInterfaceName: "eni-1", + expectedErrMsg: "", + }, + { + name: "two default routes", + routes: []netlink.Route{ + netlink.Route{ + Gw: net.ParseIP("10.194.20.1"), + Dst: randomIpNet, + LinkIndex: 0, + }, + netlink.Route{ + Gw: net.ParseIP("10.194.20.1"), + Dst: allIpNet, + LinkIndex: 1, + }, + netlink.Route{ + Gw: net.ParseIP("10.194.20.1"), + Dst: nil, + LinkIndex: 2, + }, + }, + link: &netlink.Device{ + LinkAttrs: netlink.LinkAttrs{ + Index: 1, + Name: "eni-0", + }, + }, + expectedDefaultNetInterfaceName: "eni-0", + expectedErrMsg: "", + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + netLink := mock_netlinkwrapper.NewMockNetLink(ctrl) + gomock.InOrder( + netLink.EXPECT().RouteList(nil, netlink.FAMILY_ALL).Return(tc.routes, nil).AnyTimes(), + netLink.EXPECT().LinkByIndex(tc.link.Attrs().Index).Return(tc.link, nil).AnyTimes(), + ) + + defaultNetInterfaceName, err := DefaultNetInterfaceName(netLink) + errMsg := "" + if err != nil { + errMsg = err.Error() + } + + assert.Equal(t, tc.expectedErrMsg, errMsg) + assert.Equal(t, tc.expectedDefaultNetInterfaceName, defaultNetInterfaceName) + }) + } +} diff --git a/ecs-agent/tmds/utils/netconfig/netconfig_windows.go b/ecs-agent/tmds/utils/netconfig/netconfig_windows.go new file mode 100644 index 00000000000..f7d24b7f63b --- /dev/null +++ b/ecs-agent/tmds/utils/netconfig/netconfig_windows.go @@ -0,0 +1,25 @@ +//go:build windows +// +build windows + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package netconfig + +import "errors" + +// DefaultNetInterfaceName returns the device name of the first default network interface +// available on the instance. This is not supported on windows as of now. +func DefaultNetInterfaceName() (string, error) { + return "", errors.New("Not supported on windows") +} diff --git a/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go b/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go index e2b972fff03..6ef7a36e111 100644 --- a/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go +++ b/ecs-agent/utils/netlinkwrapper/mocks/netlinkwrapper_mocks_linux.go @@ -48,6 +48,21 @@ func (m *MockNetLink) EXPECT() *MockNetLinkMockRecorder { return m.recorder } +// LinkByIndex mocks base method. +func (m *MockNetLink) LinkByIndex(arg0 int) (netlink.Link, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkByIndex", arg0) + ret0, _ := ret[0].(netlink.Link) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkByIndex indicates an expected call of LinkByIndex. +func (mr *MockNetLinkMockRecorder) LinkByIndex(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByIndex", reflect.TypeOf((*MockNetLink)(nil).LinkByIndex), arg0) +} + // LinkByName mocks base method. func (m *MockNetLink) LinkByName(arg0 string) (netlink.Link, error) { m.ctrl.T.Helper() @@ -76,3 +91,18 @@ func (mr *MockNetLinkMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetUp", reflect.TypeOf((*MockNetLink)(nil).LinkSetUp), arg0) } + +// RouteList mocks base method. +func (m *MockNetLink) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RouteList", arg0, arg1) + ret0, _ := ret[0].([]netlink.Route) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RouteList indicates an expected call of RouteList. +func (mr *MockNetLinkMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteList", reflect.TypeOf((*MockNetLink)(nil).RouteList), arg0, arg1) +} diff --git a/ecs-agent/utils/netlinkwrapper/netlink_linux.go b/ecs-agent/utils/netlinkwrapper/netlink_linux.go index 0c733f23c27..dafbece6f21 100644 --- a/ecs-agent/utils/netlinkwrapper/netlink_linux.go +++ b/ecs-agent/utils/netlinkwrapper/netlink_linux.go @@ -23,6 +23,8 @@ import ( type NetLink interface { LinkByName(name string) (netlink.Link, error) LinkSetUp(link netlink.Link) error + RouteList(link netlink.Link, family int) ([]netlink.Route, error) + LinkByIndex(index int) (netlink.Link, error) } type netLink struct{} @@ -38,3 +40,13 @@ func (nl *netLink) LinkByName(name string) (netlink.Link, error) { func (nl *netLink) LinkSetUp(link netlink.Link) error { return netlink.LinkSetUp(link) } + +// RouteList gets a list of routes in the system. Equivalent to: `ip route show`. +// The list can be filtered by link and ip family. +func (nl *netLink) RouteList(link netlink.Link, family int) ([]netlink.Route, error) { + return netlink.RouteList(link, family) +} + +func (nl *netLink) LinkByIndex(index int) (netlink.Link, error) { + return netlink.LinkByIndex(index) +}