Skip to content

Commit

Permalink
use test helper functions and sub tests
Browse files Browse the repository at this point in the history
No functional changes, but this will make the test output look
nicer (and easier to debug) when tests fail.
  • Loading branch information
marten-seemann committed Nov 9, 2017
1 parent 64af8ab commit 03b32d6
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 108 deletions.
10 changes: 10 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version: 2
jobs:
build:
docker:
- image: circleci/golang:1.9
working_directory: /go/src/github.com/bifurcation/mint
steps:
- checkout
- run: go get -v -t -d ./...
- run: go test -v -race ./...
11 changes: 11 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,52 @@ func unhex(h string) []byte {
}

func assert(t *testing.T, test bool, msg string) {
t.Helper()
if !test {
t.Fatalf(msg)
}
}

func assertError(t *testing.T, err error, msg string) {
t.Helper()
assert(t, err != nil, msg)
}

func assertNotError(t *testing.T, err error, msg string) {
t.Helper()
if err != nil {
msg += ": " + err.Error()
}
assert(t, err == nil, msg)
}

func assertNil(t *testing.T, x interface{}, msg string) {
t.Helper()
assert(t, x == nil, msg)
}

func assertNotNil(t *testing.T, x interface{}, msg string) {
t.Helper()
assert(t, x != nil, msg)
}

func assertEquals(t *testing.T, a, b interface{}) {
t.Helper()
assert(t, a == b, fmt.Sprintf("%+v != %+v", a, b))
}

func assertByteEquals(t *testing.T, a, b []byte) {
t.Helper()
assert(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
}

func assertNotByteEquals(t *testing.T, a, b []byte) {
t.Helper()
assert(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
}

func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) {
t.Helper()
assertEquals(t, a.Suite, b.Suite)
// Can't compare aeadFactory values
assertEquals(t, a.Hash, b.Hash)
Expand All @@ -62,10 +71,12 @@ func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) {
}

func assertDeepEquals(t *testing.T, a, b interface{}) {
t.Helper()
assert(t, reflect.DeepEqual(a, b), fmt.Sprintf("%+v != %+v", a, b))
}

func assertSameType(t *testing.T, a, b interface{}) {
t.Helper()
A := reflect.TypeOf(a)
B := reflect.TypeOf(b)
assert(t, A == B, fmt.Sprintf("%s != %s", A.Name(), B.Name()))
Expand Down
82 changes: 47 additions & 35 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,41 +281,53 @@ func computeExporter(t *testing.T, c *Conn, label string, context []byte, length
}

func TestBasicFlows(t *testing.T) {
for _, conf := range []*Config{basicConfig, hrrConfig, alpnConfig, ffdhConfig, x25519Config} {
cConn, sConn := pipe()

client := Client(cConn, conf)
server := Server(sConn, conf)

var clientAlert, serverAlert Alert

done := make(chan bool)
go func(t *testing.T) {
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
done <- true
}(t)

clientAlert = client.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)

<-done

assertDeepEquals(t, client.state.Params, server.state.Params)
assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams)
assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret)
assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret)
assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret)
assertByteEquals(t, client.state.exporterSecret, server.state.exporterSecret)

emptyContext := []byte{}

assertByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 20))
assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 21))
assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "F", emptyContext, 20))
assertByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'A'}, 20))
assertNotByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'B'}, 20))

tests := []struct {
name string
config *Config
}{
{"basic config", basicConfig},
{"HRR", hrrConfig},
{"ALPN", alpnConfig},
{"FFDH", ffdhConfig},
{"x25519", x25519Config},
}
for _, testcase := range tests {
t.Run(fmt.Sprintf("with %s", testcase.name), func(t *testing.T) {
conf := testcase.config
cConn, sConn := pipe()

client := Client(cConn, conf)
server := Server(sConn, conf)

var clientAlert, serverAlert Alert

done := make(chan bool)
go func(t *testing.T) {
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
done <- true
}(t)

clientAlert = client.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)

<-done

assertDeepEquals(t, client.state.Params, server.state.Params)
assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams)
assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret)
assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret)
assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret)
assertByteEquals(t, client.state.exporterSecret, server.state.exporterSecret)

emptyContext := []byte{}

assertByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 20))
assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 21))
assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "F", emptyContext, 20))
assertByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'A'}, 20))
assertNotByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'B'}, 20))
})
}
}

