Skip to content

Commit

Permalink
Updating agent state to include task default interface name and task …
Browse files Browse the repository at this point in the history
…network namespace
  • Loading branch information
mye956 committed Sep 5, 2024
1 parent 67ff392 commit b38e6ce
Show file tree
Hide file tree
Showing 9 changed files with 275 additions and 23 deletions.
24 changes: 23 additions & 1 deletion agent/api/task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,12 @@ type Task struct {

IsInternal bool `json:"IsInternal,omitempty"`

// TODO: Will need to initialize/set the value in a follow PR
NetworkNamespace string `json:"NetworkNamespace,omitempty"`

// TODO: Will need to initialize/set the value in a follow PR
FaultInjectionEnabled bool `json:"FaultInjectionEnabled,omitempty"`

DefaultIfname string `json:"DefaultIfname,omitempty"`
}

// TaskFromACS translates ecsacs.Task to apitask.Task by first marshaling the received
Expand Down Expand Up @@ -3773,3 +3774,24 @@ func (task *Task) GetNetworkNamespace() string {

return task.NetworkNamespace
}

func (task *Task) SetNetworkNamespace(netNs string) {
task.lock.Lock()
defer task.lock.Unlock()

task.NetworkNamespace = netNs
}

func (task *Task) GetDefaultIfname() string {
task.lock.RLock()
defer task.lock.RUnlock()

return task.DefaultIfname
}

func (task *Task) SetDefaultIfname(ifname string) {
task.lock.Lock()
defer task.lock.Unlock()

task.DefaultIfname = ifname
}
6 changes: 6 additions & 0 deletions agent/engine/docker_task_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -2358,6 +2358,12 @@ func (engine *DockerTaskEngine) provisionContainerResourcesAwsvpc(task *apitask.
field.TaskID: task.GetID(),
"ip": taskIP,
})
task.SetNetworkNamespace(cniConfig.ContainerNetNS)
// Note: By default, the interface name is set to eth0 within the CNI configs. We can also always assume that the first entry of the CNI network config to be
// the task ENI. Otherwise this means that there weren't any task ENIs passed down to agent from the task payload.
if len(cniConfig.NetworkConfigs) > 0 {
task.SetDefaultIfname(cniConfig.NetworkConfigs[0].IfName)
}
engine.state.AddTaskIPAddress(taskIP, task.Arn)
task.SetLocalIPAddress(taskIP)
engine.saveTaskData(task)
Expand Down
9 changes: 5 additions & 4 deletions agent/engine/docker_task_engine_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ import (
)

const (
cgroupMountPath = "/sys/fs/cgroup"
testTaskDefFamily = "testFamily"
testTaskDefVersion = "1"
containerNetNS = "none"
cgroupMountPath = "/sys/fs/cgroup"
testTaskDefFamily = "testFamily"
testTaskDefVersion = "1"
containerNetNS = "none"
ExpectedNetworkNamespace = "/host/proc/123/ns/net"
)

