Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consume updated TaskResponse in network fault injection handlers #4302

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions agent/api/task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@ type Task struct {
NetworkMode string `json:"NetworkMode,omitempty"`

IsInternal bool `json:"IsInternal,omitempty"`

// TODO: Will need to initialize/set the value in a follow PR
NetworkNamespace string `json:"NetworkNamespace,omitempty"`

// TODO: Will need to initialize/set the value in a follow PR
FaultInjectionEnabled bool `json:"FaultInjectionEnabled,omitempty"`
}

// TaskFromACS translates ecsacs.Task to apitask.Task by first marshaling the received
Expand Down Expand Up @@ -3743,3 +3749,24 @@ func (task *Task) HasAContainerWithResolvedDigest() bool {
}
return false
}

func (task *Task) IsFaultInjectionEnabled() bool {
task.lock.RLock()
defer task.lock.RUnlock()

return task.FaultInjectionEnabled
}

func (task *Task) GetNetworkMode() string {
task.lock.RLock()
defer task.lock.RUnlock()

return task.NetworkMode
}

func (task *Task) GetNetworkNamespace() string {
task.lock.RLock()
defer task.lock.RUnlock()

return task.NetworkNamespace
}
89 changes: 71 additions & 18 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ const (
subnetGatewayIpv4Address = "172.31.32.1/20"
taskCredentialsID = "taskCredentialsId"
endpointId = "endpointId"
networkNamespace = "/path"

port = 1234
protocol = "tcp"
Expand Down Expand Up @@ -416,6 +417,21 @@ var (
SubnetGatewayIPV4Address: "",
}},
})

agentStateExpectations = func(state *mock_dockerstate.MockTaskEngineState, faultInjectionEnabled bool, networkMode string) {
task := standardTask()
task.FaultInjectionEnabled = faultInjectionEnabled
task.NetworkMode = networkMode
task.NetworkNamespace = networkNamespace
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(endpointId).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
)
}
)

func standardTask() *apitask.Task {
Expand Down Expand Up @@ -3576,7 +3592,9 @@ type blackholePortFaultTestCase struct {
expectedStatusCode int
requestBody interface{}
expectedFaultResponse faulttype.NetworkFaultInjectionResponse
setStateExpectations func(state *mock_dockerstate.MockTaskEngineState)
setStateExpectations func(state *mock_dockerstate.MockTaskEngineState, faultInjectionEnabled bool, networkMode string)
faultInjectionEnabled bool
networkMode string
}

func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyResponseBody string) []blackholePortFaultTestCase {
Expand All @@ -3585,24 +3603,25 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe
"Protocol": protocol,
"TrafficType": trafficType,
}
happyStateExpectations := func(state *mock_dockerstate.MockTaskEngineState) {
task := standardTask()
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(endpointId).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2),
state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true).AnyTimes(),
state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true),
)
}

