Skip to content

Commit

Permalink
Add ratelimiter for fault injection handlers (aws#4340)
Browse files Browse the repository at this point in the history
* Add ratelimiter for fault injection handler

* Add integ tests and refactor code

* Refactor integ test & ratelimiter setup

* Refactor tests to table driven tests

* Add execWrapper to integ tests

* Add dynamic port allocation

---------

Co-authored-by: Harish Senthilkumar <[email protected]>
  • Loading branch information
2 people authored and mye956 committed Oct 3, 2024
1 parent f5a0eb7 commit 5e1f628
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 19 deletions.
2 changes: 1 addition & 1 deletion agent/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
github.com/containernetworking/cni v1.1.2
github.com/containernetworking/plugins v1.4.1
github.com/deniswernert/udev v0.0.0-20170418162847-a12666f7b5a1
github.com/didip/tollbooth v4.0.2+incompatible
github.com/docker/distribution v2.8.2+incompatible
github.com/docker/docker v24.0.9+incompatible
github.com/docker/go-connections v0.4.0
Expand Down Expand Up @@ -41,7 +42,6 @@ require (
github.com/cilium/ebpf v0.9.1 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/didip/tollbooth v4.0.2+incompatible // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/godbus/dbus/v5 v5.1.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
Expand Down
45 changes: 27 additions & 18 deletions agent/handlers/task_server_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ import (
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry"

"github.com/cihub/seelog"
"github.com/didip/tollbooth"
"github.com/didip/tollbooth/limiter"
"github.com/gorilla/mux"
)

Expand Down Expand Up @@ -207,50 +209,57 @@ func registerFaultHandlers(
}

// Setting up handler endpoints for network blackhole port fault injections
muxRouter.HandleFunc(
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.BlackHolePortFaultType),
handler.StartNetworkBlackholePort(),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkBlackholePort()),
).Methods("PUT")
muxRouter.HandleFunc(
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.BlackHolePortFaultType),
handler.StopNetworkBlackHolePort(),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkBlackHolePort()),
).Methods("DELETE")
muxRouter.HandleFunc(
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.BlackHolePortFaultType),
handler.CheckNetworkBlackHolePort(),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkBlackHolePort()),
).Methods("GET")

// Setting up handler endpoints for network latency fault injections
muxRouter.HandleFunc(
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.LatencyFaultType),
handler.StartNetworkLatency(),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkLatency()),
).Methods("PUT")
muxRouter.HandleFunc(
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.LatencyFaultType),
handler.StopNetworkLatency(),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkLatency()),
).Methods("DELETE")
muxRouter.HandleFunc(
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.LatencyFaultType),
handler.CheckNetworkLatency(),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkLatency()),
).Methods("GET")

// Setting up handler endpoints for network packet loss fault injections
muxRouter.HandleFunc(
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.PacketLossFaultType),
handler.StartNetworkPacketLoss(),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkPacketLoss()),
).Methods("PUT")
muxRouter.HandleFunc(
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.PacketLossFaultType),
handler.StopNetworkPacketLoss(),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkPacketLoss()),
).Methods("DELETE")
muxRouter.HandleFunc(
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.PacketLossFaultType),
handler.CheckNetworkPacketLoss(),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkPacketLoss()),
).Methods("GET")

seelog.Debug("Successfully set up Fault TMDS handlers")
}

// Creates a tollbooth ratelimiter for the Fault Handler APIs
func createRateLimiter() *limiter.Limiter {
lmt := tollbooth.NewLimiter(0.2, nil)
lmt.SetMessage("You have reached maximum request limit")
return lmt
}