func init() {
Expand Down
3 changes: 3 additions & 0 deletions agent/engine/docker_task_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ const (
containerNetworkMode = "none"
serviceConnectContainerName = "service-connect"
mediaTypeManifestV2 = "application/vnd.docker.distribution.manifest.v2+json"
defaultIfname = "eth0"
)

var (
Expand Down Expand Up @@ -1098,6 +1099,8 @@ func TestProvisionContainerResourcesAwsvpcSetPausePIDInVolumeResources(t *testin
require.Nil(t, taskEngine.(*DockerTaskEngine).provisionContainerResources(testTask, pauseContainer).Error)
assert.Equal(t, strconv.Itoa(containerPid), volRes.GetPauseContainerPID())
assert.Equal(t, taskIP, testTask.GetLocalIPAddress())
assert.Equal(t, defaultIfname, testTask.GetDefaultIfname())
assert.Equal(t, ExpectedNetworkNamespace, testTask.GetNetworkNamespace())
savedTasks, err := dataClient.GetTasks()
require.NoError(t, err)
assert.Len(t, savedTasks, 1)
Expand Down
3 changes: 2 additions & 1 deletion agent/engine/docker_task_engine_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ import (
)

const (
containerNetNS = "container:abcd"
containerNetNS = "container:abcd"
ExpectedNetworkNamespace = "none"
)

func TestDeleteTask(t *testing.T) {
Expand Down
189 changes: 189 additions & 0 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ const (
taskCredentialsID = "taskCredentialsId"
endpointId = "endpointId"
networkNamespace = "/path"
hostNetworkNamespace = "host"
defaultIfname = "eth0"

port = 1234
protocol = "tcp"
Expand Down Expand Up @@ -270,6 +272,7 @@ var (
PullStoppedAtUnsafe: now,
ExecutionStoppedAtUnsafe: now,
LaunchType: "EC2",
NetworkMode: bridgeMode,
}
container1 = &apicontainer.Container{
Name: containerName,
Expand Down Expand Up @@ -404,6 +407,30 @@ var (
Type: containerType,
},
}
expectedV4HostContainerResponse = v4.ContainerResponse{
ContainerResponse: &v2.ContainerResponse{
ID: containerID,
Name: containerName,
DockerName: containerName,
Image: imageName,
ImageID: imageID,
DesiredStatus: statusRunning,
KnownStatus: statusRunning,
ContainerARN: "arn:aws:ecs:ap-northnorth-1:NNN:container/NNNNNNNN-aaaa-4444-bbbb-00000000000",
Limits: v2.LimitsResponse{
CPU: aws.Float64(cpu),
Memory: aws.Int64(memory),
},
Type: containerType,
Labels: labels,
Ports: []tmdsresponse.PortResponse{
{
ContainerPort: containerPort,
Protocol: containerPortProtocol,
},
},
},
}
expectedV4BridgeContainerResponse = v4ContainerResponseFromV2(expectedBridgeContainerResponse, []v4.Network{{
Network: tmdsresponse.Network{
NetworkMode: bridgeMode,
Expand All @@ -423,6 +450,7 @@ var (
task.FaultInjectionEnabled = faultInjectionEnabled
task.NetworkMode = networkMode
task.NetworkNamespace = networkNamespace
task.DefaultIfname = defaultIfname
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(endpointId).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2),
Expand Down Expand Up @@ -466,6 +494,13 @@ func standardTask() *apitask.Task {
return &task
}

func standardHostTask() *apitask.Task {
task := standardTask()
task.ENIs = nil
task.NetworkMode = apitask.HostNetworkMode
return task
}

// Returns a standard v2 task response. This getter function protects against tests mutating the
// response.
func expectedTaskResponse() v2.TaskResponse {
Expand Down Expand Up @@ -524,6 +559,34 @@ func expectedV4TaskResponse() v4.TaskResponse {
)
}

func expectedV4TaskNetworkConfig(faultInjectionEnabled bool, networkMode, path, deviceName string) *v4.TaskNetworkConfig {
return v4.NewTaskNetworkConfig(networkMode, path, deviceName)
}

func expectedV4TaskResponseHostMode() v4.TaskResponse {
return v4TaskResponseFromV2(
v2.TaskResponse{
Cluster: clusterName,
TaskARN: taskARN,
Family: family,
Revision: version,
DesiredStatus: statusRunning,
KnownStatus: statusRunning,
Limits: &v2.LimitsResponse{
CPU: aws.Float64(cpu),
Memory: aws.Int64(memory),
},
PullStartedAt: aws.Time(now.UTC()),
PullStoppedAt: aws.Time(now.UTC()),
ExecutionStoppedAt: aws.Time(now.UTC()),
AvailabilityZone: availabilityzone,
LaunchType: "EC2",
},
[]v4.ContainerResponse{expectedV4HostContainerResponse},
vpcID,
)
}

// Returns a standard v4 task response including pulled containers response. This getter function
// protects against tests mutating the response.
func expectedV4PulledTaskResponse() v4.TaskResponse {
Expand Down Expand Up @@ -1994,6 +2057,51 @@ func TestV4TaskMetadata(t *testing.T) {
expectedResponseBody: expectedV4PulledTaskResponse(),
})
})

t.Run("happy case with fault injection enabled using awsvpc mode", func(t *testing.T) {
testTMDSRequest(t, TMDSTestCase[v4.TaskResponse]{
path: v4BasePath + v3EndpointID + "/task",
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
task.FaultInjectionEnabled = true
task.NetworkNamespace = networkNamespace
task.DefaultIfname = defaultIfname
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
)
},
expectedStatusCode: http.StatusOK,
expectedResponseBody: expectedV4TaskResponse(),
})
})

t.Run("happy case with fault injection enabled using host mode", func(t *testing.T) {
testTMDSRequest(t, TMDSTestCase[v4.TaskResponse]{
path: v4BasePath + v3EndpointID + "/task",
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
hostTask := standardHostTask()
hostTask.FaultInjectionEnabled = true
hostTask.NetworkNamespace = networkNamespace
hostTask.DefaultIfname = defaultIfname
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(hostTask, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(),
)
},
expectedStatusCode: http.StatusOK,
expectedResponseBody: expectedV4TaskResponseHostMode(),
})
})

t.Run("bridge mode container not found", func(t *testing.T) {
testTMDSRequest(t, TMDSTestCase[v4.TaskResponse]{
path: v4BasePath + v3EndpointID + "/task",
Expand Down Expand Up @@ -3804,3 +3912,84 @@ func testRegisterFaultHandler(t *testing.T, tcs []blackholePortFaultTestCase, me
})
}
}

func TestV4GetTaskMetadataWithTaskNetworkConfig(t *testing.T) {

tcs := []struct {
name string
setStateExpectations func(state *mock_dockerstate.MockTaskEngineState)
expectedTaskNetworkConfig *v4.TaskNetworkConfig
}{
{
name: "happy case with awsvpc mode",
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
task := standardTask()
task.FaultInjectionEnabled = true
task.NetworkNamespace = networkNamespace
task.DefaultIfname = defaultIfname
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
)
},
expectedTaskNetworkConfig: expectedV4TaskNetworkConfig(true, apitask.AWSVPCNetworkMode, networkNamespace, defaultIfname),
},
{
name: "happy case with host mode",
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
hostTask := standardHostTask()
hostTask.FaultInjectionEnabled = true
hostTask.NetworkNamespace = networkNamespace
hostTask.DefaultIfname = defaultIfname
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(hostTask, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
state.EXPECT().ContainerByID(containerID).Return(nil, false).AnyTimes(),
)
},
expectedTaskNetworkConfig: expectedV4TaskNetworkConfig(true, apitask.HostNetworkMode, hostNetworkNamespace, defaultIfname),
},
{
name: "happy bridge mode",
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(bridgeTask, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToBridgeContainer, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true).AnyTimes(),
)
},
expectedTaskNetworkConfig: expectedV4TaskNetworkConfig(true, bridgeMode, "", ""),
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

state := mock_dockerstate.NewMockTaskEngineState(ctrl)
statsEngine := mock_stats.NewMockEngine(ctrl)
ecsClient := mock_ecs.NewMockECSClient(ctrl)

if tc.setStateExpectations != nil {
tc.setStateExpectations(state)
}
tmdsAgentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn)
actualTaskResponse, err := tmdsAgentState.GetTaskMetadata(v3EndpointID)

assert.NoError(t, err)
assert.Equal(t, tc.expectedTaskNetworkConfig, actualTaskResponse.TaskNetworkConfig)
})
}
}
32 changes: 15 additions & 17 deletions agent/handlers/v4/tmdsstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ import (
tmdsv4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state"
)