Expand Down
147 changes: 74 additions & 73 deletions state-machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,91 +310,92 @@ func messagesFromActions(instructions []HandshakeAction) []*HandshakeMessage {
// TODO: Unit tests for individual states
func TestStateMachineIntegration(t *testing.T) {
for caseName, params := range stateMachineIntegrationCases {
t.Logf("=== Integration Test (%s) ===", caseName)
t.Run(caseName, func(t *testing.T) {

var clientState, serverState HandshakeState
clientState = ClientStateStart{
Caps: params.clientCapabilities,
Opts: params.clientOptions,
}
serverState = ServerStateStart{Caps: params.serverCapabilities}
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
t.Logf("Server: %s", reflect.TypeOf(serverState).Name())
var clientState, serverState HandshakeState
clientState = ClientStateStart{
Caps: params.clientCapabilities,
Opts: params.clientOptions,
}
serverState = ServerStateStart{Caps: params.serverCapabilities}
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
t.Logf("Server: %s", reflect.TypeOf(serverState).Name())

clientStateSequence := []HandshakeState{clientState}
serverStateSequence := []HandshakeState{serverState}
clientStateSequence := []HandshakeState{clientState}
serverStateSequence := []HandshakeState{serverState}

// Create the ClientHello
clientState, clientInstr, alert := clientState.Next(nil)
clientToSend := messagesFromActions(clientInstr)
assertEquals(t, alert, AlertNoAlert)
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
clientStateSequence = append(clientStateSequence, clientState)
assertEquals(t, len(clientToSend), 1)
// Create the ClientHello
clientState, clientInstr, alert := clientState.Next(nil)
clientToSend := messagesFromActions(clientInstr)
assertEquals(t, alert, AlertNoAlert)
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
clientStateSequence = append(clientStateSequence, clientState)
assertEquals(t, len(clientToSend), 1)

for {
var clientInstr, serverInstr []HandshakeAction
var alert Alert
for {
var clientInstr, serverInstr []HandshakeAction
var alert Alert

// Client -> Server
serverToSend := []*HandshakeMessage{}
for _, body := range clientToSend {
t.Logf("C->S: %d", body.msgType)
serverState, serverInstr, alert = serverState.Next(body)
serverResponses := messagesFromActions(serverInstr)
assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from server [%v]", alert))
serverStateSequence = append(serverStateSequence, serverState)
t.Logf("Server: %s", reflect.TypeOf(serverState).Name())
serverToSend = append(serverToSend, serverResponses...)
}
// Client -> Server
serverToSend := []*HandshakeMessage{}
for _, body := range clientToSend {
t.Logf("C->S: %d", body.msgType)
serverState, serverInstr, alert = serverState.Next(body)
serverResponses := messagesFromActions(serverInstr)
assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from server [%v]", alert))
serverStateSequence = append(serverStateSequence, serverState)
t.Logf("Server: %s", reflect.TypeOf(serverState).Name())
serverToSend = append(serverToSend, serverResponses...)
}

// Server -> Client
clientToSend = []*HandshakeMessage{}
for _, body := range serverToSend {
t.Logf("S->C: %d", body.msgType)
clientState, clientInstr, alert = clientState.Next(body)
clientResponses := messagesFromActions(clientInstr)
assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from client [%v]", alert))
clientStateSequence = append(clientStateSequence, clientState)
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
clientToSend = append(clientToSend, clientResponses...)
}
// Server -> Client
clientToSend = []*HandshakeMessage{}
for _, body := range serverToSend {
t.Logf("S->C: %d", body.msgType)
clientState, clientInstr, alert = clientState.Next(body)
clientResponses := messagesFromActions(clientInstr)
assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from client [%v]", alert))
clientStateSequence = append(clientStateSequence, clientState)
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
clientToSend = append(clientToSend, clientResponses...)
}

clientConnected := reflect.TypeOf(clientState) == reflect.TypeOf(StateConnected{})
serverConnected := reflect.TypeOf(serverState) == reflect.TypeOf(StateConnected{})
if clientConnected && serverConnected {
c := clientState.(StateConnected)
s := serverState.(StateConnected)
clientConnected := reflect.TypeOf(clientState) == reflect.TypeOf(StateConnected{})
serverConnected := reflect.TypeOf(serverState) == reflect.TypeOf(StateConnected{})
if clientConnected && serverConnected {
c := clientState.(StateConnected)
s := serverState.(StateConnected)

// Test that we ended up at the same state
assertDeepEquals(t, c.Params, s.Params)
assertCipherSuiteParamsEquals(t, c.cryptoParams, s.cryptoParams)
assertByteEquals(t, c.resumptionSecret, s.resumptionSecret)
assertByteEquals(t, c.clientTrafficSecret, s.clientTrafficSecret)
assertByteEquals(t, c.serverTrafficSecret, s.serverTrafficSecret)
// Test that we ended up at the same state
assertDeepEquals(t, c.Params, s.Params)
assertCipherSuiteParamsEquals(t, c.cryptoParams, s.cryptoParams)
assertByteEquals(t, c.resumptionSecret, s.resumptionSecret)
assertByteEquals(t, c.clientTrafficSecret, s.clientTrafficSecret)
assertByteEquals(t, c.serverTrafficSecret, s.serverTrafficSecret)

// Test that the client went through the expected sequence of states
assertEquals(t, len(clientStateSequence), len(params.clientStateSequence))
for i, state := range clientStateSequence {
t.Logf("-- %d %s", i, reflect.TypeOf(state).Name())
assertSameType(t, state, params.clientStateSequence[i])
}
// Test that the client went through the expected sequence of states
assertEquals(t, len(clientStateSequence), len(params.clientStateSequence))
for i, state := range clientStateSequence {
t.Logf("-- %d %s", i, reflect.TypeOf(state).Name())
assertSameType(t, state, params.clientStateSequence[i])
}

// Test that the server went through the expected sequence of states
assertEquals(t, len(serverStateSequence), len(params.serverStateSequence))
for i, state := range serverStateSequence {
t.Logf("-- %d %s", i, reflect.TypeOf(state).Name())
assertSameType(t, state, params.serverStateSequence[i])
}
// Test that the server went through the expected sequence of states
assertEquals(t, len(serverStateSequence), len(params.serverStateSequence))
for i, state := range serverStateSequence {
t.Logf("-- %d %s", i, reflect.TypeOf(state).Name())
assertSameType(t, state, params.serverStateSequence[i])
}

break
}
break
}

clientStateName := reflect.TypeOf(clientState).Name()
serverStateName := reflect.TypeOf(serverState).Name()
if len(clientToSend) == 0 {
t.Fatalf("Deadlock at client=[%s] server=[%s]", clientStateName, serverStateName)
clientStateName := reflect.TypeOf(clientState).Name()
serverStateName := reflect.TypeOf(serverState).Name()
if len(clientToSend) == 0 {
t.Fatalf("Deadlock at client=[%s] server=[%s]", clientStateName, serverStateName)
}
}
}
})
}
}

0 comments on commit 03b32d6

Please sign in to comment.