diff --git a/agent/go.mod b/agent/go.mod index bc5db0c90d1..f916d5d64ec 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -14,6 +14,7 @@ require ( github.com/containernetworking/cni v1.1.2 github.com/containernetworking/plugins v1.4.1 github.com/deniswernert/udev v0.0.0-20170418162847-a12666f7b5a1 + github.com/didip/tollbooth v4.0.2+incompatible github.com/docker/distribution v2.8.2+incompatible github.com/docker/docker v24.0.9+incompatible github.com/docker/go-connections v0.4.0 @@ -41,7 +42,6 @@ require ( github.com/cilium/ebpf v0.9.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/didip/tollbooth v4.0.2+incompatible // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index 057de85f435..5d4974f95e8 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -41,6 +41,8 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" "github.com/cihub/seelog" + "github.com/didip/tollbooth" + "github.com/didip/tollbooth/limiter" "github.com/gorilla/mux" ) @@ -207,50 +209,57 @@ func registerFaultHandlers( } // Setting up handler endpoints for network blackhole port fault injections - muxRouter.HandleFunc( + muxRouter.Handle( fault.NetworkFaultPath(faulttype.BlackHolePortFaultType), - handler.StartNetworkBlackholePort(), + tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkBlackholePort()), ).Methods("PUT") - muxRouter.HandleFunc( + muxRouter.Handle( fault.NetworkFaultPath(faulttype.BlackHolePortFaultType), - handler.StopNetworkBlackHolePort(), + tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkBlackHolePort()), ).Methods("DELETE") - muxRouter.HandleFunc( + muxRouter.Handle( fault.NetworkFaultPath(faulttype.BlackHolePortFaultType), - handler.CheckNetworkBlackHolePort(), + tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkBlackHolePort()), ).Methods("GET") // Setting up handler endpoints for network latency fault injections - muxRouter.HandleFunc( + muxRouter.Handle( fault.NetworkFaultPath(faulttype.LatencyFaultType), - handler.StartNetworkLatency(), + tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkLatency()), ).Methods("PUT") - muxRouter.HandleFunc( + muxRouter.Handle( fault.NetworkFaultPath(faulttype.LatencyFaultType), - handler.StopNetworkLatency(), + tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkLatency()), ).Methods("DELETE") - muxRouter.HandleFunc( + muxRouter.Handle( fault.NetworkFaultPath(faulttype.LatencyFaultType), - handler.CheckNetworkLatency(), + tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkLatency()), ).Methods("GET") // Setting up handler endpoints for network packet loss fault injections - muxRouter.HandleFunc( + muxRouter.Handle( fault.NetworkFaultPath(faulttype.PacketLossFaultType), - handler.StartNetworkPacketLoss(), + tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkPacketLoss()), ).Methods("PUT") - muxRouter.HandleFunc( + muxRouter.Handle( fault.NetworkFaultPath(faulttype.PacketLossFaultType), - handler.StopNetworkPacketLoss(), + tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkPacketLoss()), ).Methods("DELETE") - muxRouter.HandleFunc( + muxRouter.Handle( fault.NetworkFaultPath(faulttype.PacketLossFaultType), - handler.CheckNetworkPacketLoss(), + tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkPacketLoss()), ).Methods("GET") seelog.Debug("Successfully set up Fault TMDS handlers") } +// Creates a tollbooth ratelimiter for the Fault Handler APIs +func createRateLimiter() *limiter.Limiter { + lmt := tollbooth.NewLimiter(0.2, nil) + lmt.SetMessage("You have reached maximum request limit") + return lmt +} + // ServeTaskHTTPEndpoint serves task/container metadata, task/container stats, IAM Role Credentials, and Agent APIs // for tasks being managed by the agent. func ServeTaskHTTPEndpoint( diff --git a/agent/handlers/task_server_setup_integ_test.go b/agent/handlers/task_server_setup_integ_test.go new file mode 100644 index 00000000000..7cbccb7a474 --- /dev/null +++ b/agent/handlers/task_server_setup_integ_test.go @@ -0,0 +1,180 @@ +//go:build integration +// +build integration + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package handlers + +import ( + "context" + "fmt" + "net" + "net/http" + "testing" + "time" + + mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" + agentV4 "github.com/aws/amazon-ecs-agent/agent/handlers/v4" + mock_stats "github.com/aws/amazon-ecs-agent/agent/stats/mock" + mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" + mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks" + mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" + "github.com/golang/mock/gomock" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + clusterName = "default" + availabilityzone = "us-west-2b" + vpcID = "test-vpc-id" + containerInstanceArn = "containerInstanceArn-test" +) + +// This function starts the server and listens on a specified port +func startServer(t *testing.T) (*http.Server, int) { + router := mux.NewRouter() + + // Mocks + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + state := mock_dockerstate.NewMockTaskEngineState(ctrl) + statsEngine := mock_stats.NewMockEngine(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) + + agentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn) + metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) + execWrapper := mock_execwrapper.NewMockExec(ctrl) + + registerFaultHandlers(router, agentState, metricsFactory, execWrapper) + + server := &http.Server{ + Addr: ":0", // Lets the system allocate an available port + Handler: router, + } + + listener, err := net.Listen("tcp", server.Addr) + require.NoError(t, err) + + port := listener.Addr().(*net.TCPAddr).Port + t.Logf("Server started on port: %d", port) + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + t.Logf("ListenAndServe(): %s\n", err) + } + }() + return server, port +} + +// This function shuts down the server after the test +func stopServer(t *testing.T, server *http.Server) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := server.Shutdown(ctx); err != nil { + t.Logf("Server Shutdown Failed:%+v", err) + } else { + t.Logf("Server Exited Properly") + } +} + +// Table-driven tests for rate limiter +func TestRateLimiterIntegration(t *testing.T) { + + testCases := []struct { + name string + method1 string + method2 string + url1 string + url2 string + expectedStatus2 int + assertNotEqual bool + }{ + { + name: "Same network faults A1 + same methods B1", + method1: "GET", + method2: "GET", + url1: "/api/container123/fault/v1/network-blackhole-port", + url2: "/api/container123/fault/v1/network-blackhole-port", + expectedStatus2: http.StatusTooManyRequests, + assertNotEqual: false, + }, + { + name: "Same network fault A1 + different methods B1, B2", + method1: "GET", + method2: "PUT", + url1: "/api/container123/fault/v1/network-blackhole-port", + url2: "/api/container123/fault/v1/network-blackhole-port", + expectedStatus2: http.StatusTooManyRequests, + assertNotEqual: true, + }, + { + name: "Different network faults A1, A2 + same method B1", + method1: "GET", + method2: "GET", + url1: "/api/container123/fault/v1/network-blackhole-port", + url2: "/api/container123/fault/v1/network-latency", + expectedStatus2: http.StatusTooManyRequests, + assertNotEqual: true, + }, + { + name: "Different network faults A1, A3 + same method B1", + method1: "GET", + method2: "GET", + url1: "/api/container123/fault/v1/network-blackhole-port", + url2: "/api/container123/fault/v1/network-packet-loss", + expectedStatus2: http.StatusTooManyRequests, + assertNotEqual: true, + }, + { + name: "Different network faults A2, A3 + same methods B1", + method1: "GET", + method2: "GET", + url1: "/api/container123/fault/v1/network-latency", + url2: "/api/container123/fault/v1/network-packet-loss", + expectedStatus2: http.StatusTooManyRequests, + assertNotEqual: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server, port := startServer(t) + client := &http.Client{} + // First request + req1, err := http.NewRequest(tc.method1, getURL(port, tc.url1), nil) + require.NoError(t, err) + _, err = client.Do(req1) + require.NoError(t, err) + + // Second request + req2, err := http.NewRequest(tc.method2, getURL(port, tc.url2), nil) + require.NoError(t, err) + resp2, err := client.Do(req2) + require.NoError(t, err) + if tc.assertNotEqual { + assert.NotEqual(t, tc.expectedStatus2, resp2.StatusCode) + } else { + assert.Equal(t, tc.expectedStatus2, resp2.StatusCode) + } + stopServer(t, server) + }) + } +} + +// Utility function to generate a URL with a dynamic port +func getURL(port int, path string) string { + return "http://localhost:" + fmt.Sprintf("%d", port) + path +}