Skip to content

Commit

Permalink
Merge pull request #147 from marten-seemann/test-helpers
Browse files Browse the repository at this point in the history
use test helper functions and sub tests
  • Loading branch information
bifurcation authored Dec 6, 2017
2 parents 64af8ab + 03b32d6 commit 1f893e2
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 108 deletions.
10 changes: 10 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version: 2
jobs:
build:
docker:
- image: circleci/golang:1.9
working_directory: /go/src/github.com/bifurcation/mint
steps:
- checkout
- run: go get -v -t -d ./...
- run: go test -v -race ./...
11 changes: 11 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,52 @@ func unhex(h string) []byte {
}

func assert(t *testing.T, test bool, msg string) {
t.Helper()
if !test {
t.Fatalf(msg)
}
}

func assertError(t *testing.T, err error, msg string) {
t.Helper()
assert(t, err != nil, msg)
}

func assertNotError(t *testing.T, err error, msg string) {
t.Helper()
if err != nil {
msg += ": " + err.Error()
}
assert(t, err == nil, msg)
}

func assertNil(t *testing.T, x interface{}, msg string) {
t.Helper()
assert(t, x == nil, msg)
}

func assertNotNil(t *testing.T, x interface{}, msg string) {
t.Helper()
assert(t, x != nil, msg)
}

func assertEquals(t *testing.T, a, b interface{}) {
t.Helper()
assert(t, a == b, fmt.Sprintf("%+v != %+v", a, b))
}

func assertByteEquals(t *testing.T, a, b []byte) {
t.Helper()
assert(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
}

func assertNotByteEquals(t *testing.T, a, b []byte) {
t.Helper()
assert(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
}

func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) {
t.Helper()
assertEquals(t, a.Suite, b.Suite)
// Can't compare aeadFactory values
assertEquals(t, a.Hash, b.Hash)
Expand All @@ -62,10 +71,12 @@ func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) {
}

func assertDeepEquals(t *testing.T, a, b interface{}) {
t.Helper()
assert(t, reflect.DeepEqual(a, b), fmt.Sprintf("%+v != %+v", a, b))
}

func assertSameType(t *testing.T, a, b interface{}) {
t.Helper()
A := reflect.TypeOf(a)
B := reflect.TypeOf(b)
assert(t, A == B, fmt.Sprintf("%s != %s", A.Name(), B.Name()))
Expand Down
82 changes: 47 additions & 35 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,41 +281,53 @@ func computeExporter(t *testing.T, c *Conn, label string, context []byte, length
}

func TestBasicFlows(t *testing.T) {
for _, conf := range []*Config{basicConfig, hrrConfig, alpnConfig, ffdhConfig, x25519Config} {
cConn, sConn := pipe()

client := Client(cConn, conf)
server := Server(sConn, conf)

var clientAlert, serverAlert Alert

done := make(chan bool)
go func(t *testing.T) {
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
done <- true
}(t)

clientAlert = client.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)

<-done

assertDeepEquals(t, client.state.Params, server.state.Params)
assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams)
assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret)
assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret)
assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret)
assertByteEquals(t, client.state.exporterSecret, server.state.exporterSecret)

emptyContext := []byte{}

assertByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 20))
assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 21))
assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "F", emptyContext, 20))
assertByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'A'}, 20))
assertNotByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'B'}, 20))

tests := []struct {
name string
config *Config
}{
{"basic config", basicConfig},
{"HRR", hrrConfig},
{"ALPN", alpnConfig},
{"FFDH", ffdhConfig},
{"x25519", x25519Config},
}
for _, testcase := range tests {
t.Run(fmt.Sprintf("with %s", testcase.name), func(t *testing.T) {
conf := testcase.config
cConn, sConn := pipe()

client := Client(cConn, conf)
server := Server(sConn, conf)

var clientAlert, serverAlert Alert

done := make(chan bool)
go func(t *testing.T) {
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
done <- true
}(t)

clientAlert = client.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)

<-done

assertDeepEquals(t, client.state.Params, server.state.Params)
assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams)
assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret)
assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret)
assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret)
assertByteEquals(t, client.state.exporterSecret, server.state.exporterSecret)

emptyContext := []byte{}

assertByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 20))
assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 21))
assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "F", emptyContext, 20))
assertByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'A'}, 20))
assertNotByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'B'}, 20))
})
}
}

