Skip to content

Commit

Permalink
Update fault injection TMDS endpoints to include the operation type s…
Browse files Browse the repository at this point in the history
…tart/stop/status (#4366)

Co-authored-by: Tianze Shan <[email protected]>
  • Loading branch information
tshan2001 and Tianze Shan authored Sep 26, 2024
1 parent afb502c commit e155f96
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 110 deletions.
36 changes: 18 additions & 18 deletions agent/handlers/task_server_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,45 +210,45 @@ func registerFaultHandlers(

// Setting up handler endpoints for network blackhole port fault injections
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.BlackHolePortFaultType),
fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkBlackholePort()),
).Methods("PUT")
).Methods("POST")
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.BlackHolePortFaultType),
fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkBlackHolePort()),
).Methods("DELETE")
).Methods("POST")
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.BlackHolePortFaultType),
fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkBlackHolePort()),
).Methods("GET")
).Methods("POST")

// Setting up handler endpoints for network latency fault injections
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.LatencyFaultType),
fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkLatency()),
).Methods("PUT")
).Methods("POST")
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.LatencyFaultType),
fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkLatency()),
).Methods("DELETE")
).Methods("POST")
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.LatencyFaultType),
fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkLatency()),
).Methods("GET")
).Methods("POST")

// Setting up handler endpoints for network packet loss fault injections
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.PacketLossFaultType),
fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkPacketLoss()),
).Methods("PUT")
).Methods("POST")
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.PacketLossFaultType),
fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkPacketLoss()),
).Methods("DELETE")
).Methods("POST")
muxRouter.Handle(
fault.NetworkFaultPath(faulttype.PacketLossFaultType),
fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix),
tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkPacketLoss()),
).Methods("GET")
).Methods("POST")

seelog.Debug("Successfully set up Fault TMDS handlers")
}
Expand Down
40 changes: 20 additions & 20 deletions agent/handlers/task_server_setup_integ_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,46 +105,46 @@ func TestRateLimiterIntegration(t *testing.T) {
}{
{
name: "Same network faults A1 + same methods B1",
method1: "GET",
method2: "GET",
url1: "/api/container123/fault/v1/network-blackhole-port",
url2: "/api/container123/fault/v1/network-blackhole-port",
method1: "POST",
method2: "POST",
url1: "/api/container123/fault/v1/network-blackhole-port/status",
url2: "/api/container123/fault/v1/network-blackhole-port/status",
expectedStatus2: http.StatusTooManyRequests,
assertNotEqual: false,
},
{
name: "Same network fault A1 + different methods B1, B2",
method1: "GET",
method2: "PUT",
url1: "/api/container123/fault/v1/network-blackhole-port",
url2: "/api/container123/fault/v1/network-blackhole-port",
method1: "POST",
method2: "POST",
url1: "/api/container123/fault/v1/network-blackhole-port/status",
url2: "/api/container123/fault/v1/network-blackhole-port/start",
expectedStatus2: http.StatusTooManyRequests,
assertNotEqual: true,
},
{
name: "Different network faults A1, A2 + same method B1",
method1: "GET",
method2: "GET",
url1: "/api/container123/fault/v1/network-blackhole-port",
url2: "/api/container123/fault/v1/network-latency",
method1: "POST",
method2: "POST",
url1: "/api/container123/fault/v1/network-blackhole-port/status",
url2: "/api/container123/fault/v1/network-latency/status",
expectedStatus2: http.StatusTooManyRequests,
assertNotEqual: true,
},
{
name: "Different network faults A1, A3 + same method B1",
method1: "GET",
method2: "GET",
url1: "/api/container123/fault/v1/network-blackhole-port",
url2: "/api/container123/fault/v1/network-packet-loss",
method1: "POST",
method2: "POST",
url1: "/api/container123/fault/v1/network-blackhole-port/status",
url2: "/api/container123/fault/v1/network-packet-loss/status",
expectedStatus2: http.StatusTooManyRequests,
assertNotEqual: true,
},
{
name: "Different network faults A2, A3 + same methods B1",
method1: "GET",
method2: "GET",
url1: "/api/container123/fault/v1/network-latency",
url2: "/api/container123/fault/v1/network-packet-loss",
method1: "POST",
method2: "POST",
url1: "/api/container123/fault/v1/network-latency/status",
url2: "/api/container123/fault/v1/network-packet-loss/status",
expectedStatus2: http.StatusTooManyRequests,
assertNotEqual: true,
},
Expand Down
49 changes: 37 additions & 12 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks"
ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface"
"github.com/aws/amazon-ecs-agent/ecs-agent/stats"
faulthandler "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers"
faulttype "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types"
tmdsresponse "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response"
tp "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers"
Expand All @@ -56,13 +57,13 @@ import (
v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2"
v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state"
mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks"
"github.com/gorilla/mux"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/docker/docker/api/types"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -3785,7 +3786,7 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) {
)
}
tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody)
testRegisterFaultHandler(t, tcs, "PUT", faulttype.BlackHolePortFaultType)
testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix))
}

func TestRegisterStopBlackholePortFaultHandler(t *testing.T) {
Expand All @@ -3805,7 +3806,7 @@ func TestRegisterStopBlackholePortFaultHandler(t *testing.T) {
)
}
tcs := generateCommonNetworkFaultInjectionTestCases("stop blackhole port", "stopped", setExecExpectations, happyBlackHolePortReqBody)
testRegisterFaultHandler(t, tcs, "DELETE", faulttype.BlackHolePortFaultType)
testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix))
}

