diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..55509ae --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,10 @@ +version: 2 +jobs: + build: + docker: + - image: circleci/golang:1.9 + working_directory: /go/src/github.com/bifurcation/mint + steps: + - checkout + - run: go get -v -t -d ./... + - run: go test -v -race ./... diff --git a/common_test.go b/common_test.go index 0294595..72853f6 100644 --- a/common_test.go +++ b/common_test.go @@ -17,16 +17,19 @@ func unhex(h string) []byte { } func assert(t *testing.T, test bool, msg string) { + t.Helper() if !test { t.Fatalf(msg) } } func assertError(t *testing.T, err error, msg string) { + t.Helper() assert(t, err != nil, msg) } func assertNotError(t *testing.T, err error, msg string) { + t.Helper() if err != nil { msg += ": " + err.Error() } @@ -34,26 +37,32 @@ func assertNotError(t *testing.T, err error, msg string) { } func assertNil(t *testing.T, x interface{}, msg string) { + t.Helper() assert(t, x == nil, msg) } func assertNotNil(t *testing.T, x interface{}, msg string) { + t.Helper() assert(t, x != nil, msg) } func assertEquals(t *testing.T, a, b interface{}) { + t.Helper() assert(t, a == b, fmt.Sprintf("%+v != %+v", a, b)) } func assertByteEquals(t *testing.T, a, b []byte) { + t.Helper() assert(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b))) } func assertNotByteEquals(t *testing.T, a, b []byte) { + t.Helper() assert(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b))) } func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) { + t.Helper() assertEquals(t, a.Suite, b.Suite) // Can't compare aeadFactory values assertEquals(t, a.Hash, b.Hash) @@ -62,10 +71,12 @@ func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) { } func assertDeepEquals(t *testing.T, a, b interface{}) { + t.Helper() assert(t, reflect.DeepEqual(a, b), fmt.Sprintf("%+v != %+v", a, b)) } func assertSameType(t *testing.T, a, b interface{}) { + t.Helper() A := reflect.TypeOf(a) B := reflect.TypeOf(b) assert(t, A == B, fmt.Sprintf("%s != %s", A.Name(), B.Name())) diff --git a/conn_test.go b/conn_test.go index 965a657..77cad24 100644 --- a/conn_test.go +++ b/conn_test.go @@ -281,41 +281,53 @@ func computeExporter(t *testing.T, c *Conn, label string, context []byte, length } func TestBasicFlows(t *testing.T) { - for _, conf := range []*Config{basicConfig, hrrConfig, alpnConfig, ffdhConfig, x25519Config} { - cConn, sConn := pipe() - - client := Client(cConn, conf) - server := Server(sConn, conf) - - var clientAlert, serverAlert Alert - - done := make(chan bool) - go func(t *testing.T) { - serverAlert = server.Handshake() - assertEquals(t, serverAlert, AlertNoAlert) - done <- true - }(t) - - clientAlert = client.Handshake() - assertEquals(t, clientAlert, AlertNoAlert) - - <-done - - assertDeepEquals(t, client.state.Params, server.state.Params) - assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) - assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) - assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) - assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) - assertByteEquals(t, client.state.exporterSecret, server.state.exporterSecret) - - emptyContext := []byte{} - - assertByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 20)) - assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 21)) - assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "F", emptyContext, 20)) - assertByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'A'}, 20)) - assertNotByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'B'}, 20)) - + tests := []struct { + name string + config *Config + }{ + {"basic config", basicConfig}, + {"HRR", hrrConfig}, + {"ALPN", alpnConfig}, + {"FFDH", ffdhConfig}, + {"x25519", x25519Config}, + } + for _, testcase := range tests { + t.Run(fmt.Sprintf("with %s", testcase.name), func(t *testing.T) { + conf := testcase.config + cConn, sConn := pipe() + + client := Client(cConn, conf) + server := Server(sConn, conf) + + var clientAlert, serverAlert Alert + + done := make(chan bool) + go func(t *testing.T) { + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertNoAlert) + done <- true + }(t) + + clientAlert = client.Handshake() + assertEquals(t, clientAlert, AlertNoAlert) + + <-done + + assertDeepEquals(t, client.state.Params, server.state.Params) + assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) + assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) + assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) + assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) + assertByteEquals(t, client.state.exporterSecret, server.state.exporterSecret) + + emptyContext := []byte{} + + assertByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 20)) + assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 21)) + assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "F", emptyContext, 20)) + assertByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'A'}, 20)) + assertNotByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'B'}, 20)) + }) } } diff --git a/state-machine_test.go b/state-machine_test.go index 71eb8ed..754f9c7 100644 --- a/state-machine_test.go +++ b/state-machine_test.go @@ -310,91 +310,92 @@ func messagesFromActions(instructions []HandshakeAction) []*HandshakeMessage { // TODO: Unit tests for individual states func TestStateMachineIntegration(t *testing.T) { for caseName, params := range stateMachineIntegrationCases { - t.Logf("=== Integration Test (%s) ===", caseName) + t.Run(caseName, func(t *testing.T) { - var clientState, serverState HandshakeState - clientState = ClientStateStart{ - Caps: params.clientCapabilities, - Opts: params.clientOptions, - } - serverState = ServerStateStart{Caps: params.serverCapabilities} - t.Logf("Client: %s", reflect.TypeOf(clientState).Name()) - t.Logf("Server: %s", reflect.TypeOf(serverState).Name()) + var clientState, serverState HandshakeState + clientState = ClientStateStart{ + Caps: params.clientCapabilities, + Opts: params.clientOptions, + } + serverState = ServerStateStart{Caps: params.serverCapabilities} + t.Logf("Client: %s", reflect.TypeOf(clientState).Name()) + t.Logf("Server: %s", reflect.TypeOf(serverState).Name()) - clientStateSequence := []HandshakeState{clientState} - serverStateSequence := []HandshakeState{serverState} + clientStateSequence := []HandshakeState{clientState} + serverStateSequence := []HandshakeState{serverState} - // Create the ClientHello - clientState, clientInstr, alert := clientState.Next(nil) - clientToSend := messagesFromActions(clientInstr) - assertEquals(t, alert, AlertNoAlert) - t.Logf("Client: %s", reflect.TypeOf(clientState).Name()) - clientStateSequence = append(clientStateSequence, clientState) - assertEquals(t, len(clientToSend), 1) + // Create the ClientHello + clientState, clientInstr, alert := clientState.Next(nil) + clientToSend := messagesFromActions(clientInstr) + assertEquals(t, alert, AlertNoAlert) + t.Logf("Client: %s", reflect.TypeOf(clientState).Name()) + clientStateSequence = append(clientStateSequence, clientState) + assertEquals(t, len(clientToSend), 1) - for { - var clientInstr, serverInstr []HandshakeAction - var alert Alert + for { + var clientInstr, serverInstr []HandshakeAction + var alert Alert - // Client -> Server - serverToSend := []*HandshakeMessage{} - for _, body := range clientToSend { - t.Logf("C->S: %d", body.msgType) - serverState, serverInstr, alert = serverState.Next(body) - serverResponses := messagesFromActions(serverInstr) - assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from server [%v]", alert)) - serverStateSequence = append(serverStateSequence, serverState) - t.Logf("Server: %s", reflect.TypeOf(serverState).Name()) - serverToSend = append(serverToSend, serverResponses...) - } + // Client -> Server + serverToSend := []*HandshakeMessage{} + for _, body := range clientToSend { + t.Logf("C->S: %d", body.msgType) + serverState, serverInstr, alert = serverState.Next(body) + serverResponses := messagesFromActions(serverInstr) + assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from server [%v]", alert)) + serverStateSequence = append(serverStateSequence, serverState) + t.Logf("Server: %s", reflect.TypeOf(serverState).Name()) + serverToSend = append(serverToSend, serverResponses...) + } - // Server -> Client - clientToSend = []*HandshakeMessage{} - for _, body := range serverToSend { - t.Logf("S->C: %d", body.msgType) - clientState, clientInstr, alert = clientState.Next(body) - clientResponses := messagesFromActions(clientInstr) - assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from client [%v]", alert)) - clientStateSequence = append(clientStateSequence, clientState) - t.Logf("Client: %s", reflect.TypeOf(clientState).Name()) - clientToSend = append(clientToSend, clientResponses...) - } + // Server -> Client + clientToSend = []*HandshakeMessage{} + for _, body := range serverToSend { + t.Logf("S->C: %d", body.msgType) + clientState, clientInstr, alert = clientState.Next(body) + clientResponses := messagesFromActions(clientInstr) + assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from client [%v]", alert)) + clientStateSequence = append(clientStateSequence, clientState) + t.Logf("Client: %s", reflect.TypeOf(clientState).Name()) + clientToSend = append(clientToSend, clientResponses...) + } - clientConnected := reflect.TypeOf(clientState) == reflect.TypeOf(StateConnected{}) - serverConnected := reflect.TypeOf(serverState) == reflect.TypeOf(StateConnected{}) - if clientConnected && serverConnected { - c := clientState.(StateConnected) - s := serverState.(StateConnected) + clientConnected := reflect.TypeOf(clientState) == reflect.TypeOf(StateConnected{}) + serverConnected := reflect.TypeOf(serverState) == reflect.TypeOf(StateConnected{}) + if clientConnected && serverConnected { + c := clientState.(StateConnected) + s := serverState.(StateConnected) - // Test that we ended up at the same state - assertDeepEquals(t, c.Params, s.Params) - assertCipherSuiteParamsEquals(t, c.cryptoParams, s.cryptoParams) - assertByteEquals(t, c.resumptionSecret, s.resumptionSecret) - assertByteEquals(t, c.clientTrafficSecret, s.clientTrafficSecret) - assertByteEquals(t, c.serverTrafficSecret, s.serverTrafficSecret) + // Test that we ended up at the same state + assertDeepEquals(t, c.Params, s.Params) + assertCipherSuiteParamsEquals(t, c.cryptoParams, s.cryptoParams) + assertByteEquals(t, c.resumptionSecret, s.resumptionSecret) + assertByteEquals(t, c.clientTrafficSecret, s.clientTrafficSecret) + assertByteEquals(t, c.serverTrafficSecret, s.serverTrafficSecret) - // Test that the client went through the expected sequence of states - assertEquals(t, len(clientStateSequence), len(params.clientStateSequence)) - for i, state := range clientStateSequence { - t.Logf("-- %d %s", i, reflect.TypeOf(state).Name()) - assertSameType(t, state, params.clientStateSequence[i]) - } + // Test that the client went through the expected sequence of states + assertEquals(t, len(clientStateSequence), len(params.clientStateSequence)) + for i, state := range clientStateSequence { + t.Logf("-- %d %s", i, reflect.TypeOf(state).Name()) + assertSameType(t, state, params.clientStateSequence[i]) + } - // Test that the server went through the expected sequence of states - assertEquals(t, len(serverStateSequence), len(params.serverStateSequence)) - for i, state := range serverStateSequence { - t.Logf("-- %d %s", i, reflect.TypeOf(state).Name()) - assertSameType(t, state, params.serverStateSequence[i]) - } + // Test that the server went through the expected sequence of states + assertEquals(t, len(serverStateSequence), len(params.serverStateSequence)) + for i, state := range serverStateSequence { + t.Logf("-- %d %s", i, reflect.TypeOf(state).Name()) + assertSameType(t, state, params.serverStateSequence[i]) + } - break - } + break + } - clientStateName := reflect.TypeOf(clientState).Name() - serverStateName := reflect.TypeOf(serverState).Name() - if len(clientToSend) == 0 { - t.Fatalf("Deadlock at client=[%s] server=[%s]", clientStateName, serverStateName) + clientStateName := reflect.TypeOf(clientState).Name() + serverStateName := reflect.TypeOf(serverState).Name() + if len(clientToSend) == 0 { + t.Fatalf("Deadlock at client=[%s] server=[%s]", clientStateName, serverStateName) + } } - } + }) } }