Expand Down
147 changes: 74 additions & 73 deletions state-machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,91 +310,92 @@ func messagesFromActions(instructions []HandshakeAction) []*HandshakeMessage {
// TODO: Unit tests for individual states
func TestStateMachineIntegration(t *testing.T) {
for caseName, params := range stateMachineIntegrationCases {
t.Logf("=== Integration Test (%s) ===", caseName)
t.Run(caseName, func(t *testing.T) {

var clientState, serverState HandshakeState
clientState = ClientStateStart{
Caps: params.clientCapabilities,
Opts: params.clientOptions,
}
serverState = ServerStateStart{Caps: params.serverCapabilities}
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
t.Logf("Server: %s", reflect.TypeOf(serverState).Name())
var clientState, serverState HandshakeState
clientState = ClientStateStart{
Caps: params.clientCapabilities,
Opts: params.clientOptions,
}
serverState = ServerStateStart{Caps: params.serverCapabilities}
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
t.Logf("Server: %s", reflect.TypeOf(serverState).Name())

clientStateSequence := []HandshakeState{clientState}
serverStateSequence := []HandshakeState{serverState}
clientStateSequence := []HandshakeState{clientState}
serverStateSequence := []HandshakeState{serverState}

// Create the ClientHello
clientState, clientInstr, alert := clientState.Next(nil)
clientToSend := messagesFromActions(clientInstr)
assertEquals(t, alert, AlertNoAlert)
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
clientStateSequence = append(clientStateSequence, clientState)
assertEquals(t, len(clientToSend), 1)
// Create the ClientHello
clientState, clientInstr, alert := clientState.Next(nil)
clientToSend := messagesFromActions(clientInstr)
assertEquals(t, alert, AlertNoAlert)
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
clientStateSequence = append(clientStateSequence, clientState)
assertEquals(t, len(clientToSend), 1)

for {
var clientInstr, serverInstr []HandshakeAction
var alert Alert
for {
var clientInstr, serverInstr []HandshakeAction
var alert Alert

// Client -> Server
serverToSend := []*HandshakeMessage{}
for _, body := range clientToSend {
t.Logf("C->S: %d", body.msgType)
serverState, serverInstr, alert = serverState.Next(body)
serverResponses := messagesFromActions(serverInstr)
assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from server [%v]", alert))
serverStateSequence = append(serverStateSequence, serverState)
t.Logf("Server: %s", reflect.TypeOf(serverState).Name())
serverToSend = append(serverToSend, serverResponses...)
}
// Client -> Server
serverToSend := []*HandshakeMessage{}
for _, body := range clientToSend {
t.Logf("C->S: %d", body.msgType)
serverState, serverInstr, alert = serverState.Next(body)
serverResponses := messagesFromActions(serverInstr)
assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from server [%v]", alert))
serverStateSequence = append(serverStateSequence, serverState)
t.Logf("Server: %s", reflect.TypeOf(serverState).Name())
serverToSend = append(serverToSend, serverResponses...)
}

// Server -> Client
clientToSend = []*HandshakeMessage{}
for _, body := range serverToSend {
t.Logf("S->C: %d", body.msgType)
clientState, clientInstr, alert = clientState.Next(body)
clientResponses := messagesFromActions(clientInstr)
assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from client [%v]", alert))
clientStateSequence = append(clientStateSequence, clientState)
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
clientToSend = append(clientToSend, clientResponses...)
}
// Server -> Client
clientToSend = []*HandshakeMessage{}
for _, body := range serverToSend {
t.Logf("S->C: %d", body.msgType)
clientState, clientInstr, alert = clientState.Next(body)
clientResponses := messagesFromActions(clientInstr)
assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from client [%v]", alert))
clientStateSequence = append(clientStateSequence, clientState)
t.Logf("Client: %s", reflect.TypeOf(clientState).Name())
clientToSend = append(clientToSend, clientResponses...)
}

clientConnected := reflect.TypeOf(clientState) == reflect.TypeOf(StateConnected{})
serverConnected := reflect.TypeOf(serverState) == reflect.TypeOf(StateConnected{})
if clientConnected && serverConnected {
c := clientState.(StateConnected)
s := serverState.(StateConnected)
clientConnected := reflect.TypeOf(clientState) == reflect.TypeOf(StateConnected{})
serverConnected := reflect.TypeOf(serverState) == reflect.TypeOf(StateConnected{})
if clientConnected && serverConnected {
c := clientState.(StateConnected)
s := serverState.(StateConnected)

// Test that we ended up at the same state
assertDeepEquals(t, c.Params, s.Params)
assertCipherSuiteParamsEquals(t, c.cryptoParams, s.cryptoParams)
assertByteEquals(t, c.resumptionSecret, s.resumptionSecret)
assertByteEquals(t, c.clientTrafficSecret, s.clientTrafficSecret)
assertByteEquals(t, c.serverTrafficSecret, s.serverTrafficSecret)
// Test that we ended up at the same state
assertDeepEquals(t, c.Params, s.Params)
assertCipherSuiteParamsEquals(t, c.cryptoParams, s.cryptoParams)
assertByteEquals(t, c.resumptionSecret, s.resumptionSecret)
assertByteEquals(t, c.clientTrafficSecret, s.clientTrafficSecret)
assertByteEquals(t, c.serverTrafficSecret, s.serverTrafficSecret)

// Test that the client went through the expected sequence of states
assertEquals(t, len(clientStateSequence), len(params.clientStateSequence))
for i, state := range clientStateSequence {
t.Logf("-- %d %s", i, reflect.TypeOf(state).Name())
assertSameType(t, state, params.clientStateSequence[i])
}
// Test that the client went through the expected sequence of states
assertEquals(t, len(clientStateSequence), len(params.clientStateSequence))
for i, state := range clientStateSequence {
t.Logf("-- %d %s", i, reflect.TypeOf(state).Name())
assertSameType(t, state, params.clientStateSequence[i])
}

// Test that the server went through the expected sequence of states
assertEquals(t, len(serverStateSequence), len(params.serverStateSequence))
for i, state := range serverStateSequence {
t.Logf("-- %d %s", i, reflect.TypeOf(state).Name())
assertSameType(t, state, params.serverStateSequence[i])
}
// Test that the server went through the expected sequence of states
assertEquals(t, len(serverStateSequence), len(params.serverStateSequence))
for i, state := range serverStateSequence {
t.Logf("-- %d %s", i, reflect.TypeOf(state).Name())
assertSameType(t, state, params.serverStateSequence[i])
}

break
}
break
}

clientStateName := reflect.TypeOf(clientState).Name()
serverStateName := reflect.TypeOf(serverState).Name()
if len(clientToSend) == 0 {
t.Fatalf("Deadlock at client=[%s] server=[%s]", clientStateName, serverStateName)
clientStateName := reflect.TypeOf(clientState).Name()
serverStateName := reflect.TypeOf(serverState).Name()
if len(clientToSend) == 0 {
t.Fatalf("Deadlock at client=[%s] server=[%s]", clientStateName, serverStateName)
}
}
}
})
}
}

0 comments on commit 1f893e2

Please sign in to comment.