Skip to content

Commit

Permalink
Incorporating getting default network interface on host mode in task …
Browse files Browse the repository at this point in the history
…metadata
  • Loading branch information
mye956 committed Sep 24, 2024
1 parent 2612f4c commit f4742c9
Show file tree
Hide file tree
Showing 13 changed files with 363 additions and 38 deletions.
32 changes: 31 additions & 1 deletion agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -56,6 +57,7 @@ import (
v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2"
v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state"
mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks"
mock_netlinkwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks"
"github.com/gorilla/mux"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -65,6 +67,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)

const (
Expand Down Expand Up @@ -3948,6 +3951,7 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) {
tcs := []struct {
name string
setStateExpectations func(state *mock_dockerstate.MockTaskEngineState)
setNetlinkExpectations func(netlinkClient *mock_netlinkwrapper.MockNetLink)
expectedTaskNetworkConfig *v4.TaskNetworkConfig
}{
{
Expand Down Expand Up @@ -3985,6 +3989,26 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) {
state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(),
)
},
setNetlinkExpectations: func(netlinkClient *mock_netlinkwrapper.MockNetLink) {
routes := []netlink.Route{
netlink.Route{
Gw: net.ParseIP("10.194.20.1"),
Dst: nil,
LinkIndex: 0,
},
}

link := &netlink.Device{
LinkAttrs: netlink.LinkAttrs{
Index: 0,
Name: "eth0",
},
}
gomock.InOrder(
netlinkClient.EXPECT().RouteList(nil, netlink.FAMILY_ALL).Return(routes, nil).AnyTimes(),
netlinkClient.EXPECT().LinkByIndex(link.Attrs().Index).Return(link, nil).AnyTimes(),
)
},
expectedTaskNetworkConfig: expectedV4TaskNetworkConfig(true, apitask.HostNetworkMode, hostNetworkNamespace, defaultIfname),
},
{
Expand Down Expand Up @@ -4012,11 +4036,17 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) {
statsEngine := mock_stats.NewMockEngine(ctrl)
ecsClient := mock_ecs.NewMockECSClient(ctrl)

mock_netlinkClient := mock_netlinkwrapper.NewMockNetLink(ctrl)

if tc.setStateExpectations != nil {
tc.setStateExpectations(state)
}

if tc.setNetlinkExpectations != nil {
tc.setNetlinkExpectations(mock_netlinkClient)
}
tmdsAgentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn)
actualTaskResponse, err := tmdsAgentState.GetTaskMetadataWithTaskNetworkConfig(v3EndpointID)
actualTaskResponse, err := tmdsAgentState.GetTaskMetadataWithTaskNetworkConfig(v3EndpointID, mock_netlinkClient)

assert.NoError(t, err)
assert.Equal(t, tc.expectedTaskNetworkConfig, actualTaskResponse.TaskNetworkConfig)
Expand Down
29 changes: 22 additions & 7 deletions agent/handlers/v4/tmdsstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"github.com/aws/amazon-ecs-agent/ecs-agent/logger"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/field"
tmdsv4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig"
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper"
)

const (
Expand Down Expand Up @@ -91,21 +93,21 @@ func (s *TMDSAgentState) GetContainerMetadata(v3EndpointID string) (tmdsv4.Conta

// Returns task metadata in v4 format for the task identified by the provided endpointContainerID.
func (s *TMDSAgentState) GetTaskMetadata(v3EndpointID string) (tmdsv4.TaskResponse, error) {
return s.getTaskMetadata(v3EndpointID, false, false)
return s.getTaskMetadata(v3EndpointID, false, false, nil)
}

// Returns task metadata including task and container instance tags in v4 format for the
// task identified by the provided endpointContainerID.
func (s *TMDSAgentState) GetTaskMetadataWithTags(v3EndpointID string) (tmdsv4.TaskResponse, error) {
return s.getTaskMetadata(v3EndpointID, true, false)
return s.getTaskMetadata(v3EndpointID, true, false, nil)
}

func (s *TMDSAgentState) GetTaskMetadataWithTaskNetworkConfig(v3EndpointID string) (tmdsv4.TaskResponse, error) {
return s.getTaskMetadata(v3EndpointID, false, true)
func (s *TMDSAgentState) GetTaskMetadataWithTaskNetworkConfig(v3EndpointID string, netlinkClient netlinkwrapper.NetLink) (tmdsv4.TaskResponse, error) {
return s.getTaskMetadata(v3EndpointID, false, true, netlinkClient)
}

// Returns task metadata in v4 format for the task identified by the provided endpointContainerID.
func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags, includeTaskNetworkConfig bool) (tmdsv4.TaskResponse, error) {
func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags, includeTaskNetworkConfig bool, netlinkClient netlinkwrapper.NetLink) (tmdsv4.TaskResponse, error) {
taskARN, ok := s.state.TaskARNByV3EndpointID(v3EndpointID)
if !ok {
return tmdsv4.TaskResponse{}, tmdsv4.NewErrorLookupFailure(fmt.Sprintf(
Expand Down Expand Up @@ -163,10 +165,23 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags, inclu
taskResponse.FaultInjectionEnabled = task.IsFaultInjectionEnabled()
var taskNetworkConfig *tmdsv4.TaskNetworkConfig
if task.IsNetworkModeHost() {
deviceName := ""
if netlinkClient != nil {
deviceName, err = netconfig.DefaultNetInterfaceName(netlinkClient)
if err != nil {
logger.Warn("Unable to obtain default network interface name on host for task.", logger.Fields{
field.TaskARN: taskARN,
field.Error: err,
})
}
logger.Info("Obtained default network interface name on host for task", logger.Fields{
field.TaskARN: taskARN,
"defaultInterfaceName": deviceName,
})
}
// For host most, we don't really need the network namespace in order to do anything within the host instance network namespace
// and so we will set this to an arbitrary value such as "host".
// TODO: Will need to find/obtain the interface name of the default network interface on the host instance
taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), defaultHostNetworkNamespace, task.GetDefaultIfname())
taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), defaultHostNetworkNamespace, deviceName)
} else {
taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), task.GetNetworkNamespace(), task.GetDefaultIfname())
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit f4742c9

Please sign in to comment.