From 907c54bbae1982e1421b3cee40820398856e1f74 Mon Sep 17 00:00:00 2001 From: mye956 Date: Mon, 30 Sep 2024 17:35:17 +0000 Subject: [PATCH] Incorporating telemetry middleware into fault handlers --- agent/handlers/task_server_setup.go | 90 +++++++++++++++++-- .../handlers/task_server_setup_integ_test.go | 4 +- agent/handlers/task_server_setup_test.go | 28 +++--- 3 files changed, 101 insertions(+), 21 deletions(-) diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index 9633ac746af..47718bd5c00 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -211,43 +211,115 @@ func registerFaultHandlers( // Setting up handler endpoints for network blackhole port fault injections muxRouter.Handle( fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkBlackholePort()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StartNetworkBlackholePort(), + ), + metricsFactory, + faulttype.StartNetworkFaultPostfix, + faulttype.BlackHolePortFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkBlackHolePort()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StopNetworkBlackHolePort(), + ), + metricsFactory, + faulttype.StopNetworkFaultPostfix, + faulttype.BlackHolePortFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkBlackHolePort()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.CheckNetworkBlackHolePort(), + ), + metricsFactory, + faulttype.CheckNetworkFaultPostfix, + faulttype.BlackHolePortFaultType, + ), ).Methods("POST") // Setting up handler endpoints for network latency fault injections muxRouter.Handle( fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkLatency()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StartNetworkLatency(), + ), + metricsFactory, + faulttype.StartNetworkFaultPostfix, + faulttype.LatencyFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkLatency()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StopNetworkLatency(), + ), + metricsFactory, + faulttype.StopNetworkFaultPostfix, + faulttype.LatencyFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkLatency()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.CheckNetworkLatency(), + ), + metricsFactory, + faulttype.CheckNetworkFaultPostfix, + faulttype.LatencyFaultType, + ), ).Methods("POST") // Setting up handler endpoints for network packet loss fault injections muxRouter.Handle( fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkPacketLoss()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StartNetworkPacketLoss(), + ), + metricsFactory, + faulttype.StartNetworkFaultPostfix, + faulttype.PacketLossFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkPacketLoss()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StopNetworkPacketLoss(), + ), + metricsFactory, + faulttype.StopNetworkFaultPostfix, + faulttype.PacketLossFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkPacketLoss()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.CheckNetworkPacketLoss(), + ), + metricsFactory, + faulttype.CheckNetworkFaultPostfix, + faulttype.PacketLossFaultType, + ), ).Methods("POST") seelog.Debug("Successfully set up Fault TMDS handlers") diff --git a/agent/handlers/task_server_setup_integ_test.go b/agent/handlers/task_server_setup_integ_test.go index 63e057d3174..109e1f08d22 100644 --- a/agent/handlers/task_server_setup_integ_test.go +++ b/agent/handlers/task_server_setup_integ_test.go @@ -28,7 +28,7 @@ import ( agentV4 "github.com/aws/amazon-ecs-agent/agent/handlers/v4" mock_stats "github.com/aws/amazon-ecs-agent/agent/stats/mock" mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" - mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" "github.com/golang/mock/gomock" "github.com/gorilla/mux" @@ -56,7 +56,7 @@ func startServer(t *testing.T) (*http.Server, int) { ecsClient := mock_ecs.NewMockECSClient(ctrl) agentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn) - metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) + metricsFactory := metrics.NewNopEntryFactory() execWrapper := mock_execwrapper.NewMockExec(ctrl) registerFaultHandlers(router, agentState, metricsFactory, execWrapper) diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 4ae6597b922..c328cd077a9 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -132,6 +132,7 @@ const ( tcLatencyFaultExistsCommandOutput = `[{"kind":"netem","handle":"10:","parent":"1:1","options":{"limit":1000,"delay":{"delay":123456789,"jitter":4567,"correlation":0},"ecn":false,"gap":0}}]` tcCommandEmptyOutput = `[]` requestTimeoutDuration = 5 * time.Second + durationMetricPrefix = "MetadataServer.%s%sDuration" ) var ( @@ -3808,7 +3809,7 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix), faulttype.StartNetworkFaultPostfix, faulttype.BlackHolePortFaultType) } func TestRegisterStopBlackholePortFaultHandler(t *testing.T) { @@ -3828,7 +3829,7 @@ func TestRegisterStopBlackholePortFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("stop blackhole port", "stopped", setExecExpectations, happyBlackHolePortReqBody) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix), faulttype.StopNetworkFaultPostfix, faulttype.BlackHolePortFaultType) } func TestRegisterCheckBlackholePortFaultHandler(t *testing.T) { @@ -3842,7 +3843,7 @@ func TestRegisterCheckBlackholePortFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("check blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix), faulttype.CheckNetworkFaultPostfix, faulttype.BlackHolePortFaultType) } func TestRegisterStartLatencyFaultHandler(t *testing.T) { @@ -3858,7 +3859,7 @@ func TestRegisterStartLatencyFaultHandler(t *testing.T) { mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil) } tcs := generateCommonNetworkFaultInjectionTestCases("start latency", "running", setExecExpectations, happyNetworkLatencyReqBody) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix), faulttype.StartNetworkFaultPostfix, faulttype.LatencyFaultType) } func TestRegisterStopLatencyFaultHandler(t *testing.T) { @@ -3872,7 +3873,7 @@ func TestRegisterStopLatencyFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("stop latency", "stopped", setExecExpectations, nil) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix), faulttype.StopNetworkFaultPostfix, faulttype.LatencyFaultType) } func TestRegisterCheckLatencyFaultHandler(t *testing.T) { @@ -3886,7 +3887,7 @@ func TestRegisterCheckLatencyFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("check latency", "running", setExecExpectations, nil) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix), faulttype.CheckNetworkFaultPostfix, faulttype.LatencyFaultType) } func TestRegisterStartPacketLossFaultHandler(t *testing.T) { @@ -3902,7 +3903,7 @@ func TestRegisterStartPacketLossFaultHandler(t *testing.T) { mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil) } tcs := generateCommonNetworkFaultInjectionTestCases("start packet loss", "running", setExecExpectations, happyNetworkPacketLossReqBody) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix), faulttype.StartNetworkFaultPostfix, faulttype.PacketLossFaultType) } func TestRegisterStopPacketLossFaultHandler(t *testing.T) { @@ -3916,7 +3917,7 @@ func TestRegisterStopPacketLossFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("stop packet loss", "stopped", setExecExpectations, nil) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix), faulttype.StopNetworkFaultPostfix, faulttype.PacketLossFaultType) } func TestRegisterCheckPacketLossFaultHandler(t *testing.T) { @@ -3930,10 +3931,10 @@ func TestRegisterCheckPacketLossFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("check packet loss", "running", setExecExpectations, nil) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix), faulttype.CheckNetworkFaultPostfix, faulttype.PacketLossFaultType) } -func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, tmdsEndpoint string) { +func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, tmdsEndpoint, faultOperation, faultType string) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { // Mocks @@ -3946,6 +3947,13 @@ func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, tmdsEndp agentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn) metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) + durationMetricEntry := mock_metrics.NewMockEntry(ctrl) + gomock.InOrder( + metricsFactory.EXPECT().New(fmt.Sprintf(durationMetricPrefix, faultOperation, faultType)).Return(durationMetricEntry).Times(1), + durationMetricEntry.EXPECT().WithFields(gomock.Any()).Return(durationMetricEntry).Times(1), + durationMetricEntry.EXPECT().WithGauge(gomock.Any()).Return(durationMetricEntry).Times(1), + durationMetricEntry.EXPECT().Done(nil).Times(1), + ) execWrapper := mock_execwrapper.NewMockExec(ctrl) if tc.setStateExpectations != nil {