tcs := []blackholePortFaultTestCase{
xxx0624 marked this conversation as resolved.
Show resolved Hide resolved
{
name: fmt.Sprintf("%s success", name),
name: fmt.Sprintf("%s success host mode", name),
expectedStatusCode: 200,
requestBody: happyBlackHolePortReqBody,
expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody),
setStateExpectations: agentStateExpectations,
faultInjectionEnabled: true,
networkMode: apitask.HostNetworkMode,
},
{
name: fmt.Sprintf("%s success awsvpc mode", name),
expectedStatusCode: 200,
requestBody: happyBlackHolePortReqBody,
expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody),
setStateExpectations: happyStateExpectations,
setStateExpectations: agentStateExpectations,
faultInjectionEnabled: true,
networkMode: apitask.AWSVPCNetworkMode,
},
{
name: fmt.Sprintf("%s unknown request body", name),
Expand All @@ -3614,7 +3633,9 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe
"Unknown": "",
},
expectedFaultResponse: faulttype.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody),
setStateExpectations: happyStateExpectations,
setStateExpectations: agentStateExpectations,
faultInjectionEnabled: true,
networkMode: apitask.AWSVPCNetworkMode,
},
{
name: fmt.Sprintf("%s malformed request body", name),
Expand All @@ -3625,6 +3646,8 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe
"TrafficType": trafficType,
},
expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkBlackholePortRequest.Port of type uint16"),
faultInjectionEnabled: true,
networkMode: apitask.AWSVPCNetworkMode,
},
{
name: fmt.Sprintf("%s incomplete request body", name),
Expand All @@ -3634,6 +3657,8 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe
"Protocol": protocol,
},
expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("required parameter TrafficType is missing"),
faultInjectionEnabled: true,
networkMode: apitask.AWSVPCNetworkMode,
},
{
name: fmt.Sprintf("%s empty value request body", name),
Expand All @@ -3644,6 +3669,8 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe
"TrafficType": "",
},
expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("required parameter TrafficType is missing"),
faultInjectionEnabled: true,
networkMode: apitask.AWSVPCNetworkMode,
},
{
name: fmt.Sprintf("%s invalid protocol value request body", name),
Expand All @@ -3654,6 +3681,8 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe
"TrafficType": trafficType,
},
expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("invalid value invalid for parameter Protocol"),
faultInjectionEnabled: true,
networkMode: apitask.AWSVPCNetworkMode,
},
{
name: fmt.Sprintf("%s invalid traffic type value request body", name),
Expand All @@ -3664,29 +3693,53 @@ func getNetworkBlackHolePortHandlerTestCases(name, fault string, expectedHappyRe
"TrafficType": "invalid",
},
expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("invalid value invalid for parameter TrafficType"),
faultInjectionEnabled: true,
networkMode: apitask.AWSVPCNetworkMode,
},
{
name: fmt.Sprintf("%s task lookup fail", name),
expectedStatusCode: 404,
requestBody: happyBlackHolePortReqBody,
expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)),
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState, faultInjectionEnabled bool, networkMode string) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(endpointId).Return("", false),
)
},
faultInjectionEnabled: true,
networkMode: apitask.AWSVPCNetworkMode,
},
{
name: fmt.Sprintf("%s task metadata fetch fail", name),
expectedStatusCode: 500,
requestBody: happyBlackHolePortReqBody,
expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)),
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState, faultInjectionEnabled bool, networkMode string) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(endpointId).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(nil, false),
)
},
faultInjectionEnabled: true,
networkMode: apitask.AWSVPCNetworkMode,
},
{
name: fmt.Sprintf("%s fault injection disabled", name),
expectedStatusCode: 400,
requestBody: happyBlackHolePortReqBody,
expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("fault injection is not enabled for task: %s", taskARN)),
setStateExpectations: agentStateExpectations,
faultInjectionEnabled: false,
networkMode: apitask.AWSVPCNetworkMode,
},
{
name: fmt.Sprintf("%s invalid network mode", name),
expectedStatusCode: 400,
requestBody: happyBlackHolePortReqBody,
expectedFaultResponse: faulttype.NewNetworkFaultInjectionErrorResponse("invalid mode is not supported. Please use either host or awsvpc mode."),
setStateExpectations: agentStateExpectations,
faultInjectionEnabled: true,
networkMode: "invalid",
},
}
return tcs
Expand Down Expand Up @@ -3722,7 +3775,7 @@ func testRegisterFaultHandler(t *testing.T, tcs []blackholePortFaultTestCase, me
metricsFactory := mock_metrics.NewMockEntryFactory(ctrl)

if tc.setStateExpectations != nil {
tc.setStateExpectations(state)
tc.setStateExpectations(state, tc.faultInjectionEnabled, tc.networkMode)
}

router := mux.NewRouter()
Expand Down
19 changes: 19 additions & 0 deletions agent/handlers/v4/tmdsstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,25 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool)
NewPulledContainerResponse(dockerContainer, task.GetPrimaryENI()))
}

if task.IsFaultInjectionEnabled() {
// TODO: The correct values for the task network config will need to be set/initialized
taskResponse.FaultInjectionEnabled = task.IsFaultInjectionEnabled()
taskNetworkConfig := tmdsv4.TaskNetworkConfig{
NetworkMode: task.GetNetworkMode(),
NetworkNamespaces: []*tmdsv4.NetworkNamespace{
{
Path: task.GetNetworkNamespace(),
NetworkInterfaces: []*tmdsv4.NetworkInterface{
{
DeviceName: "",
},
},
},
},
}
taskResponse.TaskNetworkConfig = &taskNetworkConfig
}

return *taskResponse, nil
}

Expand Down
Loading
Loading