const (
defaultHostNetworkNamespace = "host"
)

// Implements AgentState interface for TMDS v4.
type TMDSAgentState struct {
state dockerstate.TaskEngineState
Expand Down Expand Up @@ -151,25 +155,19 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool)
NewPulledContainerResponse(dockerContainer, task.GetPrimaryENI()))
}

if task.IsFaultInjectionEnabled() {
// TODO: The correct values for the task network config will need to be set/initialized
taskResponse.FaultInjectionEnabled = task.IsFaultInjectionEnabled()
taskNetworkConfig := tmdsv4.TaskNetworkConfig{
NetworkMode: task.GetNetworkMode(),
NetworkNamespaces: []*tmdsv4.NetworkNamespace{
{
Path: task.GetNetworkNamespace(),
NetworkInterfaces: []*tmdsv4.NetworkInterface{
{
DeviceName: "",
},
},
},
},
}
taskResponse.TaskNetworkConfig = &taskNetworkConfig
taskResponse.FaultInjectionEnabled = task.IsFaultInjectionEnabled()
var taskNetworkConfig *tmdsv4.TaskNetworkConfig
if task.IsNetworkModeHost() {
// For host most, we don't really need the network namespace in order to do anything within the host instance network namespace
// and so we will set this to an arbitrary value such as "host".
// TODO: Will need to find/obtain the interface name of the default network interface on the host instance
taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), defaultHostNetworkNamespace, task.GetDefaultIfname())
} else {
taskNetworkConfig = tmdsv4.NewTaskNetworkConfig(task.GetNetworkMode(), task.GetNetworkNamespace(), task.GetDefaultIfname())
}

taskResponse.TaskNetworkConfig = taskNetworkConfig

return *taskResponse, nil
}

Expand Down
Loading

0 comments on commit b38e6ce

Please sign in to comment.