// ServeTaskHTTPEndpoint serves task/container metadata, task/container stats, IAM Role Credentials, and Agent APIs
// for tasks being managed by the agent.
func ServeTaskHTTPEndpoint(
Expand Down
180 changes: 180 additions & 0 deletions agent/handlers/task_server_setup_integ_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
//go:build integration
// +build integration

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package handlers

import (
"context"
"fmt"
"net"
"net/http"
"testing"
"time"

mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks"
agentV4 "github.com/aws/amazon-ecs-agent/agent/handlers/v4"
mock_stats "github.com/aws/amazon-ecs-agent/agent/stats/mock"
mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks"
mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks"
mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const (
clusterName = "default"
availabilityzone = "us-west-2b"
vpcID = "test-vpc-id"
containerInstanceArn = "containerInstanceArn-test"
)

// This function starts the server and listens on a specified port
func startServer(t *testing.T) (*http.Server, int) {
router := mux.NewRouter()

// Mocks
ctrl := gomock.NewController(t)
defer ctrl.Finish()

state := mock_dockerstate.NewMockTaskEngineState(ctrl)
statsEngine := mock_stats.NewMockEngine(ctrl)
ecsClient := mock_ecs.NewMockECSClient(ctrl)

agentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn)
metricsFactory := mock_metrics.NewMockEntryFactory(ctrl)
execWrapper := mock_execwrapper.NewMockExec(ctrl)

registerFaultHandlers(router, agentState, metricsFactory, execWrapper)

server := &http.Server{
Addr: ":0", // Lets the system allocate an available port
Handler: router,
}

listener, err := net.Listen("tcp", server.Addr)
require.NoError(t, err)

port := listener.Addr().(*net.TCPAddr).Port
t.Logf("Server started on port: %d", port)

go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
t.Logf("ListenAndServe(): %s\n", err)
}
}()
return server, port
}

// This function shuts down the server after the test
func stopServer(t *testing.T, server *http.Server) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
t.Logf("Server Shutdown Failed:%+v", err)
} else {
t.Logf("Server Exited Properly")
}
}

// Table-driven tests for rate limiter
func TestRateLimiterIntegration(t *testing.T) {

testCases := []struct {
name string
method1 string
method2 string
url1 string
url2 string
expectedStatus2 int
assertNotEqual bool
}{
{
name: "Same network faults A1 + same methods B1",
method1: "GET",
method2: "GET",
url1: "/api/container123/fault/v1/network-blackhole-port",
url2: "/api/container123/fault/v1/network-blackhole-port",
expectedStatus2: http.StatusTooManyRequests,
assertNotEqual: false,
},
{
name: "Same network fault A1 + different methods B1, B2",
method1: "GET",
method2: "PUT",
url1: "/api/container123/fault/v1/network-blackhole-port",
url2: "/api/container123/fault/v1/network-blackhole-port",
expectedStatus2: http.StatusTooManyRequests,
assertNotEqual: true,
},
{
name: "Different network faults A1, A2 + same method B1",
method1: "GET",
method2: "GET",
url1: "/api/container123/fault/v1/network-blackhole-port",
url2: "/api/container123/fault/v1/network-latency",
expectedStatus2: http.StatusTooManyRequests,
assertNotEqual: true,
},
{
name: "Different network faults A1, A3 + same method B1",
method1: "GET",
method2: "GET",
url1: "/api/container123/fault/v1/network-blackhole-port",
url2: "/api/container123/fault/v1/network-packet-loss",
expectedStatus2: http.StatusTooManyRequests,
assertNotEqual: true,
},
{
name: "Different network faults A2, A3 + same methods B1",
method1: "GET",
method2: "GET",
url1: "/api/container123/fault/v1/network-latency",
url2: "/api/container123/fault/v1/network-packet-loss",
expectedStatus2: http.StatusTooManyRequests,
assertNotEqual: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
server, port := startServer(t)
client := &http.Client{}
// First request
req1, err := http.NewRequest(tc.method1, getURL(port, tc.url1), nil)
require.NoError(t, err)
_, err = client.Do(req1)
require.NoError(t, err)

// Second request
req2, err := http.NewRequest(tc.method2, getURL(port, tc.url2), nil)
require.NoError(t, err)
resp2, err := client.Do(req2)
require.NoError(t, err)
if tc.assertNotEqual {
assert.NotEqual(t, tc.expectedStatus2, resp2.StatusCode)
} else {
assert.Equal(t, tc.expectedStatus2, resp2.StatusCode)
}
stopServer(t, server)
})
}
}

// Utility function to generate a URL with a dynamic port
func getURL(port int, path string) string {
return "http://localhost:" + fmt.Sprintf("%d", port) + path
}

0 comments on commit 5e1f628

Please sign in to comment.