func TestRegisterCheckBlackholePortFaultHandler(t *testing.T) {
Expand All @@ -3819,7 +3820,7 @@ func TestRegisterCheckBlackholePortFaultHandler(t *testing.T) {
)
}
tcs := generateCommonNetworkFaultInjectionTestCases("check blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody)
testRegisterFaultHandler(t, tcs, "GET", faulttype.BlackHolePortFaultType)
testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix))
}

func TestRegisterStartLatencyFaultHandler(t *testing.T) {
Expand All @@ -3828,7 +3829,7 @@ func TestRegisterStartLatencyFaultHandler(t *testing.T) {

}
tcs := generateCommonNetworkFaultInjectionTestCases("start latency", "running", setExecExpectations, happyNetworkLatencyReqBody)
testRegisterFaultHandler(t, tcs, "PUT", faulttype.LatencyFaultType)
testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix))
}

func TestRegisterStopLatencyFaultHandler(t *testing.T) {
Expand All @@ -3837,7 +3838,7 @@ func TestRegisterStopLatencyFaultHandler(t *testing.T) {

}
tcs := generateCommonNetworkFaultInjectionTestCases("stop latency", "stopped", setExecExpectations, happyNetworkLatencyReqBody)
testRegisterFaultHandler(t, tcs, "DELETE", faulttype.LatencyFaultType)
testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix))
}

func TestRegisterCheckLatencyFaultHandler(t *testing.T) {
Expand All @@ -3846,7 +3847,7 @@ func TestRegisterCheckLatencyFaultHandler(t *testing.T) {

}
tcs := generateCommonNetworkFaultInjectionTestCases("check latency", "running", setExecExpectations, happyNetworkLatencyReqBody)
testRegisterFaultHandler(t, tcs, "GET", faulttype.LatencyFaultType)
testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix))
}

func TestRegisterStartPacketLossFaultHandler(t *testing.T) {
Expand All @@ -3862,7 +3863,7 @@ func TestRegisterStartPacketLossFaultHandler(t *testing.T) {
mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil)
}
tcs := generateCommonNetworkFaultInjectionTestCases("start packet loss", "running", setExecExpectations, happyNetworkPacketLossReqBody)
testRegisterFaultHandler(t, tcs, "PUT", faulttype.PacketLossFaultType)
testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix))
}

func TestRegisterStopPacketLossFaultHandler(t *testing.T) {
Expand All @@ -3876,7 +3877,7 @@ func TestRegisterStopPacketLossFaultHandler(t *testing.T) {
)
}
tcs := generateCommonNetworkFaultInjectionTestCases("stop packet loss", "stopped", setExecExpectations, happyNetworkPacketLossReqBody)
testRegisterFaultHandler(t, tcs, "DELETE", faulttype.PacketLossFaultType)
testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix))
}

func TestRegisterCheckPacketLossFaultHandler(t *testing.T) {
Expand All @@ -3890,10 +3891,10 @@ func TestRegisterCheckPacketLossFaultHandler(t *testing.T) {
)
}
tcs := generateCommonNetworkFaultInjectionTestCases("check packet loss", "running", setExecExpectations, happyNetworkPacketLossReqBody)
testRegisterFaultHandler(t, tcs, "GET", faulttype.PacketLossFaultType)
testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix))
}

func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, method, fault string) {
func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, tmdsEndpoint string) {
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
// Mocks
Expand All @@ -3916,6 +3917,30 @@ func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, method,
tc.setExecExpectations(execWrapper, ctrl)
}

var tmdsAPI string
switch tmdsEndpoint {
case faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix):
tmdsAPI = "/api/%s/fault/v1/network-blackhole-port/start"
case faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix):
tmdsAPI = "/api/%s/fault/v1/network-blackhole-port/stop"
case faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix):
tmdsAPI = "/api/%s/fault/v1/network-blackhole-port/status"
case faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix):
tmdsAPI = "/api/%s/fault/v1/network-latency/start"
case faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix):
tmdsAPI = "/api/%s/fault/v1/network-latency/stop"
case faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix):
tmdsAPI = "/api/%s/fault/v1/network-latency/status"
case faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix):
tmdsAPI = "/api/%s/fault/v1/network-packet-loss/start"
case faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix):
tmdsAPI = "/api/%s/fault/v1/network-packet-loss/stop"
case faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix):
tmdsAPI = "/api/%s/fault/v1/network-packet-loss/status"
default:
t.Error("Unrecognized TMDS Endpoint")
}

router := mux.NewRouter()
registerFaultHandlers(router, agentState, metricsFactory, execWrapper)
var requestBody io.Reader
Expand All @@ -3925,7 +3950,7 @@ func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, method,
requestBody = bytes.NewReader(reqBodyBytes)
}

req, err := http.NewRequest(method, fmt.Sprintf("/api/%s/fault/v1/%s", endpointId, fault),
req, err := http.NewRequest("POST", fmt.Sprintf(tmdsAPI, endpointId),
requestBody)
require.NoError(t, err)

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ func New(agentState state.AgentState, mf metrics.EntryFactory, execWrapper execw
}

// NetworkFaultPath will take in a fault type and return the TMDS endpoint path
func NetworkFaultPath(fault string) string {
return fmt.Sprintf("/api/%s/fault/v1/%s",
utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault)
func NetworkFaultPath(fault string, operationType string) string {
return fmt.Sprintf("/api/%s/fault/v1/%s/%s",
utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault, operationType)
}

// loadLock returns the lock associated with given key.
Expand Down
Loading

0 comments on commit e155f96

Please sign in to comment.