diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index 5d4974f95e8..9633ac746af 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -210,45 +210,45 @@ func registerFaultHandlers( // Setting up handler endpoints for network blackhole port fault injections muxRouter.Handle( - fault.NetworkFaultPath(faulttype.BlackHolePortFaultType), + fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix), tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkBlackholePort()), - ).Methods("PUT") + ).Methods("POST") muxRouter.Handle( - fault.NetworkFaultPath(faulttype.BlackHolePortFaultType), + fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix), tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkBlackHolePort()), - ).Methods("DELETE") + ).Methods("POST") muxRouter.Handle( - fault.NetworkFaultPath(faulttype.BlackHolePortFaultType), + fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix), tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkBlackHolePort()), - ).Methods("GET") + ).Methods("POST") // Setting up handler endpoints for network latency fault injections muxRouter.Handle( - fault.NetworkFaultPath(faulttype.LatencyFaultType), + fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix), tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkLatency()), - ).Methods("PUT") + ).Methods("POST") muxRouter.Handle( - fault.NetworkFaultPath(faulttype.LatencyFaultType), + fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix), tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkLatency()), - ).Methods("DELETE") + ).Methods("POST") muxRouter.Handle( - fault.NetworkFaultPath(faulttype.LatencyFaultType), + fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix), tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkLatency()), - ).Methods("GET") + ).Methods("POST") // Setting up handler endpoints for network packet loss fault injections muxRouter.Handle( - fault.NetworkFaultPath(faulttype.PacketLossFaultType), + fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix), tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkPacketLoss()), - ).Methods("PUT") + ).Methods("POST") muxRouter.Handle( - fault.NetworkFaultPath(faulttype.PacketLossFaultType), + fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix), tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkPacketLoss()), - ).Methods("DELETE") + ).Methods("POST") muxRouter.Handle( - fault.NetworkFaultPath(faulttype.PacketLossFaultType), + fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix), tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkPacketLoss()), - ).Methods("GET") + ).Methods("POST") seelog.Debug("Successfully set up Fault TMDS handlers") } diff --git a/agent/handlers/task_server_setup_integ_test.go b/agent/handlers/task_server_setup_integ_test.go index 7cbccb7a474..c57a12b6c58 100644 --- a/agent/handlers/task_server_setup_integ_test.go +++ b/agent/handlers/task_server_setup_integ_test.go @@ -105,46 +105,46 @@ func TestRateLimiterIntegration(t *testing.T) { }{ { name: "Same network faults A1 + same methods B1", - method1: "GET", - method2: "GET", - url1: "/api/container123/fault/v1/network-blackhole-port", - url2: "/api/container123/fault/v1/network-blackhole-port", + method1: "POST", + method2: "POST", + url1: "/api/container123/fault/v1/network-blackhole-port/status", + url2: "/api/container123/fault/v1/network-blackhole-port/status", expectedStatus2: http.StatusTooManyRequests, assertNotEqual: false, }, { name: "Same network fault A1 + different methods B1, B2", - method1: "GET", - method2: "PUT", - url1: "/api/container123/fault/v1/network-blackhole-port", - url2: "/api/container123/fault/v1/network-blackhole-port", + method1: "POST", + method2: "POST", + url1: "/api/container123/fault/v1/network-blackhole-port/status", + url2: "/api/container123/fault/v1/network-blackhole-port/start", expectedStatus2: http.StatusTooManyRequests, assertNotEqual: true, }, { name: "Different network faults A1, A2 + same method B1", - method1: "GET", - method2: "GET", - url1: "/api/container123/fault/v1/network-blackhole-port", - url2: "/api/container123/fault/v1/network-latency", + method1: "POST", + method2: "POST", + url1: "/api/container123/fault/v1/network-blackhole-port/status", + url2: "/api/container123/fault/v1/network-latency/status", expectedStatus2: http.StatusTooManyRequests, assertNotEqual: true, }, { name: "Different network faults A1, A3 + same method B1", - method1: "GET", - method2: "GET", - url1: "/api/container123/fault/v1/network-blackhole-port", - url2: "/api/container123/fault/v1/network-packet-loss", + method1: "POST", + method2: "POST", + url1: "/api/container123/fault/v1/network-blackhole-port/status", + url2: "/api/container123/fault/v1/network-packet-loss/status", expectedStatus2: http.StatusTooManyRequests, assertNotEqual: true, }, { name: "Different network faults A2, A3 + same methods B1", - method1: "GET", - method2: "GET", - url1: "/api/container123/fault/v1/network-latency", - url2: "/api/container123/fault/v1/network-packet-loss", + method1: "POST", + method2: "POST", + url1: "/api/container123/fault/v1/network-latency/status", + url2: "/api/container123/fault/v1/network-packet-loss/status", expectedStatus2: http.StatusTooManyRequests, assertNotEqual: true, }, diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 7d052dc9bdc..a1cb61dfa66 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -47,6 +47,7 @@ import ( mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" "github.com/aws/amazon-ecs-agent/ecs-agent/stats" + faulthandler "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers" faulttype "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types" tmdsresponse "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response" tp "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" @@ -56,13 +57,13 @@ import ( v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" - "github.com/gorilla/mux" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" "github.com/docker/docker/api/types" "github.com/golang/mock/gomock" + "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -3785,7 +3786,7 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody) - testRegisterFaultHandler(t, tcs, "PUT", faulttype.BlackHolePortFaultType) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix)) } func TestRegisterStopBlackholePortFaultHandler(t *testing.T) { @@ -3805,7 +3806,7 @@ func TestRegisterStopBlackholePortFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("stop blackhole port", "stopped", setExecExpectations, happyBlackHolePortReqBody) - testRegisterFaultHandler(t, tcs, "DELETE", faulttype.BlackHolePortFaultType) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix)) } func TestRegisterCheckBlackholePortFaultHandler(t *testing.T) { @@ -3819,7 +3820,7 @@ func TestRegisterCheckBlackholePortFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("check blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody) - testRegisterFaultHandler(t, tcs, "GET", faulttype.BlackHolePortFaultType) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix)) } func TestRegisterStartLatencyFaultHandler(t *testing.T) { @@ -3828,7 +3829,7 @@ func TestRegisterStartLatencyFaultHandler(t *testing.T) { } tcs := generateCommonNetworkFaultInjectionTestCases("start latency", "running", setExecExpectations, happyNetworkLatencyReqBody) - testRegisterFaultHandler(t, tcs, "PUT", faulttype.LatencyFaultType) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix)) } func TestRegisterStopLatencyFaultHandler(t *testing.T) { @@ -3837,7 +3838,7 @@ func TestRegisterStopLatencyFaultHandler(t *testing.T) { } tcs := generateCommonNetworkFaultInjectionTestCases("stop latency", "stopped", setExecExpectations, happyNetworkLatencyReqBody) - testRegisterFaultHandler(t, tcs, "DELETE", faulttype.LatencyFaultType) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix)) } func TestRegisterCheckLatencyFaultHandler(t *testing.T) { @@ -3846,7 +3847,7 @@ func TestRegisterCheckLatencyFaultHandler(t *testing.T) { } tcs := generateCommonNetworkFaultInjectionTestCases("check latency", "running", setExecExpectations, happyNetworkLatencyReqBody) - testRegisterFaultHandler(t, tcs, "GET", faulttype.LatencyFaultType) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix)) } func TestRegisterStartPacketLossFaultHandler(t *testing.T) { @@ -3862,7 +3863,7 @@ func TestRegisterStartPacketLossFaultHandler(t *testing.T) { mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil) } tcs := generateCommonNetworkFaultInjectionTestCases("start packet loss", "running", setExecExpectations, happyNetworkPacketLossReqBody) - testRegisterFaultHandler(t, tcs, "PUT", faulttype.PacketLossFaultType) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix)) } func TestRegisterStopPacketLossFaultHandler(t *testing.T) { @@ -3876,7 +3877,7 @@ func TestRegisterStopPacketLossFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("stop packet loss", "stopped", setExecExpectations, happyNetworkPacketLossReqBody) - testRegisterFaultHandler(t, tcs, "DELETE", faulttype.PacketLossFaultType) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix)) } func TestRegisterCheckPacketLossFaultHandler(t *testing.T) { @@ -3890,10 +3891,10 @@ func TestRegisterCheckPacketLossFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("check packet loss", "running", setExecExpectations, happyNetworkPacketLossReqBody) - testRegisterFaultHandler(t, tcs, "GET", faulttype.PacketLossFaultType) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix)) } -func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, method, fault string) { +func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, tmdsEndpoint string) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { // Mocks @@ -3916,6 +3917,30 @@ func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, method, tc.setExecExpectations(execWrapper, ctrl) } + var tmdsAPI string + switch tmdsEndpoint { + case faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-blackhole-port/start" + case faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-blackhole-port/stop" + case faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-blackhole-port/status" + case faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-latency/start" + case faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-latency/stop" + case faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-latency/status" + case faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-packet-loss/start" + case faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-packet-loss/stop" + case faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-packet-loss/status" + default: + t.Error("Unrecognized TMDS Endpoint") + } + router := mux.NewRouter() registerFaultHandlers(router, agentState, metricsFactory, execWrapper) var requestBody io.Reader @@ -3925,7 +3950,7 @@ func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, method, requestBody = bytes.NewReader(reqBodyBytes) } - req, err := http.NewRequest(method, fmt.Sprintf("/api/%s/fault/v1/%s", endpointId, fault), + req, err := http.NewRequest("POST", fmt.Sprintf(tmdsAPI, endpointId), requestBody) require.NoError(t, err) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index db002641ea9..1e3e30814f1 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -86,9 +86,9 @@ func New(agentState state.AgentState, mf metrics.EntryFactory, execWrapper execw } // NetworkFaultPath will take in a fault type and return the TMDS endpoint path -func NetworkFaultPath(fault string) string { - return fmt.Sprintf("/api/%s/fault/v1/%s", - utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault) +func NetworkFaultPath(fault string, operationType string) string { + return fmt.Sprintf("/api/%s/fault/v1/%s/%s", + utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault, operationType) } // loadLock returns the lock associated with given key. diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types/types.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types/types.go index 1e69c7359d9..80a00550e1e 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types/types.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types/types.go @@ -26,6 +26,9 @@ const ( BlackHolePortFaultType = "network-blackhole-port" LatencyFaultType = "network-latency" PacketLossFaultType = "network-packet-loss" + StartNetworkFaultPostfix = "start" + StopNetworkFaultPostfix = "stop" + CheckNetworkFaultPostfix = "status" missingRequiredFieldError = "required parameter %s is missing" MissingRequestBodyError = "required request body is missing" invalidValueError = "invalid value %s for parameter %s" diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index db002641ea9..1e3e30814f1 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -86,9 +86,9 @@ func New(agentState state.AgentState, mf metrics.EntryFactory, execWrapper execw } // NetworkFaultPath will take in a fault type and return the TMDS endpoint path -func NetworkFaultPath(fault string) string { - return fmt.Sprintf("/api/%s/fault/v1/%s", - utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault) +func NetworkFaultPath(fault string, operationType string) string { + return fmt.Sprintf("/api/%s/fault/v1/%s/%s", + utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault, operationType) } // loadLock returns the lock associated with given key. diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index 014f1ad4e0d..87c1a611cfa 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -131,22 +131,28 @@ type networkFaultInjectionTestCase struct { // Tests the path for Fault Network Faults API func TestFaultBlackholeFaultPath(t *testing.T) { - assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-blackhole-port", NetworkFaultPath(types.BlackHolePortFaultType)) + assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-blackhole-port/start", NetworkFaultPath(types.BlackHolePortFaultType, types.StartNetworkFaultPostfix)) + assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-blackhole-port/stop", NetworkFaultPath(types.BlackHolePortFaultType, types.StopNetworkFaultPostfix)) + assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-blackhole-port/status", NetworkFaultPath(types.BlackHolePortFaultType, types.CheckNetworkFaultPostfix)) } func TestFaultLatencyFaultPath(t *testing.T) { - assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-latency", NetworkFaultPath(types.LatencyFaultType)) + assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-latency/start", NetworkFaultPath(types.LatencyFaultType, types.StartNetworkFaultPostfix)) + assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-latency/stop", NetworkFaultPath(types.LatencyFaultType, types.StopNetworkFaultPostfix)) + assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-latency/status", NetworkFaultPath(types.LatencyFaultType, types.CheckNetworkFaultPostfix)) } func TestFaultPacketLossFaultPath(t *testing.T) { - assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-packet-loss", NetworkFaultPath(types.PacketLossFaultType)) + assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-packet-loss/start", NetworkFaultPath(types.PacketLossFaultType, types.StartNetworkFaultPostfix)) + assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-packet-loss/stop", NetworkFaultPath(types.PacketLossFaultType, types.StopNetworkFaultPostfix)) + assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/fault/v1/network-packet-loss/status", NetworkFaultPath(types.PacketLossFaultType, types.CheckNetworkFaultPostfix)) } // testNetworkFaultInjectionCommon will be used by unit tests for all 9 fault injection Network Fault APIs. // Unit tests for all 9 APIs interact with the TMDS server and share similar logic. // Thus, use a shared base method to reduce duplicated code. func testNetworkFaultInjectionCommon(t *testing.T, - tcs []networkFaultInjectionTestCase, faultType string, httpMethod string) { + tcs []networkFaultInjectionTestCase, tmdsEndpoint string) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { // Mocks @@ -168,51 +174,42 @@ func testNetworkFaultInjectionCommon(t *testing.T, var handleMethod func(http.ResponseWriter, *http.Request) var tmdsAPI string - switch faultType { - case types.BlackHolePortFaultType: - tmdsAPI = "/api/%s/fault/v1/network-blackhole-port" - switch httpMethod { - case http.MethodPut: - handleMethod = handler.StartNetworkBlackholePort() - case http.MethodDelete: - handleMethod = handler.StopNetworkBlackHolePort() - case http.MethodGet: - handleMethod = handler.CheckNetworkBlackHolePort() - default: - t.Error("Unrecognized HTTP method") - } - case types.LatencyFaultType: - tmdsAPI = "/api/%s/fault/v1/network-latency" - switch httpMethod { - case http.MethodPut: - handleMethod = handler.StartNetworkLatency() - case http.MethodDelete: - handleMethod = handler.StopNetworkLatency() - case http.MethodGet: - handleMethod = handler.CheckNetworkLatency() - default: - t.Error("Unrecognized HTTP method") - } - case types.PacketLossFaultType: - tmdsAPI = "/api/%s/fault/v1/network-packet-loss" - switch httpMethod { - case http.MethodPut: - handleMethod = handler.StartNetworkPacketLoss() - case http.MethodDelete: - handleMethod = handler.StopNetworkPacketLoss() - case http.MethodGet: - handleMethod = handler.CheckNetworkPacketLoss() - default: - t.Error("Unrecognized HTTP method") - } + switch tmdsEndpoint { + case NetworkFaultPath(types.BlackHolePortFaultType, types.StartNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-blackhole-port/start" + handleMethod = handler.StartNetworkBlackholePort() + case NetworkFaultPath(types.BlackHolePortFaultType, types.StopNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-blackhole-port/stop" + handleMethod = handler.StopNetworkBlackHolePort() + case NetworkFaultPath(types.BlackHolePortFaultType, types.CheckNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-blackhole-port/status" + handleMethod = handler.CheckNetworkBlackHolePort() + case NetworkFaultPath(types.LatencyFaultType, types.StartNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-latency/start" + handleMethod = handler.StartNetworkLatency() + case NetworkFaultPath(types.LatencyFaultType, types.StopNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-latency/stop" + handleMethod = handler.StopNetworkLatency() + case NetworkFaultPath(types.LatencyFaultType, types.CheckNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-latency/status" + handleMethod = handler.CheckNetworkLatency() + case NetworkFaultPath(types.PacketLossFaultType, types.StartNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-packet-loss/start" + handleMethod = handler.StartNetworkPacketLoss() + case NetworkFaultPath(types.PacketLossFaultType, types.StopNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-packet-loss/stop" + handleMethod = handler.StopNetworkPacketLoss() + case NetworkFaultPath(types.PacketLossFaultType, types.CheckNetworkFaultPostfix): + tmdsAPI = "/api/%s/fault/v1/network-packet-loss/status" + handleMethod = handler.CheckNetworkPacketLoss() default: - t.Error("Unrecognized fault type") + t.Error("Unrecognized TMDS Endpoint") } router.HandleFunc( - NetworkFaultPath(faultType), + tmdsEndpoint, handleMethod, - ).Methods(httpMethod) + ).Methods(http.MethodPost) var requestBody io.Reader if tc.requestBody != nil { @@ -220,7 +217,7 @@ func testNetworkFaultInjectionCommon(t *testing.T, require.NoError(t, err) requestBody = bytes.NewReader(reqBodyBytes) } - req, err := http.NewRequest(httpMethod, fmt.Sprintf(tmdsAPI, endpointId), requestBody) + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf(tmdsAPI, endpointId), requestBody) require.NoError(t, err) // Send the request and record the response @@ -892,17 +889,17 @@ func generateCheckBlackHolePortFaultStatusTestCases() []networkFaultInjectionTes func TestStartNetworkBlackHolePort(t *testing.T) { tcs := generateStartBlackHolePortFaultTestCases() - testNetworkFaultInjectionCommon(t, tcs, types.BlackHolePortFaultType, http.MethodPut) + testNetworkFaultInjectionCommon(t, tcs, NetworkFaultPath(types.BlackHolePortFaultType, types.StartNetworkFaultPostfix)) } func TestStopNetworkBlackHolePort(t *testing.T) { tcs := generateStopBlackHolePortFaultTestCases() - testNetworkFaultInjectionCommon(t, tcs, types.BlackHolePortFaultType, http.MethodDelete) + testNetworkFaultInjectionCommon(t, tcs, NetworkFaultPath(types.BlackHolePortFaultType, types.StopNetworkFaultPostfix)) } func TestCheckNetworkBlackHolePort(t *testing.T) { tcs := generateCheckBlackHolePortFaultStatusTestCases() - testNetworkFaultInjectionCommon(t, tcs, types.BlackHolePortFaultType, http.MethodGet) + testNetworkFaultInjectionCommon(t, tcs, NetworkFaultPath(types.BlackHolePortFaultType, types.CheckNetworkFaultPostfix)) } func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []networkFaultInjectionTestCase { @@ -1161,17 +1158,17 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n func TestStartNetworkLatency(t *testing.T) { tcs := generateNetworkLatencyTestCases("start network latency", "running") - testNetworkFaultInjectionCommon(t, tcs, types.LatencyFaultType, http.MethodPut) + testNetworkFaultInjectionCommon(t, tcs, NetworkFaultPath(types.LatencyFaultType, types.StartNetworkFaultPostfix)) } func TestStopNetworkLatency(t *testing.T) { tcs := generateNetworkLatencyTestCases("stop network latency", "stopped") - testNetworkFaultInjectionCommon(t, tcs, types.LatencyFaultType, http.MethodDelete) + testNetworkFaultInjectionCommon(t, tcs, NetworkFaultPath(types.LatencyFaultType, types.StopNetworkFaultPostfix)) } func TestCheckNetworkLatency(t *testing.T) { tcs := generateNetworkLatencyTestCases("check network latency", "running") - testNetworkFaultInjectionCommon(t, tcs, types.LatencyFaultType, http.MethodGet) + testNetworkFaultInjectionCommon(t, tcs, NetworkFaultPath(types.LatencyFaultType, types.CheckNetworkFaultPostfix)) } func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjectionTestCase { @@ -1676,15 +1673,15 @@ func generateCheckNetworkPacketLossTestCases() []networkFaultInjectionTestCase { func TestStartNetworkPacketLoss(t *testing.T) { tcs := generateStartNetworkPacketLossTestCases() - testNetworkFaultInjectionCommon(t, tcs, types.PacketLossFaultType, http.MethodPut) + testNetworkFaultInjectionCommon(t, tcs, NetworkFaultPath(types.PacketLossFaultType, types.StartNetworkFaultPostfix)) } func TestStopNetworkPacketLoss(t *testing.T) { tcs := generateStopNetworkPacketLossTestCases() - testNetworkFaultInjectionCommon(t, tcs, types.PacketLossFaultType, http.MethodDelete) + testNetworkFaultInjectionCommon(t, tcs, NetworkFaultPath(types.PacketLossFaultType, types.StopNetworkFaultPostfix)) } func TestCheckNetworkPacketLoss(t *testing.T) { tcs := generateCheckNetworkPacketLossTestCases() - testNetworkFaultInjectionCommon(t, tcs, types.PacketLossFaultType, http.MethodGet) + testNetworkFaultInjectionCommon(t, tcs, NetworkFaultPath(types.PacketLossFaultType, types.CheckNetworkFaultPostfix)) } diff --git a/ecs-agent/tmds/handlers/fault/v1/types/types.go b/ecs-agent/tmds/handlers/fault/v1/types/types.go index 1e69c7359d9..80a00550e1e 100644 --- a/ecs-agent/tmds/handlers/fault/v1/types/types.go +++ b/ecs-agent/tmds/handlers/fault/v1/types/types.go @@ -26,6 +26,9 @@ const ( BlackHolePortFaultType = "network-blackhole-port" LatencyFaultType = "network-latency" PacketLossFaultType = "network-packet-loss" + StartNetworkFaultPostfix = "start" + StopNetworkFaultPostfix = "stop" + CheckNetworkFaultPostfix = "status" missingRequiredFieldError = "required parameter %s is missing" MissingRequestBodyError = "required request body is missing" invalidValueError = "invalid value %s for parameter %s"