Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP DO NOT REVIEW #4360

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -56,6 +57,7 @@ import (
v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2"
v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state"
mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks"
mock_netlinkwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks"
"github.com/gorilla/mux"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -65,6 +67,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)

const (
Expand Down Expand Up @@ -3948,6 +3951,7 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) {
tcs := []struct {
name string
setStateExpectations func(state *mock_dockerstate.MockTaskEngineState)
setNetlinkExpectations func(netlinkClient *mock_netlinkwrapper.MockNetLink)
expectedTaskNetworkConfig *v4.TaskNetworkConfig
}{
{
Expand Down Expand Up @@ -3985,6 +3989,26 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) {
state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(),
)
},
setNetlinkExpectations: func(netlinkClient *mock_netlinkwrapper.MockNetLink) {
routes := []netlink.Route{
netlink.Route{
Gw: net.ParseIP("10.194.20.1"),
Dst: nil,
LinkIndex: 0,
},
}

link := &netlink.Device{
LinkAttrs: netlink.LinkAttrs{
Index: 0,
Name: "eth0",
},
}
gomock.InOrder(
netlinkClient.EXPECT().RouteList(nil, netlink.FAMILY_ALL).Return(routes, nil).AnyTimes(),
netlinkClient.EXPECT().LinkByIndex(link.Attrs().Index).Return(link, nil).AnyTimes(),
)
},
expectedTaskNetworkConfig: expectedV4TaskNetworkConfig(true, apitask.HostNetworkMode, hostNetworkNamespace, defaultIfname),
},
{
Expand Down Expand Up @@ -4012,11 +4036,17 @@ func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) {
statsEngine := mock_stats.NewMockEngine(ctrl)
ecsClient := mock_ecs.NewMockECSClient(ctrl)

mock_netlinkClient := mock_netlinkwrapper.NewMockNetLink(ctrl)

if tc.setStateExpectations != nil {
tc.setStateExpectations(state)
}

if tc.setNetlinkExpectations != nil {
tc.setNetlinkExpectations(mock_netlinkClient)
}
tmdsAgentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn)
actualTaskResponse, err := tmdsAgentState.GetTaskMetadata(v3EndpointID)
actualTaskResponse, err := tmdsAgentState.GetTaskMetadataWithTaskNetworkConfig(v3EndpointID, mock_netlinkClient)

assert.NoError(t, err)
assert.Equal(t, tc.expectedTaskNetworkConfig, actualTaskResponse.TaskNetworkConfig)
Expand Down
51 changes: 38 additions & 13 deletions agent/handlers/v4/tmdsstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"github.com/aws/amazon-ecs-agent/ecs-agent/logger"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/field"
tmdsv4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig"
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper"
)

const (
Expand Down Expand Up @@ -91,17 +93,40 @@ func (s *TMDSAgentState) GetContainerMetadata(v3EndpointID string) (tmdsv4.Conta

// Returns task metadata in v4 format for the task identified by the provided endpointContainerID.
func (s *TMDSAgentState) GetTaskMetadata(v3EndpointID string) (tmdsv4.TaskResponse, error) {
return s.getTaskMetadata(v3EndpointID, false)
return s.getTaskMetadata(v3EndpointID, false, false)
}

// Returns task metadata including task and container instance tags in v4 format for the
// task identified by the provided endpointContainerID.
func (s *TMDSAgentState) GetTaskMetadataWithTags(v3EndpointID string) (tmdsv4.TaskResponse, error) {
return s.getTaskMetadata(v3EndpointID, true)
return s.getTaskMetadata(v3EndpointID, true, false)
}

func (s *TMDSAgentState) GetTaskMetadataWithTaskNetworkConfig(v3EndpointID string, netlinkClient netlinkwrapper.NetLink) (tmdsv4.TaskResponse, error) {
taskResponse, err := s.getTaskMetadata(v3EndpointID, false, true)
if err == nil {
if taskResponse.TaskNetworkConfig.NetworkMode == "host" && taskResponse.TaskNetworkConfig != nil {
taskARN := taskResponse.TaskARN
deviceName, err := netconfig.DefaultNetInterfaceName(netlinkClient)
if err != nil {
logger.Warn("Unable to obtain default network interface name on host for task.", logger.Fields{
field.TaskARN: taskARN,
field.Error: err,
})
} else {
logger.Info("Obtained default network interface name on host for task", logger.Fields{
field.TaskARN: taskARN,
"defaultInterfaceName": deviceName,
})
taskResponse.TaskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName = deviceName
}
}
}
return taskResponse, err
}

// Returns task metadata in v4 format for the task identified by the provided endpointContainerID.
func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool) (tmdsv4.TaskResponse, error) {
func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags, includeTaskNetworkConfig bool) (tmdsv4.TaskResponse, error) {
taskARN, ok := s.state.TaskARNByV3EndpointID(v3EndpointID)
if !ok {
return tmdsv4.TaskResponse{}, tmdsv4.NewErrorLookupFailure(fmt.Sprintf(
Expand Down Expand Up @@ -156,18 +181,18 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool)
}

taskResponse.FaultInjectionEnabled = task.IsFaultInjectionEnabled()
var taskNetworkConfig *tmdsv4.TaskNetworkConfig
if task.IsNetworkModeHost() {
// For host most, we don't really need the network namespace in order to do anything within the host instance network namespace
// and so we will set this to an arbitrary value such as "host".
// TODO: Will need to find/obtain the interface name of the default network interface on the host instance
taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), defaultHostNetworkNamespace, task.GetDefaultIfname())
} else {
taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), task.GetNetworkNamespace(), task.GetDefaultIfname())
if includeTaskNetworkConfig {
var taskNetworkConfig *tmdsv4.TaskNetworkConfig
if task.IsNetworkModeHost() {
// For host most, we don't really need the network namespace in order to do anything within the host instance network namespace
// and so we will set this to an arbitrary value such as "host".
taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), defaultHostNetworkNamespace, task.GetDefaultIfname())
} else {
taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), task.GetNetworkNamespace(), task.GetDefaultIfname())
}
taskResponse.TaskNetworkConfig = taskNetworkConfig
}

taskResponse.TaskNetworkConfig = taskNetworkConfig

return *taskResponse, nil
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading