From 5b987b00a6437d0ede54e9126ea5c833e6586424 Mon Sep 17 00:00:00 2001 From: Tianze Shan Date: Thu, 5 Sep 2024 21:10:36 +0000 Subject: [PATCH] Add check network packet loss implementation --- .../handlers/fault/v1/handlers/handlers.go | 181 +++++++++++++++++- .../ecs-agent/utils/execwrapper/exec.go | 114 +++++++++++ .../utils/execwrapper/generate_mocks.go | 16 ++ agent/vendor/modules.txt | 1 + .../handlers/fault/v1/handlers/handlers.go | 181 +++++++++++++++++- .../fault/v1/handlers/handlers_test.go | 149 ++++++++++++-- ecs-agent/utils/execwrapper/exec.go | 10 + .../execwrapper/mocks/execwrapper_mocks.go | 30 +++ 8 files changed, 649 insertions(+), 33 deletions(-) create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/exec.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/generate_mocks.go diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 438d0cdbc74..cb26a80cd3b 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -15,12 +15,17 @@ package handlers import ( "bytes" + "context" "encoding/json" "errors" "fmt" "io" + "net" "net/http" + "strconv" + "strings" "sync" + "time" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -30,6 +35,7 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" "github.com/gorilla/mux" ) @@ -42,6 +48,12 @@ const ( faultInjectionEnabledError = "fault injection is not enabled for task: %s" ) +var ( + tcCheckInjectionCommandString = "tc -j q show dev %s parent 1:1" + tcCheckIPFilterCommandString = "tc -j filter show dev %s" + nsenterCommandString = "nsenter --net=%s " +) + type FaultHandler struct { // mutexMap is used to avoid multiple clients to manipulate same resource at same // time. The 'key' is the the network namespace path and 'value' is the RWMutex. @@ -49,6 +61,7 @@ type FaultHandler struct { mutexMap sync.Map AgentState state.AgentState MetricsFactory metrics.EntryFactory + OsExecWrapper execwrapper.Exec } func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { @@ -56,6 +69,7 @@ func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { AgentState: agentState, MetricsFactory: mf, mutexMap: sync.Map{}, + OsExecWrapper: execwrapper.NewExec(), } } @@ -471,30 +485,45 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. return } - // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR + // Obtain the task metadata via the endpoint container ID. taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } - // To avoid multiple requests to manipulate same network resource + // To avoid multiple requests to manipulate same network resource. networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path rwMu := h.loadLock(networkNSPath) rwMu.RLock() defer rwMu.RUnlock() - // TODO: Check status of current fault injection - // TODO: Return the correct status state - responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") - logger.Info("Successfully checked status for fault", logger.Fields{ + // Check status of current fault injection. + faultStatus, err := h.checkPacketLossFault(taskMetadata, request) + var responseBody types.NetworkFaultInjectionResponse + var stringToBeLogged string + var httpStatusCode int + if err != nil { + responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + stringToBeLogged = "Error: failed to check fault status" + httpStatusCode = http.StatusInternalServerError + } else { + stringToBeLogged = "Successfully checked status for fault" + httpStatusCode = http.StatusOK + if faultStatus { + responseBody = types.NewNetworkFaultInjectionSuccessResponse("running") + } else { + responseBody = types.NewNetworkFaultInjectionSuccessResponse("not-running") + } + } + + logger.Info(stringToBeLogged, logger.Fields{ field.RequestType: requestType, field.Request: request.ToString(), field.Response: responseBody.ToString(), }) utils.WriteJSONResponse( w, - http.StatusOK, + httpStatusCode, responseBody, requestType, ) @@ -692,3 +721,139 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return nil } + +// checkPacketLossFault checks if there's existing network-packet-loss fault running. +func (h *FaultHandler) checkPacketLossFault(taskMetadata *state.TaskResponse, request types.NetworkPacketLossRequest) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + interfaceName := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName + lossPercent := request.LossPercent + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + ipSources := request.Sources + // If task's network mode is awsvpc, we need to run nsenter to access the task's network namespace. + nsenterPrefix := "" + if networkMode == "awsvpc" { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path) + } + + // We will run the following Linux command to assess if there existing fault. + // "tc -j q show dev {INTERFACE} parent 1:1" + // The command above gives the output of "tc q show dev {INTERFACE} parent 1:1" in json format. + // We will then unmarshall the json string and evaluate the fields of it. + tcCheckInjectionCommandComposed := nsenterPrefix + fmt.Sprintf(tcCheckInjectionCommandString, interfaceName) + fmt.Println(tcCheckInjectionCommandComposed) + cmdOutput, err := h.runExecCommand(ctx, tcCheckInjectionCommandComposed) + if err != nil { + return false, errors.New("failed to check network-packet-loss-fault: " + string(cmdOutput[:]) + err.Error()) + } + // Log the command output to better help us debug. + logger.Info(fmt.Sprintf("%s command result: %s", tcCheckInjectionCommandComposed, string(cmdOutput[:]))) + var outputUnmarshalled []map[string]interface{} + err = json.Unmarshal(cmdOutput, &outputUnmarshalled) + if err != nil { + return false, errors.New("failed to unmarshal tc command output: " + err.Error()) + } + netemExists := false + for _, line := range outputUnmarshalled { + // First check field "kind":"netem" exists. + if line["kind"] == "netem" { + // Now check if field "loss":"" exists, and if the percentage matches with the value in the request. + if options := line["options"]; options != nil { + if lossRandom := options.(map[string]interface{})["loss-random"]; lossRandom != nil { + if loss := lossRandom.(map[string]interface{})["loss"]; loss != nil { + if lossValue, ok := loss.(float64); ok { + lossPercentInPercentage := float64(*lossPercent) / 100 + if lossValue == lossPercentInPercentage { + netemExists = true + break + } + } + } + } + } + } + } + // If we didn't find anything from above, there's no fault injected. + if !netemExists { + return false, nil + } + + // Now check if the desired IPs are properly added to the filters. + // We run the following command: "tc -j filter show dev " + // The output of the command above does not format into json properly. + // It has a field like this: "options":{something here} + // Due to the field above, we won't be able to properly unmarshal this. + // Thus, we will use strings.contains directly to parse the output. + tcCheckIPCommandComposed := nsenterPrefix + fmt.Sprintf(tcCheckIPFilterCommandString, interfaceName) + cmdOutput, err = h.runExecCommand(ctx, tcCheckIPCommandComposed) + if err != nil { + return false, errors.New("failed to check network-packet-loss-fault: " + string(cmdOutput[:]) + err.Error()) + } + // Log the command output to better help us debug. + logger.Info(fmt.Sprintf("%s command result: %s", tcCheckIPCommandComposed, string(cmdOutput[:]))) + allIPAddressesInRequestExist := true + for _, ipAddress := range ipSources { + ipAddressInHex, err := convertIPAddressToHex(*ipAddress) + if err != nil { + return false, errors.New("failed to check network-packet-loss-fault: " + err.Error()) + } + patternString := "match " + ipAddressInHex + if !strings.Contains(string(cmdOutput[:]), patternString) { + allIPAddressesInRequestExist = false + break + } + } + if !allIPAddressesInRequestExist { + return false, nil + } + + return true, nil +} + +// runExecCommand wraps around the execwrapper, providing a convenient way of running any Linux command +// and getting the result in both stdout and stderr. +func (h *FaultHandler) runExecCommand(ctx context.Context, linuxCommandString string) ([]byte, error) { + cmdExec := h.OsExecWrapper.CommandContext(ctx, "/bin/sh", "-c", linuxCommandString) + return cmdExec.CombinedOutput() +} + +// convertIPAddressToHex converts an ipv4 address or ipv4 CIDR block string input into HEX format. +// If not specified, we will use the full ip namespace (mask will be /32). +// For example, string "192.168.1.100" will be converted to "c0a80164/ffffffff", +// and string "192.168.1.100/31" will be converted to "c0a80164/fffffffe". +func convertIPAddressToHex(ipAddressInString string) (string, error) { + var ipAddress, mask string + ipAddressAndMaskSeparated := strings.Split(ipAddressInString, "/") + if len(ipAddressAndMaskSeparated) > 2 { + return "", errors.New("invalid IP address") + } + ipAddress = ipAddressAndMaskSeparated[0] + // If a mask is not specified in the IP address, by default use /32. + if len(ipAddressAndMaskSeparated) == 1 { + mask = "ffffffff" + } else { + maskInInt, err := strconv.Atoi(ipAddressAndMaskSeparated[1]) + if err != nil { + return "", err + } + mask = net.CIDRMask(maskInInt, 32).String() + } + ipAddressInHexString := "" + ipAddressSplited := strings.Split(ipAddress, ".") + for _, component := range ipAddressSplited { + componentInInt, err := strconv.Atoi(component) + if err != nil { + return "", err + } + str := strconv.FormatInt(int64(componentInInt), 16) + // Edge case: values less than 0d16 will be converted to single digit hex number/ + // For example, 0d10 will be converted to 0xa instead of 0x0a. + // Thus, if we have a single digit hex string, add a "0" in front of it. + if len(str) == 1 { + str = "0" + str + } + ipAddressInHexString += str + } + return ipAddressInHexString + "/" + mask, nil +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/exec.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/exec.go new file mode 100644 index 00000000000..674ceef8d7c --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/exec.go @@ -0,0 +1,114 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package execwrapper + +import ( + "context" + "io" + "os" + "os/exec" +) + +// Exec acts as a wrapper to functions exposed by the exec package. +// Having this interface enables us to create mock objects we can use +// for testing. +type Exec interface { + CommandContext(ctx context.Context, name string, arg ...string) Cmd +} + +// execWrapper is a placeholder struct which implements the Exec interface. +type execWrapper struct { +} + +func NewExec() Exec { + return &execWrapper{} +} + +// CommandContext essentially acts as a wrapper function for exec.CommandContext function. +func (e *execWrapper) CommandContext(ctx context.Context, name string, arg ...string) Cmd { + return NewCMDContext(ctx, name, arg...) +} + +// Cmd acts as a wrapper to functions exposed by the exec.Cmd object. +// Having this interface enables us to create mock objects we can use +// for testing. +type Cmd interface { + Run() error + Start() error + Wait() error + KillProcess() error + AppendExtraFiles(...*os.File) + Args() []string + SetIOStreams(io.Reader, io.Writer, io.Writer) + Output() ([]byte, error) + CombinedOutput() ([]byte, error) +} + +type cmdWrapper struct { + *exec.Cmd +} + +func NewCMDContext(ctx context.Context, name string, arg ...string) Cmd { + cmd := exec.CommandContext(ctx, name, arg...) + return &cmdWrapper{Cmd: cmd} +} + +func NewCMD(name string, arg ...string) Cmd { + cmd := exec.Command(name, arg...) + return &cmdWrapper{Cmd: cmd} +} + +func (c *cmdWrapper) Run() error { + return c.Cmd.Run() +} + +func (c *cmdWrapper) Start() error { + return c.Cmd.Start() +} + +func (c *cmdWrapper) Wait() error { + return c.Cmd.Wait() +} + +func (c *cmdWrapper) KillProcess() error { + return c.Cmd.Process.Kill() +} + +func (c *cmdWrapper) AppendExtraFiles(ef ...*os.File) { + c.ExtraFiles = append(c.ExtraFiles, ef...) +} + +func (c *cmdWrapper) Args() []string { + return c.Cmd.Args +} + +func (c *cmdWrapper) SetIOStreams(stdin io.Reader, stdout io.Writer, stderr io.Writer) { + if stdin != nil { + c.Stdin = stdin + } + if stdout != nil { + c.Stdout = stdout + } + if stderr != nil { + c.Stderr = stderr + } +} + +func (c *cmdWrapper) Output() ([]byte, error) { + return c.Cmd.Output() +} + +func (c *cmdWrapper) CombinedOutput() ([]byte, error) { + return c.Cmd.CombinedOutput() +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/generate_mocks.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/generate_mocks.go new file mode 100644 index 00000000000..0c0f9e9aba1 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/generate_mocks.go @@ -0,0 +1,16 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +//go:generate mockgen -build_flags=--mod=mod -destination=mocks/execwrapper_mocks.go -copyright_file=../../../scripts/copyright_file github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper Cmd,Exec + +package execwrapper diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index 54ddb3d910b..9ea6b1cbd6b 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -74,6 +74,7 @@ github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/mux github.com/aws/amazon-ecs-agent/ecs-agent/utils github.com/aws/amazon-ecs-agent/ecs-agent/utils/arn github.com/aws/amazon-ecs-agent/ecs-agent/utils/cipher +github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper github.com/aws/amazon-ecs-agent/ecs-agent/utils/httpproxy github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry/mock diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 438d0cdbc74..cb26a80cd3b 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -15,12 +15,17 @@ package handlers import ( "bytes" + "context" "encoding/json" "errors" "fmt" "io" + "net" "net/http" + "strconv" + "strings" "sync" + "time" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -30,6 +35,7 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" "github.com/gorilla/mux" ) @@ -42,6 +48,12 @@ const ( faultInjectionEnabledError = "fault injection is not enabled for task: %s" ) +var ( + tcCheckInjectionCommandString = "tc -j q show dev %s parent 1:1" + tcCheckIPFilterCommandString = "tc -j filter show dev %s" + nsenterCommandString = "nsenter --net=%s " +) + type FaultHandler struct { // mutexMap is used to avoid multiple clients to manipulate same resource at same // time. The 'key' is the the network namespace path and 'value' is the RWMutex. @@ -49,6 +61,7 @@ type FaultHandler struct { mutexMap sync.Map AgentState state.AgentState MetricsFactory metrics.EntryFactory + OsExecWrapper execwrapper.Exec } func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { @@ -56,6 +69,7 @@ func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { AgentState: agentState, MetricsFactory: mf, mutexMap: sync.Map{}, + OsExecWrapper: execwrapper.NewExec(), } } @@ -471,30 +485,45 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. return } - // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR + // Obtain the task metadata via the endpoint container ID. taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } - // To avoid multiple requests to manipulate same network resource + // To avoid multiple requests to manipulate same network resource. networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path rwMu := h.loadLock(networkNSPath) rwMu.RLock() defer rwMu.RUnlock() - // TODO: Check status of current fault injection - // TODO: Return the correct status state - responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") - logger.Info("Successfully checked status for fault", logger.Fields{ + // Check status of current fault injection. + faultStatus, err := h.checkPacketLossFault(taskMetadata, request) + var responseBody types.NetworkFaultInjectionResponse + var stringToBeLogged string + var httpStatusCode int + if err != nil { + responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + stringToBeLogged = "Error: failed to check fault status" + httpStatusCode = http.StatusInternalServerError + } else { + stringToBeLogged = "Successfully checked status for fault" + httpStatusCode = http.StatusOK + if faultStatus { + responseBody = types.NewNetworkFaultInjectionSuccessResponse("running") + } else { + responseBody = types.NewNetworkFaultInjectionSuccessResponse("not-running") + } + } + + logger.Info(stringToBeLogged, logger.Fields{ field.RequestType: requestType, field.Request: request.ToString(), field.Response: responseBody.ToString(), }) utils.WriteJSONResponse( w, - http.StatusOK, + httpStatusCode, responseBody, requestType, ) @@ -692,3 +721,139 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return nil } + +// checkPacketLossFault checks if there's existing network-packet-loss fault running. +func (h *FaultHandler) checkPacketLossFault(taskMetadata *state.TaskResponse, request types.NetworkPacketLossRequest) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + interfaceName := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName + lossPercent := request.LossPercent + networkMode := taskMetadata.TaskNetworkConfig.NetworkMode + ipSources := request.Sources + // If task's network mode is awsvpc, we need to run nsenter to access the task's network namespace. + nsenterPrefix := "" + if networkMode == "awsvpc" { + nsenterPrefix = fmt.Sprintf(nsenterCommandString, taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path) + } + + // We will run the following Linux command to assess if there existing fault. + // "tc -j q show dev {INTERFACE} parent 1:1" + // The command above gives the output of "tc q show dev {INTERFACE} parent 1:1" in json format. + // We will then unmarshall the json string and evaluate the fields of it. + tcCheckInjectionCommandComposed := nsenterPrefix + fmt.Sprintf(tcCheckInjectionCommandString, interfaceName) + fmt.Println(tcCheckInjectionCommandComposed) + cmdOutput, err := h.runExecCommand(ctx, tcCheckInjectionCommandComposed) + if err != nil { + return false, errors.New("failed to check network-packet-loss-fault: " + string(cmdOutput[:]) + err.Error()) + } + // Log the command output to better help us debug. + logger.Info(fmt.Sprintf("%s command result: %s", tcCheckInjectionCommandComposed, string(cmdOutput[:]))) + var outputUnmarshalled []map[string]interface{} + err = json.Unmarshal(cmdOutput, &outputUnmarshalled) + if err != nil { + return false, errors.New("failed to unmarshal tc command output: " + err.Error()) + } + netemExists := false + for _, line := range outputUnmarshalled { + // First check field "kind":"netem" exists. + if line["kind"] == "netem" { + // Now check if field "loss":"" exists, and if the percentage matches with the value in the request. + if options := line["options"]; options != nil { + if lossRandom := options.(map[string]interface{})["loss-random"]; lossRandom != nil { + if loss := lossRandom.(map[string]interface{})["loss"]; loss != nil { + if lossValue, ok := loss.(float64); ok { + lossPercentInPercentage := float64(*lossPercent) / 100 + if lossValue == lossPercentInPercentage { + netemExists = true + break + } + } + } + } + } + } + } + // If we didn't find anything from above, there's no fault injected. + if !netemExists { + return false, nil + } + + // Now check if the desired IPs are properly added to the filters. + // We run the following command: "tc -j filter show dev " + // The output of the command above does not format into json properly. + // It has a field like this: "options":{something here} + // Due to the field above, we won't be able to properly unmarshal this. + // Thus, we will use strings.contains directly to parse the output. + tcCheckIPCommandComposed := nsenterPrefix + fmt.Sprintf(tcCheckIPFilterCommandString, interfaceName) + cmdOutput, err = h.runExecCommand(ctx, tcCheckIPCommandComposed) + if err != nil { + return false, errors.New("failed to check network-packet-loss-fault: " + string(cmdOutput[:]) + err.Error()) + } + // Log the command output to better help us debug. + logger.Info(fmt.Sprintf("%s command result: %s", tcCheckIPCommandComposed, string(cmdOutput[:]))) + allIPAddressesInRequestExist := true + for _, ipAddress := range ipSources { + ipAddressInHex, err := convertIPAddressToHex(*ipAddress) + if err != nil { + return false, errors.New("failed to check network-packet-loss-fault: " + err.Error()) + } + patternString := "match " + ipAddressInHex + if !strings.Contains(string(cmdOutput[:]), patternString) { + allIPAddressesInRequestExist = false + break + } + } + if !allIPAddressesInRequestExist { + return false, nil + } + + return true, nil +} + +// runExecCommand wraps around the execwrapper, providing a convenient way of running any Linux command +// and getting the result in both stdout and stderr. +func (h *FaultHandler) runExecCommand(ctx context.Context, linuxCommandString string) ([]byte, error) { + cmdExec := h.OsExecWrapper.CommandContext(ctx, "/bin/sh", "-c", linuxCommandString) + return cmdExec.CombinedOutput() +} + +// convertIPAddressToHex converts an ipv4 address or ipv4 CIDR block string input into HEX format. +// If not specified, we will use the full ip namespace (mask will be /32). +// For example, string "192.168.1.100" will be converted to "c0a80164/ffffffff", +// and string "192.168.1.100/31" will be converted to "c0a80164/fffffffe". +func convertIPAddressToHex(ipAddressInString string) (string, error) { + var ipAddress, mask string + ipAddressAndMaskSeparated := strings.Split(ipAddressInString, "/") + if len(ipAddressAndMaskSeparated) > 2 { + return "", errors.New("invalid IP address") + } + ipAddress = ipAddressAndMaskSeparated[0] + // If a mask is not specified in the IP address, by default use /32. + if len(ipAddressAndMaskSeparated) == 1 { + mask = "ffffffff" + } else { + maskInInt, err := strconv.Atoi(ipAddressAndMaskSeparated[1]) + if err != nil { + return "", err + } + mask = net.CIDRMask(maskInInt, 32).String() + } + ipAddressInHexString := "" + ipAddressSplited := strings.Split(ipAddress, ".") + for _, component := range ipAddressSplited { + componentInInt, err := strconv.Atoi(component) + if err != nil { + return "", err + } + str := strconv.FormatInt(int64(componentInInt), 16) + // Edge case: values less than 0d16 will be converted to single digit hex number/ + // For example, 0d10 will be converted to 0xa instead of 0x0a. + // Thus, if we have a single digit hex string, add a "0" in front of it. + if len(str) == 1 { + str = "0" + str + } + ipAddressInHexString += str + } + return ipAddressInHexString + "/" + mask, nil +} diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index c810a48b94c..bd060299024 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -24,6 +24,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks" @@ -31,6 +32,7 @@ import ( v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" mock_state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/mocks" + mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" "github.com/golang/mock/gomock" "github.com/gorilla/mux" @@ -39,17 +41,25 @@ import ( ) const ( - endpointId = "endpointId" - port = 1234 - protocol = "tcp" - trafficType = "ingress" - delayMilliseconds = 123456789 - jitterMilliseconds = 4567 - lossPercent = 6 - taskARN = "taskArn" - awsvpcNetworkMode = "awsvpc" - deviceName = "eth0" - invalidNetworkMode = "invalid" + endpointId = "endpointId" + port = 1234 + protocol = "tcp" + trafficType = "ingress" + delayMilliseconds = 123456789 + jitterMilliseconds = 4567 + lossPercent = 6 + taskARN = "taskArn" + awsvpcNetworkMode = "awsvpc" + deviceName = "eth0" + invalidNetworkMode = "invalid" + tcLatencyFaultExistsCommandOutput = `[{"kind":"netem","handle":"10:","parent":"1:1","options":{"limit":1000,"delay":{"delay":0.1,"jitter":0,"correlation":0},"ecn":false,"gap":0}}]` + tcLossFaultExistsCommandOutput = `[{"kind":"netem","handle":"10:","dev":"eth0","parent":"1:1","options":{"limit":1000,"loss-random":{"loss":0.06,"correlation":0},"ecn":false,"gap":0}}]` + tcLossFaultDoesNotExistCommandOutput = `[{"kind":"dummyname"}]` + tcLossFilterAllIPsExistCommandOutput = `[{"parent":"1:","protocol":"ip","pref":1,"kind":"u32","chain":0,"options":{fh 800::805 order 2053 key ht 800 bkt 0 flowid 1:1 not_in_hw +match 345f9a01/ffffffff at 16}},{"parent":"1:","protocol":"ip","pref":1,"kind":"u32","chain":0,"options":{fh 800::805 order 2053 key ht 800 bkt 0 flowid 1:1 not_in_hw +match 345f9a02/fffffffe at 16}}]` + tcLossFilterOnlyOneIPCommandOutput = `[{"parent":"1:","protocol":"ip","pref":1,"kind":"u32","chain":0,"options":{fh 800::805 order 2053 key ht 800 bkt 0 flowid 1:1 not_in_hw +match 345f9a01/ffffffff at 16}}]` ) var ( @@ -84,13 +94,20 @@ var ( NetworkNamespaces: happyNetworkNamespaces, } + happyTaskContainers = state.ContainerResponse{ + ContainerResponse: &v2.ContainerResponse{ + DockerName: "taskid-containerid", + }, + } + happyTaskResponse = state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, TaskNetworkConfig: &happyTaskNetworkConfig, FaultInjectionEnabled: true, + Containers: []state.ContainerResponse{happyTaskContainers}, } - ipSources = []string{"52.95.154.1", "52.95.154.2"} + ipSources = []string{"52.95.154.1", "52.95.154.2/31"} ) type networkFaultInjectionTestCase struct { @@ -99,6 +116,7 @@ type networkFaultInjectionTestCase struct { requestBody interface{} expectedResponseBody types.NetworkFaultInjectionResponse setAgentStateExpectations func(agentState *mock_state.MockAgentState) + setExecExpectations func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) } // Tests the path for Fault Network Faults API @@ -885,20 +903,71 @@ func TestCheckNetworkLatency(t *testing.T) { } } -func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) []networkFaultInjectionTestCase { +func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string, expectedUnhappyResponseBody string) []networkFaultInjectionTestCase { happyNetworkPacketLossReqBody := map[string]interface{}{ "LossPercent": lossPercent, "Sources": ipSources, } tcs := []networkFaultInjectionTestCase{ { - name: fmt.Sprintf("%s success", name), + name: fmt.Sprintf("%s success-running", name), expectedStatusCode: 200, requestBody: happyNetworkPacketLossReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultExistsCommandOutput), nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFilterAllIPsExistCommandOutput), nil)) + }, + }, + { + name: fmt.Sprintf("%s success-not-running", name), + expectedStatusCode: 200, + requestBody: happyNetworkPacketLossReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedUnhappyResponseBody), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultDoesNotExistCommandOutput), nil) + }, + }, + { + name: fmt.Sprintf("%s success-not-running-but-latency-is-running", name), + expectedStatusCode: 200, + requestBody: happyNetworkPacketLossReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedUnhappyResponseBody), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLatencyFaultExistsCommandOutput), nil) + }, + }, + { + name: fmt.Sprintf("%s only-one-ip-exists-not-running", name), + expectedStatusCode: 200, + requestBody: happyNetworkPacketLossReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedUnhappyResponseBody), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultExistsCommandOutput), nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFilterOnlyOneIPCommandOutput), nil)) + }, }, { name: fmt.Sprintf("%s unknown request body", name), @@ -912,6 +981,33 @@ func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder(exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultExistsCommandOutput), nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD), + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFilterAllIPsExistCommandOutput), nil)) + + }, + }, + { + name: fmt.Sprintf("%s failed to unmarshal json", name), + expectedStatusCode: 500, + requestBody: map[string]interface{}{ + "LossPercent": lossPercent, + "Sources": ipSources, + "Unknown": "", + }, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("failed to unmarshal tc command output: unexpected end of JSON input"), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + mockCMD := mock_execwrapper.NewMockCmd(ctrl) + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD) + mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(""), nil) + + }, }, { name: fmt.Sprintf("%s malformed request body 1", name), @@ -1105,9 +1201,16 @@ func generateNetworkPacketLossTestCases(name, expectedHappyResponseBody string) } func TestStartNetworkPacketLoss(t *testing.T) { - tcs := generateNetworkPacketLossTestCases("start network packet loss", "running") + tcs := generateNetworkPacketLossTestCases("start network packet loss", "running", "") for _, tc := range tcs { + // Currently the logic that the following test case covers is only implemented for CheckNetworkPacketLoss(). + // It will fail for Start and Stop. Thus, skipping them until the logic is fully implemented. + if strings.Contains(tc.name, "failed to unmarshal json") || + strings.Contains(tc.name, "success-not-running") || + strings.Contains(tc.name, "only-one-ip-exists-not-running") { + continue + } t.Run(tc.name, func(t *testing.T) { // Mocks ctrl := gomock.NewController(t) @@ -1154,8 +1257,15 @@ func TestStartNetworkPacketLoss(t *testing.T) { } func TestStopNetworkPacketLoss(t *testing.T) { - tcs := generateNetworkPacketLossTestCases("stop network packet loss", "stopped") + tcs := generateNetworkPacketLossTestCases("stop network packet loss", "stopped", "") for _, tc := range tcs { + // Currently the logic that the following test case covers is only implemented for CheckNetworkPacketLoss(). + // It will fail for Start and Stop. Thus, skipping them until the logic is fully implemented. + if strings.Contains(tc.name, "failed to unmarshal json") || + strings.Contains(tc.name, "success-not-running") || + strings.Contains(tc.name, "only-one-ip-exists-not-running") { + continue + } t.Run(tc.name, func(t *testing.T) { // Mocks ctrl := gomock.NewController(t) @@ -1202,7 +1312,7 @@ func TestStopNetworkPacketLoss(t *testing.T) { } func TestCheckNetworkPacketLoss(t *testing.T) { - tcs := generateNetworkPacketLossTestCases("check network packet loss", "running") + tcs := generateNetworkPacketLossTestCases("check network packet loss", "running", "not-running") for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { @@ -1214,10 +1324,15 @@ func TestCheckNetworkPacketLoss(t *testing.T) { metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) router := mux.NewRouter() + mockExec := mock_execwrapper.NewMockExec(ctrl) handler := New(agentState, metricsFactory) + handler.OsExecWrapper = mockExec if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) } + if tc.setExecExpectations != nil { + tc.setExecExpectations(mockExec, ctrl) + } router.HandleFunc( NetworkFaultPath(types.PacketLossFaultType), diff --git a/ecs-agent/utils/execwrapper/exec.go b/ecs-agent/utils/execwrapper/exec.go index 776a7dfb57f..674ceef8d7c 100644 --- a/ecs-agent/utils/execwrapper/exec.go +++ b/ecs-agent/utils/execwrapper/exec.go @@ -51,6 +51,8 @@ type Cmd interface { AppendExtraFiles(...*os.File) Args() []string SetIOStreams(io.Reader, io.Writer, io.Writer) + Output() ([]byte, error) + CombinedOutput() ([]byte, error) } type cmdWrapper struct { @@ -102,3 +104,11 @@ func (c *cmdWrapper) SetIOStreams(stdin io.Reader, stdout io.Writer, stderr io.W c.Stderr = stderr } } + +func (c *cmdWrapper) Output() ([]byte, error) { + return c.Cmd.Output() +} + +func (c *cmdWrapper) CombinedOutput() ([]byte, error) { + return c.Cmd.CombinedOutput() +} diff --git a/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go b/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go index 962077bd141..f61f2e777e3 100644 --- a/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go +++ b/ecs-agent/utils/execwrapper/mocks/execwrapper_mocks.go @@ -81,6 +81,21 @@ func (mr *MockCmdMockRecorder) Args() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Args", reflect.TypeOf((*MockCmd)(nil).Args)) } +// CombinedOutput mocks base method. +func (m *MockCmd) CombinedOutput() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CombinedOutput") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CombinedOutput indicates an expected call of CombinedOutput. +func (mr *MockCmdMockRecorder) CombinedOutput() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CombinedOutput", reflect.TypeOf((*MockCmd)(nil).CombinedOutput)) +} + // KillProcess mocks base method. func (m *MockCmd) KillProcess() error { m.ctrl.T.Helper() @@ -95,6 +110,21 @@ func (mr *MockCmdMockRecorder) KillProcess() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KillProcess", reflect.TypeOf((*MockCmd)(nil).KillProcess)) } +// Output mocks base method. +func (m *MockCmd) Output() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Output") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Output indicates an expected call of Output. +func (mr *MockCmdMockRecorder) Output() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Output", reflect.TypeOf((*MockCmd)(nil).Output)) +} + // Run mocks base method. func (m *MockCmd) Run() error { m.ctrl.T.Helper()