diff --git a/client-state-machine.go b/client-state-machine.go index ffca45e..369493e 100644 --- a/client-state-machine.go +++ b/client-state-machine.go @@ -58,7 +58,7 @@ type clientStateStart struct { cookie []byte firstClientHello *HandshakeMessage helloRetryRequest *HandshakeMessage - hsCtx HandshakeContext + hsCtx *HandshakeContext } var _ HandshakeState = &clientStateStart{} @@ -172,8 +172,10 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ } ch.CipherSuites = compatibleSuites + // TODO(ekr@rtfm.com): Check that the ticket can be used for early + // data. // Signal early data if we're going to do it - if len(state.Opts.EarlyData) > 0 { + if state.Config.AllowEarlyData { state.Params.ClientSendingEarlyData = true ed = &EarlyDataExtension{} err = ch.Extensions.Add(ed) @@ -255,9 +257,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret) clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret) - } else if len(state.Opts.EarlyData) > 0 { - logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK") - return nil, nil, AlertInternalError } else { clientHello, err = state.hsCtx.hOut.HandshakeMessageFromBody(ch) if err != nil { @@ -291,7 +290,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ if state.Params.ClientSendingEarlyData { toSend = append(toSend, []HandshakeAction{ RekeyOut{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, - SendEarlyData{}, }...) } @@ -302,7 +300,7 @@ type clientStateWaitSH struct { Config *Config Opts ConnectionOptions Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext OfferedDH map[NamedGroup][]byte OfferedPSK PreSharedKey PSK []byte @@ -412,6 +410,11 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, body: h.Sum(nil), } + state.hsCtx.receivedEndOfFlight() + + // TODO(ekr@rtfm.com): Need to rekey with cleartext if we are on 0-RTT + // mode. In DTLS, we also need to bump the sequence number. + // This is a pre-existing defect in Mint. Issue #175. logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]") return clientStateStart{ Config: state.Config, @@ -515,7 +518,6 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) - logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") nextState := clientStateWaitEE{ Config: state.Config, @@ -530,13 +532,20 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, toSend := []HandshakeAction{ RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys}, } + // We're definitely not going to have to send anything with + // early data. + if !state.Params.ClientSendingEarlyData { + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(params, clientHandshakeTrafficSecret)}) + } + return nextState, toSend, AlertNoAlert } type clientStateWaitEE struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash masterSecret []byte @@ -596,6 +605,14 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, state.handshakeHash.Write(hm.Marshal()) + toSend := []HandshakeAction{} + + if state.Params.ClientSendingEarlyData && !state.Params.UsingEarlyData { + // We didn't get 0-RTT, so rekey to handshake. + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)}) + } + if state.Params.UsingPSK { logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") nextState := clientStateWaitFinished{ @@ -608,7 +625,7 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, } - return nextState, nil, AlertNoAlert + return nextState, toSend, AlertNoAlert } logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]") @@ -622,13 +639,13 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, } - return nextState, nil, AlertNoAlert + return nextState, toSend, AlertNoAlert } type clientStateWaitCertCR struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash masterSecret []byte @@ -706,7 +723,7 @@ func (state clientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeSta type clientStateWaitCert struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash @@ -760,7 +777,7 @@ func (state clientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState type clientStateWaitCV struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash @@ -861,7 +878,7 @@ func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, type clientStateWaitFinished struct { Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash @@ -933,6 +950,7 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS toSend := []HandshakeAction{} if state.Params.UsingEarlyData { + logf(logTypeHandshake, "Sending end of early data") // Note: We only send EOED if the server is actually going to use the early // data. Otherwise, it will never see it, and the transcripts will // mismatch. @@ -942,10 +960,11 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS state.handshakeHash.Write(eoedm.Marshal()) logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal()) - } - clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) - toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}) + // And then rekey to handshake + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)}) + } if state.Params.UsingClientAuth { // Extract constraints from certicateRequest @@ -1045,6 +1064,8 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS RekeyOut{epoch: EpochApplicationData, KeySet: clientTrafficKeys}, }...) + state.hsCtx.receivedEndOfFlight() + logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]") nextState := stateConnected{ Params: state.Params, diff --git a/common.go b/common.go index 565d15e..4c7e999 100644 --- a/common.go +++ b/common.go @@ -25,6 +25,7 @@ const ( RecordTypeAlert RecordType = 21 RecordTypeHandshake RecordType = 22 RecordTypeApplicationData RecordType = 23 + RecordTypeAck RecordType = 25 ) // enum {...} HandshakeType; @@ -166,6 +167,8 @@ const ( type State uint8 const ( + StateInit = 0 + // states valid for the client StateClientStart State = iota StateClientWaitSH @@ -179,6 +182,7 @@ const ( StateServerStart State = iota StateServerRecvdCH StateServerNegotiated + StateServerReadPastEarlyData StateServerWaitEOED StateServerWaitFlight2 StateServerWaitCert @@ -209,6 +213,8 @@ func (s State) String() string { return "Server RECVD_CH" case StateServerNegotiated: return "Server NEGOTIATED" + case StateServerReadPastEarlyData: + return "Server READ_PAST_EARLY_DATA" case StateServerWaitEOED: return "Server WAIT_EOED" case StateServerWaitFlight2: @@ -250,3 +256,9 @@ func (e Epoch) label() string { } return "Application data (updated)" } + +func assert(b bool) { + if !b { + panic("Assertion failed") + } +} diff --git a/common_test.go b/common_test.go index 3bcd96b..d23203f 100644 --- a/common_test.go +++ b/common_test.go @@ -17,7 +17,7 @@ func unhex(h string) []byte { return b } -func assert(t *testing.T, test bool, msg string) { +func assertTrue(t *testing.T, test bool, msg string) { t.Helper() prefix := string("") for i := 1; ; i++ { @@ -34,7 +34,7 @@ func assert(t *testing.T, test bool, msg string) { func assertError(t *testing.T, err error, msg string) { t.Helper() - assert(t, err != nil, msg) + assertTrue(t, err != nil, msg) } func assertNotError(t *testing.T, err error, msg string) { @@ -42,32 +42,32 @@ func assertNotError(t *testing.T, err error, msg string) { if err != nil { msg += ": " + err.Error() } - assert(t, err == nil, msg) + assertTrue(t, err == nil, msg) } func assertNil(t *testing.T, x interface{}, msg string) { t.Helper() - assert(t, x == nil, msg) + assertTrue(t, x == nil, msg) } func assertNotNil(t *testing.T, x interface{}, msg string) { t.Helper() - assert(t, x != nil, msg) + assertTrue(t, x != nil, msg) } func assertEquals(t *testing.T, a, b interface{}) { t.Helper() - assert(t, a == b, fmt.Sprintf("%+v != %+v", a, b)) + assertTrue(t, a == b, fmt.Sprintf("%+v != %+v", a, b)) } func assertByteEquals(t *testing.T, a, b []byte) { t.Helper() - assert(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b))) + assertTrue(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b))) } func assertNotByteEquals(t *testing.T, a, b []byte) { t.Helper() - assert(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b))) + assertTrue(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b))) } func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) { @@ -81,12 +81,81 @@ func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) { func assertDeepEquals(t *testing.T, a, b interface{}) { t.Helper() - assert(t, reflect.DeepEqual(a, b), fmt.Sprintf("%+v != %+v", a, b)) + assertTrue(t, reflect.DeepEqual(a, b), fmt.Sprintf("%+v != %+v", a, b)) } func assertSameType(t *testing.T, a, b interface{}) { t.Helper() A := reflect.TypeOf(a) B := reflect.TypeOf(b) - assert(t, A == B, fmt.Sprintf("%s != %s", A.Name(), B.Name())) + assertTrue(t, A == B, fmt.Sprintf("%s != %s", A.Name(), B.Name())) +} + +// Utilities for parametrized tests +// Represents the configuration for a given test instance. +type testInstanceState interface { + set(string, string) + get(string) string +} + +// This complicated mixin stuff is in case we want to extend this to +// pass other parameters that aren't srings. Arguably YAGNI. +type testInstanceStateMixin struct { + params map[string]string +} + +func (m *testInstanceStateMixin) set(k string, v string) { + m.params[k] = v +} + +func (m *testInstanceStateMixin) get(k string) string { + return m.params[k] +} + +func newTestInstanceStateMixin() testInstanceStateMixin { + return testInstanceStateMixin{make(map[string]string, 0)} +} + +type testInstanceStateBase struct { + testInstanceStateMixin +} + +func newTestInstanceStateBase() testInstanceState { + return &testInstanceStateBase{newTestInstanceStateMixin()} +} + +// Helper function. +func runParametrizedInner(t *testing.T, name string, state testInstanceState, inparams []testParameter, f parametrizedTest) { + next := inparams[1:] + + param := inparams[0] + for _, p := range param.vals { + state.set(param.name, p) + var n string + if len(name) > 0 { + n = name + "/" + } + n = n + param.name + "=" + p + + if len(next) == 0 { + t.Run(n, func(t *testing.T) { + f(t, n, state) + }) + continue + } + runParametrizedInner(t, n, state, next, f) + } +} + +// Nominally public API. +type testParameter struct { + name string + vals []string +} + +type parametrizedTest func(t *testing.T, name string, p testInstanceState) + +// This is the function you call. +func runParametrizedTest(t *testing.T, inparams []testParameter, f parametrizedTest) { + runParametrizedInner(t, "", newTestInstanceStateBase(), inparams, f) } diff --git a/conn.go b/conn.go index 0ce05b2..5455f56 100644 --- a/conn.go +++ b/conn.go @@ -13,8 +13,6 @@ import ( "time" ) -var WouldBlock = fmt.Errorf("Would have blocked") - type Certificate struct { Chain []*x509.Certificate PrivateKey crypto.Signer @@ -253,6 +251,8 @@ type ConnectionState struct { PeerCertificates []*x509.Certificate // certificate chain presented by remote peer VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates NextProto string // Selected ALPN proto + UsingPSK bool // Are we using PSK. + UsingEarlyData bool // Did we negotiate 0-RTT. } // Conn implements the net.Conn interface, as with "crypto/tls" @@ -263,8 +263,6 @@ type Conn struct { conn net.Conn isClient bool - EarlyData []byte - state stateConnected hState HandshakeState handshakeMutex sync.Mutex @@ -273,22 +271,27 @@ type Conn struct { readBuffer []byte in, out *RecordLayer - hsCtx HandshakeContext + hsCtx *HandshakeContext } func NewConn(conn net.Conn, config *Config, isClient bool) *Conn { - c := &Conn{conn: conn, config: config, isClient: isClient} + c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}} if !config.UseDTLS { - c.in = NewRecordLayerTLS(c.conn) - c.out = NewRecordLayerTLS(c.conn) - c.hsCtx.hIn = NewHandshakeLayerTLS(c.in) - c.hsCtx.hOut = NewHandshakeLayerTLS(c.out) + c.in = NewRecordLayerTLS(c.conn, directionRead) + c.out = NewRecordLayerTLS(c.conn, directionWrite) + c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in) + c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out) } else { - c.in = NewRecordLayerDTLS(c.conn) - c.out = NewRecordLayerDTLS(c.conn) - c.hsCtx.hIn = NewHandshakeLayerDTLS(c.in) - c.hsCtx.hOut = NewHandshakeLayerDTLS(c.out) - } + c.in = NewRecordLayerDTLS(c.conn, directionRead) + c.out = NewRecordLayerDTLS(c.conn, directionWrite) + c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in) + c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out) + c.hsCtx.timeoutMS = initialTimeout + c.hsCtx.timers = newTimerSet() + c.hsCtx.waitingNextFlight = true + } + c.in.label = c.label() + c.out.label = c.label() c.hsCtx.hIn.nonblocking = c.config.NonBlocking return c } @@ -374,20 +377,54 @@ func (c *Conn) consumeRecord() error { return io.EOF } + case RecordTypeAck: + if !c.hsCtx.hIn.datagram { + logf(logTypeHandshake, "Received ACK in TLS mode") + return AlertUnexpectedMessage + } + return c.hsCtx.processAck(pt.fragment) + case RecordTypeApplicationData: c.readBuffer = append(c.readBuffer, pt.fragment...) logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer) + } return err } +func readPartial(in *[]byte, buffer []byte) int { + logf(logTypeIO, "conn.Read input buffer now has len %d", len((*in))) + read := copy(buffer, *in) + *in = (*in)[read:] + + logf(logTypeVerbose, "Returning %v", string(buffer)) + return read +} + // Read application data up to the size of buffer. Handshake and alert records // are consumed by the Conn object directly. func (c *Conn) Read(buffer []byte) (int, error) { if _, connected := c.hState.(stateConnected); !connected { - return 0, errors.New("Read called before the handshake completed") + // Clients can't call Read prior to handshake completion. + if c.isClient { + return 0, errors.New("Read called before the handshake completed") + } + + // Neither can servers that don't allow early data. + if !c.config.AllowEarlyData { + return 0, errors.New("Read called before the handshake completed") + } + + // If there's no early data, then return WouldBlock + if len(c.hsCtx.earlyData) == 0 { + return 0, AlertWouldBlock + } + + return readPartial(&c.hsCtx.earlyData, buffer), nil } + + // The handshake is now connected. logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) if alert := c.Handshake(); alert != AlertNoAlert { return 0, alert @@ -397,6 +434,13 @@ func (c *Conn) Read(buffer []byte) (int, error) { return 0, nil } + // Run our timers. + if c.config.UseDTLS { + if err := c.hsCtx.timers.check(time.Now()); err != nil { + return 0, AlertInternalError + } + } + // Lock the input channel c.in.Lock() defer c.in.Unlock() @@ -406,30 +450,14 @@ func (c *Conn) Read(buffer []byte) (int, error) { // err can be nil if consumeRecord processed a non app-data // record. if err != nil { - if c.config.NonBlocking || err != WouldBlock { + if c.config.NonBlocking || err != AlertWouldBlock { logf(logTypeIO, "conn.Read returns err=%v", err) return 0, err } } } - var read int - n := len(buffer) - logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer)) - if len(c.readBuffer) <= n { - buffer = buffer[:len(c.readBuffer)] - copy(buffer, c.readBuffer) - read = len(c.readBuffer) - c.readBuffer = c.readBuffer[:0] - } else { - logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n) - copy(buffer[:n], c.readBuffer[:n]) - c.readBuffer = c.readBuffer[n:] - read = n - } - - logf(logTypeVerbose, "Returning %v", string(buffer)) - return read, nil + return readPartial(&c.readBuffer, buffer), nil } // Write application data @@ -438,6 +466,12 @@ func (c *Conn) Write(buffer []byte) (int, error) { c.out.Lock() defer c.out.Unlock() + if _, connected := c.hState.(stateConnected); !connected { + if !c.isClient || c.out.cipher.epoch != EpochEarlyData { + return 0, errors.New("Write called before the handshake completed (and early data not in use)") + } + } + // Send full-size fragments var start int sent := 0 @@ -549,11 +583,16 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { } case SendQueuedHandshake: - err := c.hsCtx.hOut.SendQueuedMessages() + _, err := c.hsCtx.hOut.SendQueuedMessages() if err != nil { logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) return AlertInternalError } + if c.config.UseDTLS { + c.hsCtx.timers.start(retransmitTimerLabel, + c.hsCtx.handshakeRetransmit, + c.hsCtx.timeoutMS) + } case RekeyIn: logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet) err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) @@ -570,62 +609,6 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { return AlertInternalError } - case SendEarlyData: - logf(logTypeHandshake, "%s Sending early data...", label) - _, err := c.Write(c.EarlyData) - if err != nil { - logf(logTypeHandshake, "%s Error writing early data: %v", label, err) - return AlertInternalError - } - - case ReadPastEarlyData: - logf(logTypeHandshake, "%s Reading past early data...", label) - // Scan past all records that fail to decrypt - _, err := c.in.PeekRecordType(!c.config.NonBlocking) - if err == nil { - break - } - _, ok := err.(DecryptError) - - for ok { - _, err = c.in.PeekRecordType(!c.config.NonBlocking) - if err == nil { - break - } - _, ok = err.(DecryptError) - } - - case ReadEarlyData: - logf(logTypeHandshake, "%s Reading early data...", label) - t, err := c.in.PeekRecordType(!c.config.NonBlocking) - if err != nil { - logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err) - return AlertInternalError - } - logf(logTypeHandshake, "%s Got record type(1): %v", label, t) - - for t == RecordTypeApplicationData { - // Read a record into the buffer. Note that this is safe - // in blocking mode because we read the record in in - // PeekRecordType. - pt, err := c.in.ReadRecord() - if err != nil { - logf(logTypeHandshake, "%s Error reading early data record: %v", label, err) - return AlertInternalError - } - - logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment) - c.EarlyData = append(c.EarlyData, pt.fragment...) - - t, err = c.in.PeekRecordType(!c.config.NonBlocking) - if err != nil { - logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err) - return AlertInternalError - } - logf(logTypeHandshake, "%s Got record type (2): %v", label, t) - } - logf(logTypeHandshake, "%s Done reading early data", label) - case StorePSK: logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity) if c.isClient { @@ -637,7 +620,8 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { } default: - logf(logTypeHandshake, "%s Unknown actionuction type", label) + logf(logTypeHandshake, "%s Unknown action type", label) + assert(false) return AlertInternalError } @@ -657,7 +641,6 @@ func (c *Conn) HandshakeSetup() Alert { opts := ConnectionOptions{ ServerName: c.config.ServerName, NextProtos: c.config.NextProtos, - EarlyData: c.EarlyData, } if c.isClient { @@ -706,19 +689,22 @@ type handshakeMessageReaderImpl struct { var _ handshakeMessageReader = &handshakeMessageReaderImpl{} func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) { - hm, err := r.hsCtx.hIn.ReadMessage() - if err == WouldBlock { - return nil, AlertWouldBlock - } - if err != nil { - logf(logTypeHandshake, "[client] Error reading message: %v", err) - return nil, AlertCloseNotify + var hm *HandshakeMessage + var err error + for { + hm, err = r.hsCtx.hIn.ReadMessage() + if err == AlertWouldBlock { + return nil, AlertWouldBlock + } + if err != nil { + logf(logTypeHandshake, "Error reading message: %v", err) + return nil, AlertCloseNotify + } + if hm != nil { + break + } } - // Once you have read a message, you no longer need the outgoing queue - // for DTLS. - r.hsCtx.hOut.ClearQueuedMessages() - return hm, AlertNoAlert } @@ -753,14 +739,21 @@ func (c *Conn) Handshake() Alert { state := c.hState _, connected := state.(stateConnected) - hmr := &handshakeMessageReaderImpl{hsCtx: &c.hsCtx} + hmr := &handshakeMessageReaderImpl{hsCtx: c.hsCtx} for !connected { var alert Alert var actions []HandshakeAction + // Advance the state machine state, actions, alert = state.Next(hmr) - if alert == WouldBlock { + if alert == AlertWouldBlock { logf(logTypeHandshake, "%s Would block reading message: %s", label, alert) + // If we blocked, then run our timers to see if any have expired. + if c.hsCtx.hIn.datagram { + if err := c.hsCtx.timers.check(time.Now()); err != nil { + return AlertInternalError + } + } return AlertWouldBlock } if alert == AlertCloseNotify { @@ -788,6 +781,31 @@ func (c *Conn) Handshake() Alert { if connected { c.state = state.(stateConnected) c.handshakeComplete = true + + if !c.isClient { + // Send NewSessionTicket if configured to + if c.config.SendSessionTickets { + actions, alert := c.state.NewSessionTicket( + c.config.TicketLen, + c.config.TicketLifetime, + c.config.EarlyDataLifetime) + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return alert + } + } + } + } + + // If there is early data, move it into the main buffer + if c.hsCtx.earlyData != nil { + c.readBuffer = c.hsCtx.earlyData + c.hsCtx.earlyData = nil + } } if c.config.NonBlocking { @@ -798,23 +816,6 @@ func (c *Conn) Handshake() Alert { } } - // Send NewSessionTicket if acting as server - if !c.isClient && c.config.SendSessionTickets { - actions, alert := c.state.NewSessionTicket( - c.config.TicketLen, - c.config.TicketLifetime, - c.config.EarlyDataLifetime) - - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - c.sendAlert(alert) - return alert - } - } - } - return AlertNoAlert } @@ -848,6 +849,9 @@ func (c *Conn) SendKeyUpdate(requestUpdate bool) error { } func (c *Conn) GetHsState() State { + if c.hState == nil { + return StateInit + } return c.hState.State() } @@ -878,7 +882,16 @@ func (c *Conn) ConnectionState() ConnectionState { state.NextProto = c.state.Params.NextProto state.VerifiedChains = c.state.verifiedChains state.PeerCertificates = c.state.peerCertificates + state.UsingPSK = c.state.Params.UsingPSK + state.UsingEarlyData = c.state.Params.UsingEarlyData } return state } + +func (c *Conn) label() string { + if c.isClient { + return "client" + } + return "server" +} diff --git a/conn_test.go b/conn_test.go index 1a050d1..46069cd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -86,14 +86,38 @@ func (p *pipeConn) RemoteAddr() net.Addr { return nil } func (p *pipeConn) SetDeadline(t time.Time) error { return nil } func (p *pipeConn) SetReadDeadline(t time.Time) error { return nil } func (p *pipeConn) SetWriteDeadline(t time.Time) error { return nil } +func (p *pipeConn) Left() int { return p.r.Len() } type bufferedConn struct { - buffer bytes.Buffer - w net.Conn + autoflush bool + buffer bytes.Buffer + w net.Conn + ctr int + lossPattern map[int]bool } -func (b *bufferedConn) Write(buf []byte) (n int, err error) { - return b.buffer.Write(buf) +func (b *bufferedConn) Write(buf []byte) (int, error) { + ctr := b.ctr + b.ctr++ + if ok := b.lossPattern[ctr]; ok { + fmt.Println("Losing write ", ctr) + return 0, nil + } + + n, err := b.buffer.Write(buf) + if err != nil { + return 0, err + } + if n != len(buf) { + return n, fmt.Errorf("Incomplete write") + } + if b.autoflush { + err := b.Flush() + if err != nil { + return 0, err + } + } + return 0, nil } func (p *bufferedConn) Read(data []byte) (n int, err error) { @@ -108,6 +132,13 @@ func (p *bufferedConn) RemoteAddr() net.Addr { return nil } func (p *bufferedConn) SetDeadline(t time.Time) error { return nil } func (p *bufferedConn) SetReadDeadline(t time.Time) error { return nil } func (p *bufferedConn) SetWriteDeadline(t time.Time) error { return nil } +func (b *bufferedConn) SetAutoflush() { + b.autoflush = true +} +func (b *bufferedConn) Left() int { + p := b.w.(*pipeConn) + return p.Left() +} func (b *bufferedConn) Flush() error { buf := b.buffer.Bytes() @@ -123,28 +154,38 @@ func (b *bufferedConn) Flush() error { return nil } +func (b *bufferedConn) Lose(m int) { + b.lossPattern[m] = true +} + +func (b *bufferedConn) Clear() { + b.buffer.Reset() +} + func newBufferedConn(p net.Conn) *bufferedConn { - return &bufferedConn{bytes.Buffer{}, p} + return &bufferedConn{ + false, bytes.Buffer{}, p, 0, make(map[int]bool, 0), + } } var ( serverKey, clientKey crypto.Signer serverCert, clientCert *x509.Certificate certificates, clientCertificates []*Certificate + clientName, serverName string psk PreSharedKey psks *PSKMapCache - basicConfig, dtlsConfig, nbConfig, hrrConfig, alpnConfig, pskConfig, pskECDHEConfig, pskDHEConfig, resumptionConfig, ffdhConfig, x25519Config *Config + basicConfig, dtlsConfig, nbConfig, nbConfigDTLS, hrrConfig, alpnConfig, pskConfig, pskDTLSConfig, pskECDHEConfig, pskDHEConfig, resumptionConfig, ffdhConfig, x25519Config *Config ) -const ( +func init() { + var err error + serverName = "example.com" clientName = "example.org" -) -func init() { - var err error serverKey, serverCert, err = MakeNewSelfSignedCert(serverName, ECDSA_P256_SHA256) if err != nil { panic(err) @@ -197,6 +238,14 @@ func init() { InsecureSkipVerify: true, } + nbConfigDTLS = &Config{ + ServerName: serverName, + Certificates: certificates, + NonBlocking: true, + UseDTLS: true, + InsecureSkipVerify: true, + } + hrrConfig = &Config{ ServerName: serverName, Certificates: certificates, @@ -219,6 +268,16 @@ func init() { InsecureSkipVerify: true, } + pskDTLSConfig = &Config{ + ServerName: serverName, + CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256}, + PSKs: psks, + AllowEarlyData: true, + UseDTLS: true, + NonBlocking: true, + InsecureSkipVerify: true, + } + pskECDHEConfig = &Config{ ServerName: serverName, CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256}, @@ -274,55 +333,81 @@ func computeExporter(t *testing.T, c *Conn, label string, context []byte, length return res } +func checkConsistency(t *testing.T, client *Conn, server *Conn) { + assertDeepEquals(t, client.state.Params, server.state.Params) + assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) + assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) + assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) + assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) + assertByteEquals(t, client.state.exporterSecret, server.state.exporterSecret) + + emptyContext := []byte{} + + assertByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 20)) + assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 21)) + assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "F", emptyContext, 20)) + assertByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'A'}, 20)) + assertNotByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'B'}, 20)) +} + +func testConnInner(t *testing.T, name string, p testInstanceState) { + // Configs array: + configs := map[string]*Config{"basic config": basicConfig, + "HRR": hrrConfig, + "ALPN": alpnConfig, + "FFDH": ffdhConfig, + "x25519": x25519Config, + } + + c := configs[p.get("config")] + conf := *c + + // Set up the test parameters. + if p.get("nonblocking") == "true" { + conf.NonBlocking = true + } + + cConn, sConn := pipe() + + client := Client(cConn, &conf) + server := Server(sConn, &conf) + + var clientAlert, serverAlert Alert + + done := make(chan bool) + go func(t *testing.T) { + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertNoAlert) + done <- true + }(t) + + clientAlert = client.Handshake() + assertEquals(t, clientAlert, AlertNoAlert) + + <-done + + checkConsistency(t, client, server) +} + func TestBasicFlows(t *testing.T) { - tests := []struct { - name string - config *Config - }{ - {"basic config", basicConfig}, - {"HRR", hrrConfig}, - {"ALPN", alpnConfig}, - {"FFDH", ffdhConfig}, - {"x25519", x25519Config}, - } - for _, testcase := range tests { - t.Run(fmt.Sprintf("with %s", testcase.name), func(t *testing.T) { - conf := testcase.config - cConn, sConn := pipe() - - client := Client(cConn, conf) - server := Server(sConn, conf) - - var clientAlert, serverAlert Alert - - done := make(chan bool) - go func(t *testing.T) { - serverAlert = server.Handshake() - assertEquals(t, serverAlert, AlertNoAlert) - done <- true - }(t) - - clientAlert = client.Handshake() - assertEquals(t, clientAlert, AlertNoAlert) - - <-done - - assertDeepEquals(t, client.state.Params, server.state.Params) - assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) - assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) - assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) - assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) - assertByteEquals(t, client.state.exporterSecret, server.state.exporterSecret) - - emptyContext := []byte{} - - assertByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 20)) - assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 21)) - assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "F", emptyContext, 20)) - assertByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'A'}, 20)) - assertNotByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'B'}, 20)) - }) + params := []testParameter{ + testParameter{ + "config", + []string{ + "basic config", + "HRR", + "ALPN", + "FFDH", + "x25519", + }, + }, + testParameter{ + "blocking", + []string{"true", "false"}, + }, } + + runParametrizedTest(t, params, testConnInner) } func TestInvalidSelfSigned(t *testing.T) { @@ -578,12 +663,8 @@ func TestClientAuth(t *testing.T) { <-done - assertDeepEquals(t, client.state.Params, server.state.Params) - assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) - assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) - assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) - assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) - assert(t, client.state.Params.UsingClientAuth, "Session did not negotiate client auth") + checkConsistency(t, client, server) + assertTrue(t, client.state.Params.UsingClientAuth, "Session did not negotiate client auth") } func TestClientAuthVerifyPeerAccepted(t *testing.T) { @@ -681,12 +762,9 @@ func TestPSKFlows(t *testing.T) { <-done - assertDeepEquals(t, client.state.Params, server.state.Params) - assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) - assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) - assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) - assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) - assert(t, client.state.Params.UsingPSK, "Session did not use the provided PSK") + checkConsistency(t, client, server) + + assertTrue(t, client.state.Params.UsingPSK, "Session did not use the provided PSK") } } @@ -724,11 +802,7 @@ func TestResumption(t *testing.T) { assertEquals(t, 1, n) <-done - assertDeepEquals(t, client1.state.Params, server1.state.Params) - assertCipherSuiteParamsEquals(t, client1.state.cryptoParams, server1.state.cryptoParams) - assertByteEquals(t, client1.state.resumptionSecret, server1.state.resumptionSecret) - assertByteEquals(t, client1.state.clientTrafficSecret, server1.state.clientTrafficSecret) - assertByteEquals(t, client1.state.serverTrafficSecret, server1.state.serverTrafficSecret) + checkConsistency(t, client1, server1) assertEquals(t, clientConfig.PSKs.Size(), 1) assertEquals(t, serverConfig.PSKs.Size(), 1) @@ -755,8 +829,8 @@ func TestResumption(t *testing.T) { receivedDelta := clientPSK.ReceivedAt.Sub(serverPSK.ReceivedAt) / time.Millisecond expiresDelta := clientPSK.ExpiresAt.Sub(serverPSK.ExpiresAt) / time.Millisecond - assert(t, receivedDelta < 10 && receivedDelta > -10, "Unequal received times") - assert(t, expiresDelta < 10 && expiresDelta > -10, "Unequal received times") + assertTrue(t, receivedDelta < 10 && receivedDelta > -10, "Unequal received times") + assertTrue(t, expiresDelta < 10 && expiresDelta > -10, "Unequal received times") // Phase 2: Verify that the session ticket gets used as a PSK cConn2, sConn2 := pipe() @@ -775,42 +849,54 @@ func TestResumption(t *testing.T) { client2.Read(nil) <-done - assertDeepEquals(t, client2.state.Params, server2.state.Params) - assertCipherSuiteParamsEquals(t, client2.state.cryptoParams, server2.state.cryptoParams) - assertByteEquals(t, client2.state.resumptionSecret, server2.state.resumptionSecret) - assertByteEquals(t, client2.state.clientTrafficSecret, server2.state.clientTrafficSecret) - assertByteEquals(t, client2.state.serverTrafficSecret, server2.state.serverTrafficSecret) - assert(t, client2.state.Params.UsingPSK, "Session did not use the provided PSK") + checkConsistency(t, client2, server2) + assertTrue(t, client2.state.Params.UsingPSK, "Session did not use the provided PSK") } -func Test0xRTT(t *testing.T) { - conf := pskConfig - cConn, sConn := pipe() - - client := Client(cConn, conf) - client.EarlyData = []byte("hello 0xRTT world!") +func test0xRTT(t *testing.T, name string, p testInstanceState) { + conf := *pskConfig + conf.NonBlocking = true - server := Server(sConn, conf) - - done := make(chan bool) - go func(t *testing.T) { - alert := server.Handshake() - assertEquals(t, alert, AlertNoAlert) - done <- true - }(t) - - alert := client.Handshake() - assertEquals(t, alert, AlertNoAlert) + if p.get("dtls") == "true" { + conf.UseDTLS = true + } - <-done + cConn, sConn := pipe() + cbConn := newBufferedConn(cConn) + cbConn.SetAutoflush() + sbConn := newBufferedConn(sConn) + sbConn.SetAutoflush() + + client := Client(cbConn, &conf) + server := Server(sbConn, &conf) + + client.Handshake() // This sends CH + zdata := []byte("ABC") + n, err := client.Write(zdata) // This should succeeed + assertNotError(t, err, "Client was able to write") + assertEquals(t, n, len(zdata)) + hsUntilBlocked(t, server, sbConn) // Read CH and early data. + tmp := make([]byte, 10) + n, err = server.Read(tmp) + assertNotError(t, err, "Error reading early data") + tmp = tmp[:n] + assertByteEquals(t, zdata, tmp) + hsRunHandshakeOneThread(t, client, server) + + assertTrue(t, client.state.Params.UsingEarlyData, "Session did not negotiate early data") + n, err = server.Read(tmp) + assertEquals(t, AlertWouldBlock, err) + assertEquals(t, 0, n) +} - assertDeepEquals(t, client.state.Params, server.state.Params) - assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) - assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) - assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) - assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) - assert(t, client.state.Params.UsingEarlyData, "Session did not negotiate early data") - assertByteEquals(t, client.EarlyData, server.EarlyData) +func Test0xRTT(t *testing.T) { + params := []testParameter{ + testParameter{ + "dtls", + []string{"true", "false"}, + }, + } + runParametrizedTest(t, params, test0xRTT) } func Test0xRTTFailure(t *testing.T) { @@ -831,7 +917,6 @@ func Test0xRTTFailure(t *testing.T) { cConn, sConn := pipe() client := Client(cConn, clientConfig) - client.EarlyData = []byte("hello 0xRTT world!") server := Server(sConn, serverConfig) @@ -1076,11 +1161,11 @@ func (h *testExtensionHandler) Check(t *testing.T, hs []HandshakeType) { for _, ht := range hs { v, ok := h.sent[ht] - assert(t, ok, "Cannot find handshake type in sent") - assert(t, v, "Value wasn't true in sent") + assertTrue(t, ok, "Cannot find handshake type in sent") + assertTrue(t, v, "Value wasn't true in sent") v, ok = h.rcvd[ht] - assert(t, ok, "Cannot find handshake type in rcvd") - assert(t, v, "Value wasn't true in rcvd") + assertTrue(t, ok, "Cannot find handshake type in rcvd") + assertTrue(t, v, "Value wasn't true in rcvd") } } @@ -1190,3 +1275,336 @@ func TestDTLS(t *testing.T) { HandshakeTypeEncryptedExtensions, }) } + +func TestNonblockingHandshakeAndDataFlowDTLS(t *testing.T) { + cConn, sConn := pipe() + + // Wrap these in a buffer so we can simulate blocking + cbConn := newBufferedConn(cConn) + sbConn := newBufferedConn(sConn) + + client := Client(cbConn, nbConfigDTLS) + server := Server(sbConn, nbConfigDTLS) + + var clientAlert, serverAlert Alert + + // Send ClientHello + clientAlert = client.Handshake() + assertEquals(t, clientAlert, AlertNoAlert) + assertEquals(t, client.GetHsState(), StateClientWaitSH) + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertWouldBlock) + assertEquals(t, server.GetHsState(), StateServerStart) + + // Release ClientHello + cbConn.Flush() + + // Process ClientHello, send server first flight. + states := []State{StateServerNegotiated, StateServerWaitFlight2, StateServerWaitFinished} + for _, state := range states { + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertNoAlert) + assertEquals(t, server.GetHsState(), state) + } + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertWouldBlock) + + clientAlert = client.Handshake() + assertEquals(t, clientAlert, AlertWouldBlock) + + // Release server first flight + sbConn.Flush() + states = []State{StateClientWaitEE, StateClientWaitCertCR, StateClientWaitCV, StateClientWaitFinished, StateClientConnected} + for _, state := range states { + clientAlert = client.Handshake() + assertEquals(t, client.GetHsState(), state) + assertEquals(t, clientAlert, AlertNoAlert) + } + + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertWouldBlock) + assertEquals(t, server.GetHsState(), StateServerWaitFinished) + + // Release client's second flight. + cbConn.Flush() + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertNoAlert) + assertEquals(t, server.GetHsState(), StateServerConnected) + + assertDeepEquals(t, client.state.Params, server.state.Params) + assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) + assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) + assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) + assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) + + buf := []byte{'a', 'b', 'c'} + n, err := client.Write(buf) + assertNotError(t, err, "Couldn't write") + assertEquals(t, n, len(buf)) + + // read := make([]byte, 5) + // n, err = server.Read(buf) +} + +func TestTimeoutAndRetransmissionDTLS(t *testing.T) { + cConn, sConn := pipe() + + // Wrap these in a buffer so we can simulate blocking + cbConn := newBufferedConn(cConn) + sbConn := newBufferedConn(sConn) + + client := Client(cbConn, nbConfigDTLS) + server := Server(sbConn, nbConfigDTLS) + + var clientAlert, serverAlert Alert + + // Send ClientHello + clientAlert = client.Handshake() + assertEquals(t, clientAlert, AlertNoAlert) + assertEquals(t, client.GetHsState(), StateClientWaitSH) + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertWouldBlock) + assertEquals(t, server.GetHsState(), StateServerStart) + + // Simulate loss for the ClientHello + cbConn.Clear() + + // Only client should be running a timer. + waiting, timeout := server.GetDTLSTimeout() + assertTrue(t, !waiting, fmt.Sprintf("Server timer armed: %v", timeout)) + + waiting, timeout = client.GetDTLSTimeout() + assertTrue(t, waiting, "Client timer not armed") + + // Now check the timer. + time.Sleep(timeout) + clientAlert = client.Handshake() + assertEquals(t, clientAlert, AlertWouldBlock) + assertEquals(t, client.GetHsState(), StateClientWaitSH) + + // Release ClientHello + cbConn.Flush() + + // Process ClientHello, send server first flight. + states := []State{StateServerNegotiated, StateServerWaitFlight2, StateServerWaitFinished} + for _, state := range states { + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertNoAlert) + assertEquals(t, server.GetHsState(), state) + } + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertWouldBlock) + + // Simulate loss for the server's first flight. + sbConn.Clear() + + // Both sides should be running timers + waiting, timeout = client.GetDTLSTimeout() + assertTrue(t, waiting, "Client timer not armed") + + waiting, timeout = server.GetDTLSTimeout() + assertTrue(t, waiting, "Server timer not armed") + + // Now check the timer. + time.Sleep(timeout) + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertWouldBlock) + assertEquals(t, server.GetHsState(), StateServerWaitFinished) + + sbConn.Flush() + states = []State{StateClientWaitEE, StateClientWaitCertCR, StateClientWaitCV, StateClientWaitFinished, StateClientConnected} + for _, state := range states { + clientAlert = client.Handshake() + assertEquals(t, client.GetHsState(), state) + assertEquals(t, clientAlert, AlertNoAlert) + } + + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertWouldBlock) + assertEquals(t, server.GetHsState(), StateServerWaitFinished) + + // Release client's second flight. + cbConn.Flush() + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertNoAlert) + assertEquals(t, server.GetHsState(), StateServerConnected) + + assertDeepEquals(t, client.state.Params, server.state.Params) + assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) + assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) + assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) + assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) +} + +func checkTimers(t *testing.T, c *Conn, labels []string) { + armed := c.hsCtx.timers.getAllTimers() + + ma := make(map[string]bool) + mb := make(map[string]bool) + + // Check that the arrays are the same + for _, a := range armed { + ma[a] = true + } + + for _, a := range labels { + mb[a] = true + assertTrue(t, ma[a], fmt.Sprintf("Timer should have been armed: %v", a)) + } + + for _, a := range armed { + assertTrue(t, mb[a], fmt.Sprintf("Timer should not have been armed: %v", a)) + } + +} + +func hsUntilBlocked(t *testing.T, c *Conn, b *bufferedConn) { + // First run until we have consumed all the data + for b.Left() > 0 { + alert := c.Handshake() + switch alert { + default: + t.Fatalf("Unexpected alert") + case AlertWouldBlock, AlertNoAlert: + } + } + + // Now run until we block + for { + alert := c.Handshake() + if alert == AlertWouldBlock { + return + } + assertEquals(t, alert, AlertNoAlert) + } +} + +func hsUntilComplete(t *testing.T, c *Conn) { + for { + alert := c.Handshake() + assertTrue(t, + alert == AlertWouldBlock || + alert == AlertNoAlert, + "Unexpected alert") + + if c.GetHsState() == StateClientConnected || + c.GetHsState() == StateServerConnected { + break + } + } +} + +func hsRunHandshakeOneThread(t *testing.T, client *Conn, server *Conn) { + assertTrue(t, client.config.NonBlocking && server.config.NonBlocking, "Both sides need to be in nonblocking mode") + for client.GetHsState() != StateClientConnected || server.GetHsState() != StateServerConnected { + alert := client.Handshake() + switch alert { + default: + t.Fatalf("Unexpected alert") + case AlertWouldBlock, AlertNoAlert: + } + + alert = server.Handshake() + switch alert { + default: + t.Fatalf("Unexpected alert %v", alert) + case AlertWouldBlock, AlertNoAlert: + } + } + checkConsistency(t, client, server) +} + +func runAllTimers(t *testing.T, c *Conn) { + for { + waiting, timeout := c.GetDTLSTimeout() + if !waiting { + return + } + + if timeout > 0 { + time.Sleep(timeout) + } + + alert := c.Handshake() + assertEquals(t, alert, AlertWouldBlock) + } +} + +func TestAckDTLSNormal(t *testing.T) { + cConn, sConn := pipe() + + cbConn := newBufferedConn(cConn) + sbConn := newBufferedConn(sConn) + cbConn.SetAutoflush() + sbConn.SetAutoflush() + + client := Client(cbConn, nbConfigDTLS) + server := Server(sbConn, nbConfigDTLS) + + // Send ClientHello + hsUntilBlocked(t, client, cbConn) + + // Process ClientHello, send server first flight. + hsUntilBlocked(t, server, sbConn) + + // Both sides should be have armed retransmit timers. + checkTimers(t, client, []string{retransmitTimerLabel}) + checkTimers(t, server, []string{retransmitTimerLabel}) + + // Now run the client and server to completion + hsUntilComplete(t, client) + hsUntilComplete(t, server) + + // Client will have retransmit until we read the ACK + checkTimers(t, client, []string{retransmitTimerLabel}) + + // Server should have no timer + checkTimers(t, server, []string{}) + + // Now read some data from the server so we get the ACK + b := make([]byte, 10) + n, _ := client.Read(b) + assertEquals(t, 0, n) + + // Client will now have no timers + checkTimers(t, client, []string{}) +} + +func TestAckDTLSLoseEE(t *testing.T) { + cConn, sConn := pipe() + + cbConn := newBufferedConn(cConn) + sbConn := newBufferedConn(sConn) + sbConn.Lose(1) // Lose EE + cbConn.SetAutoflush() + sbConn.SetAutoflush() + + client := Client(cbConn, nbConfigDTLS) + server := Server(sbConn, nbConfigDTLS) + + // Send ClientHello + hsUntilBlocked(t, client, cbConn) + + // Process ClientHello, send server first flight. + hsUntilBlocked(t, server, sbConn) + + // Both sides should be have armed retransmit timers. + checkTimers(t, client, []string{retransmitTimerLabel}) + checkTimers(t, server, []string{retransmitTimerLabel}) + + // Now process as much of the server first flight as is there. + hsUntilBlocked(t, client, cbConn) + + // Client should now have the ACK timer armed + checkTimers(t, client, []string{ackTimerLabel}) + + // Now expire the timers + runAllTimers(t, client) + + // Process ACK + hsUntilBlocked(t, server, sbConn) + + // Now run the client and server to completion + hsUntilComplete(t, client) + hsUntilComplete(t, server) +} diff --git a/crypto_test.go b/crypto_test.go index 52c92d6..21f21b7 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -60,8 +60,8 @@ func TestNewKeyShare(t *testing.T) { crv := curveFromNamedGroup(group) x, y := elliptic.Unmarshal(crv, pub) - assert(t, x != nil && y != nil, "Public key failed to unmarshal") - assert(t, crv.Params().IsOnCurve(x, y), "Public key not on curve") + assertTrue(t, x != nil && y != nil, "Public key failed to unmarshal") + assertTrue(t, crv.Params().IsOnCurve(x, y), "Public key not on curve") } for _, group := range nonECGroups { @@ -145,13 +145,13 @@ func TestNewSigningKey(t *testing.T) { privRSA, err := newSigningKey(RSA_PKCS1_SHA256) assertNotError(t, err, "failed to generate RSA private key") _, ok := privRSA.(*rsa.PrivateKey) - assert(t, ok, "New RSA key was not actually an RSA key") + assertTrue(t, ok, "New RSA key was not actually an RSA key") // Test ECDSA success (P-256) privECDSA, err := newSigningKey(ECDSA_P256_SHA256) assertNotError(t, err, "failed to generate RSA private key") _, ok = privECDSA.(*ecdsa.PrivateKey) - assert(t, ok, "New ECDSA key was not actually an ECDSA key") + assertTrue(t, ok, "New ECDSA key was not actually an ECDSA key") pub := privECDSA.(*ecdsa.PrivateKey).Public().(*ecdsa.PublicKey) assertEquals(t, P256, namedGroupFromECDSAKey(pub)) @@ -159,7 +159,7 @@ func TestNewSigningKey(t *testing.T) { privECDSA, err = newSigningKey(ECDSA_P384_SHA384) assertNotError(t, err, "failed to generate RSA private key") _, ok = privECDSA.(*ecdsa.PrivateKey) - assert(t, ok, "New ECDSA key was not actually an ECDSA key") + assertTrue(t, ok, "New ECDSA key was not actually an ECDSA key") pub = privECDSA.(*ecdsa.PrivateKey).Public().(*ecdsa.PublicKey) assertEquals(t, P384, namedGroupFromECDSAKey(pub)) @@ -167,7 +167,7 @@ func TestNewSigningKey(t *testing.T) { privECDSA, err = newSigningKey(ECDSA_P521_SHA512) assertNotError(t, err, "failed to generate RSA private key") _, ok = privECDSA.(*ecdsa.PrivateKey) - assert(t, ok, "New ECDSA key was not actually an ECDSA key") + assertTrue(t, ok, "New ECDSA key was not actually an ECDSA key") pub = privECDSA.(*ecdsa.PrivateKey).Public().(*ecdsa.PublicKey) assertEquals(t, P521, namedGroupFromECDSAKey(pub)) @@ -184,7 +184,7 @@ func TestSelfSigned(t *testing.T) { alg := ECDSA_P256_SHA256 cert, err := newSelfSigned("example.com", alg, priv) assertNotError(t, err, "Failed to sign certificate") - assert(t, len(cert.Raw) > 0, "Certificate had empty raw value") + assertTrue(t, len(cert.Raw) > 0, "Certificate had empty raw value") assertEquals(t, cert.SignatureAlgorithm, x509AlgMap[alg]) // Test failure on unknown signature algorithm diff --git a/dtls.go b/dtls.go index df4f1aa..91dee33 100644 --- a/dtls.go +++ b/dtls.go @@ -2,14 +2,35 @@ package mint import ( "fmt" + "github.com/bifurcation/mint/syntax" + "time" ) // This file is a placeholder. DTLS-specific stuff (timer management, // ACKs, retransmits, etc. will eventually go here. const ( - initialMtu = 1200 + initialMtu = 1200 + initialTimeout = 100 ) +// labels for timers +const ( + retransmitTimerLabel = "handshake retransmit" + ackTimerLabel = "ack timer" +) + +type SentHandshakeFragment struct { + seq uint32 + offset int + fragLength int + record uint64 + acked bool +} + +type DtlsAck struct { + RecordNumbers []uint64 `tls:"head=2"` +} + func wireVersion(h *HandshakeLayer) uint16 { if h.datagram { return dtls12WireVersion @@ -26,3 +47,180 @@ func dtlsConvertVersion(version uint16) uint16 { } panic(fmt.Sprintf("Internal error, unexpected version=%d", version)) } + +func (h *HandshakeContext) handshakeRetransmit() error { + if _, err := h.hOut.SendQueuedMessages(); err != nil { + return err + } + + h.timers.start(retransmitTimerLabel, + h.handshakeRetransmit, + h.timeoutMS) + + // TODO(ekr@rtfm.com): Back off timer + return nil +} + +func (h *HandshakeContext) sendAck() error { + toack := h.hIn.recvdRecords + + count := (initialMtu - 2) / 8 // TODO(ekr@rtfm.com): Current MTU + if len(toack) > count { + toack = toack[:count] + } + logf(logTypeHandshake, "Sending ACK: [%x]", toack) + + ack := &DtlsAck{toack} + body, err := syntax.Marshal(&ack) + if err != nil { + return err + } + err = h.hOut.conn.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeAck, + fragment: body, + }) + if err != nil { + return err + } + return nil +} + +func (h *HandshakeContext) processAck(data []byte) error { + // Cancel the retransmit timer because we will be resending + // and possibly re-arming later. + h.timers.cancel(retransmitTimerLabel) + + ack := &DtlsAck{} + read, err := syntax.Unmarshal(data, &ack) + if err != nil { + return err + } + if len(data) != read { + return fmt.Errorf("Invalid encoding: Extra data not consumed") + } + logf(logTypeHandshake, "ACK: [%x]", ack.RecordNumbers) + + for _, r := range ack.RecordNumbers { + for _, m := range h.sentFragments { + if r == m.record { + logf(logTypeHandshake, "Marking %v %v(%v) as acked", + m.seq, m.offset, m.fragLength) + m.acked = true + } + } + } + + count, err := h.hOut.SendQueuedMessages() + if err != nil { + return err + } + + if count == 0 { + logf(logTypeHandshake, "All messages ACKed") + h.hOut.ClearQueuedMessages() + return nil + } + + // Reset the timer + h.timers.start(retransmitTimerLabel, + h.handshakeRetransmit, + h.timeoutMS) + + return nil +} + +type connTimerCb func(c *Conn) error + +func (c *Conn) GetDTLSTimeout() (bool, time.Duration) { + return c.hsCtx.timers.remaining() +} + +func (h *HandshakeContext) receivedHandshakeMessage() { + logf(logTypeHandshake, "%p Received handshake, waiting for start of flight = %v", h, h.waitingNextFlight) + // This just enables tests. + if h.hIn == nil { + return + } + + if !h.hIn.datagram { + return + } + + if h.waitingNextFlight { + logf(logTypeHandshake, "Received the start of the flight") + + // Clear the outgoing DTLS queue and terminate the retransmit timer + h.hOut.ClearQueuedMessages() + h.timers.cancel(retransmitTimerLabel) + + // OK, we're not waiting any more. + h.waitingNextFlight = false + } + + // Now pre-emptively arm the ACK timer if it's not armed already. + // We'll automatically dis-arm it at the end of the handshake. + if h.timers.getTimer(ackTimerLabel) == nil { + h.timers.start(ackTimerLabel, h.sendAck, h.timeoutMS/4) + } +} + +func (h *HandshakeContext) receivedEndOfFlight() { + logf(logTypeHandshake, "%p Received the end of the flight", h) + if !h.hIn.datagram { + return + } + + // Empty incoming queue + h.hIn.queued = nil + + // Note that we are waiting for the next flight. + h.waitingNextFlight = true + + // Clear the ACK queue. + h.hIn.recvdRecords = nil + + // Disarm the ACK timer + h.timers.cancel(ackTimerLabel) +} + +func (h *HandshakeContext) receivedFinalFlight() { + logf(logTypeHandshake, "%p Received final flight", h) + if !h.hIn.datagram { + return + } + + // Disarm the ACK timer + h.timers.cancel(ackTimerLabel) + + // But send an ACK immediately. + h.sendAck() + +} + +func (h *HandshakeContext) fragmentAcked(seq uint32, offset int, fraglen int) bool { + logf(logTypeHandshake, "Looking to see if fragment %v %v(%v) was acked", seq, offset, fraglen) + for _, f := range h.sentFragments { + if !f.acked { + continue + } + + if f.seq != seq { + continue + } + + if f.offset > offset { + continue + } + + // At this point, we know that the stored fragment starts + // at or before what we want to send, so check where the end + // is. + if f.offset+f.fragLength < offset+fraglen { + continue + } + + return true + } + + return false +} diff --git a/extensions_test.go b/extensions_test.go index ad1e3bb..d9d43c3 100644 --- a/extensions_test.go +++ b/extensions_test.go @@ -360,17 +360,17 @@ func TestExtensionFind(t *testing.T) { ks := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} found, err := extListKeyShareIn.Find(&ks) assertNotError(t, err, "Failed to parse valid extension") - assert(t, found, "Failed to find a valid extension") + assertTrue(t, found, "Failed to find a valid extension") // Test find failure on absent extension var sg SupportedGroupsExtension found, err = extListKeyShareIn.Find(&sg) assertNotError(t, err, "Error on missing extension") - assert(t, !found, "Found an extension that's not present") + assertTrue(t, !found, "Found an extension that's not present") // Test find failure on unmarshal failure found, err = extListInvalidIn.Find(&ks) - assert(t, found, "Didn't found an extension that's not valid") + assertTrue(t, found, "Didn't found an extension that's not valid") assertError(t, err, "Parsed an invalid extension") } @@ -394,8 +394,8 @@ func TestExtensionParse(t *testing.T) { found, err := validExtensions.Parse(extensionsIn) assertNotError(t, err, "Failed to parse valid extensions") - assert(t, found[ExtensionTypeKeyShare], "Failed to find key share") - assert(t, found[ExtensionTypeSupportedVersions], "Failed to find supported versions") + assertTrue(t, found[ExtensionTypeKeyShare], "Failed to find key share") + assertTrue(t, found[ExtensionTypeSupportedVersions], "Failed to find supported versions") // Now a version with an error sv.HandshakeType = HandshakeTypeServerHello @@ -644,13 +644,13 @@ func TestPreSharedKeyMarshalUnmarshal(t *testing.T) { // Test finding an identity that is present id := []byte{1, 2, 3, 4} binder, found := pskClientIn.HasIdentity(id) - assert(t, found, "Failed to find present identity") + assertTrue(t, found, "Failed to find present identity") assertByteEquals(t, binder, bytes.Repeat([]byte{0xA0}, 32)) // Test finding an identity that is not present id = []byte{1, 2, 4, 3} _, found = pskClientIn.HasIdentity(id) - assert(t, !found, "Found a not-present identity") + assertTrue(t, !found, "Found a not-present identity") } func TestALPNMarshalUnmarshal(t *testing.T) { diff --git a/frame-reader.go b/frame-reader.go index 54f40ce..4ccfc23 100644 --- a/frame-reader.go +++ b/frame-reader.go @@ -67,7 +67,7 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) { f.writeOffset += copied if f.writeOffset < len(f.working) { logf(logTypeVerbose, "Read would have blocked 1") - return nil, nil, WouldBlock + return nil, nil, AlertWouldBlock } // Reset the write offset, because we are now full. f.writeOffset = 0 @@ -94,5 +94,5 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) { } logf(logTypeVerbose, "Read would have blocked 2") - return nil, nil, WouldBlock + return nil, nil, AlertWouldBlock } diff --git a/frame-reader_test.go b/frame-reader_test.go index 55c94a2..4ea5efd 100644 --- a/frame-reader_test.go +++ b/frame-reader_test.go @@ -64,7 +64,7 @@ func TestFrameReaderTrickle(t *testing.T) { for i := 0; i <= len(kTestFrame); i += 1 { hdr, body, err = r.process() if i < len(kTestFrame) { - assertEquals(t, err, WouldBlock) + assertEquals(t, err, AlertWouldBlock) assertEquals(t, 0, len(hdr)) assertEquals(t, 0, len(body)) r.addChunk(kTestFrame[i : i+1]) diff --git a/handshake-layer.go b/handshake-layer.go index 888c5f3..de17b30 100644 --- a/handshake-layer.go +++ b/handshake-layer.go @@ -35,7 +35,6 @@ type HandshakeMessage struct { datagram bool offset uint32 // Used for DTLS length uint32 - records []uint64 // Used for DTLS cipher *cipherState } @@ -119,6 +118,7 @@ func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*H } type HandshakeLayer struct { + ctx *HandshakeContext // The handshake we are attached to nonblocking bool // Should we operate in nonblocking mode conn *RecordLayer // Used for reading/writing records frame *frameReader // The buffered frame reader @@ -126,6 +126,7 @@ type HandshakeLayer struct { msgSeq uint32 // The DTLS message sequence number queued []*HandshakeMessage // In/out queue sent []*HandshakeMessage // Sent messages for DTLS + recvdRecords []uint64 // Records we have received. maxFragmentLen int } @@ -152,8 +153,9 @@ func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { return int(val), nil } -func NewHandshakeLayerTLS(r *RecordLayer) *HandshakeLayer { +func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer { h := HandshakeLayer{} + h.ctx = c h.conn = r h.datagram = false h.frame = newFrameReader(&handshakeLayerFrameDetails{false}) @@ -161,8 +163,9 @@ func NewHandshakeLayerTLS(r *RecordLayer) *HandshakeLayer { return &h } -func NewHandshakeLayerDTLS(r *RecordLayer) *HandshakeLayer { +func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer { h := HandshakeLayer{} + h.ctx = c h.conn = r h.datagram = true h.frame = newFrameReader(&handshakeLayerFrameDetails{true}) @@ -172,16 +175,25 @@ func NewHandshakeLayerDTLS(r *RecordLayer) *HandshakeLayer { func (h *HandshakeLayer) readRecord() error { logf(logTypeVerbose, "Trying to read record") - pt, err := h.conn.ReadRecord() + pt, err := h.conn.readRecordAnyEpoch() if err != nil { return err } - if pt.contentType != RecordTypeHandshake && - pt.contentType != RecordTypeAlert { + switch pt.contentType { + case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck: + default: return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) } + if pt.contentType == RecordTypeAck { + if !h.datagram { + return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS") + } + logf(logTypeIO, "read ACK") + return h.ctx.processAck(pt.fragment) + } + if pt.contentType == RecordTypeAlert { logf(logTypeIO, "read alert %v", pt.fragment[1]) if len(pt.fragment) < 2 { @@ -191,6 +203,19 @@ func (h *HandshakeLayer) readRecord() error { return Alert(pt.fragment[1]) } + assert(h.ctx.hIn.conn != nil) + if pt.epoch != h.ctx.hIn.conn.cipher.epoch { + // This is out of order but we're dropping it. + // TODO(ekr@rtfm.com): If server, need to retransmit Finished. + if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData { + return nil + } + + // Anything else shouldn't happen. + return AlertIllegalParameter + } + + h.recvdRecords = append(h.recvdRecords, pt.seq) h.frame.addChunk(pt.fragment) return nil @@ -227,9 +252,13 @@ func (h *HandshakeLayer) noteMessageDelivered(seq uint32) { func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) { if hm.seq < h.msgSeq { - return nil, WouldBlock + return nil, nil } + // TODO(ekr@rtfm.com): Send an ACK immediately if we got something + // out of order. + h.ctx.receivedHandshakeMessage() + if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { // TODO(ekr@rtfm.com): Check the length? // This is complete. @@ -259,12 +288,12 @@ func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMe func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) { if len(h.queued) == 0 { - return nil, WouldBlock + return nil, nil } hm := h.queued[0] if hm.seq != h.msgSeq { - return nil, WouldBlock + return nil, nil } if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { @@ -307,7 +336,7 @@ func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) { } - return nil, WouldBlock + return nil, nil } func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { @@ -315,19 +344,19 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { var err error hm, err := h.checkMessageAvailable() - if err == nil { - return hm, err - } - if err != WouldBlock { + if err != nil { return nil, err } + if hm != nil { + return hm, nil + } for { logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder)) if h.frame.needed() > 0 { logf(logTypeVerbose, "Trying to read a new record") err = h.readRecord() - if err != nil && (h.nonblocking || err != WouldBlock) { + if err != nil && (h.nonblocking || err != AlertWouldBlock) { return nil, err } } @@ -336,7 +365,7 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { if err == nil { break } - if err != nil && (h.nonblocking || err != WouldBlock) { + if err != nil && (h.nonblocking || err != AlertWouldBlock) { return nil, err } } @@ -370,12 +399,13 @@ func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error { return nil } -func (h *HandshakeLayer) SendQueuedMessages() error { +func (h *HandshakeLayer) SendQueuedMessages() (int, error) { logf(logTypeHandshake, "Sending outgoing messages") - err := h.WriteMessages(h.queued) - h.ClearQueuedMessages() // This isn't going to work for DTLS, but we'll - // get there. - return err + count, err := h.WriteMessages(h.queued) + if !h.datagram { + h.ClearQueuedMessages() + } + return count, err } func (h *HandshakeLayer) ClearQueuedMessages() { @@ -383,7 +413,7 @@ func (h *HandshakeLayer) ClearQueuedMessages() { h.queued = nil } -func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (int, error) { +func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) { var buf []byte // Figure out if we're going to want the full header or just @@ -408,17 +438,35 @@ func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int } body := hm.body[start : start+bodylen] + // Now see if this chunk has been ACKed. This doesn't produce ideal + // retransmission but is simple. + if h.ctx.fragmentAcked(hm.seq, start, bodylen) { + logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen) + return false, start + bodylen, nil + } + // Encode the data. if hdrlen > 0 { hm2 := *hm hm2.offset = uint32(start) hm2.body = body buf = hm2.Marshal() + hm = &hm2 } else { buf = body } - return start + bodylen, h.conn.writeRecordWithPadding( + if h.datagram { + // Remember that we sent this. + h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{ + hm.seq, + start, + len(body), + h.conn.cipher.combineSeq(true), + false, + }) + } + return true, start + bodylen, h.conn.writeRecordWithPadding( &TLSPlaintext{ contentType: RecordTypeHandshake, fragment: buf, @@ -426,38 +474,46 @@ func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int hm.cipher, 0) } -func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error { +func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) { start := int(0) if len(hm.body) > maxHandshakeMessageLen { - return fmt.Errorf("Tried to write a handshake message that's too long") + return 0, fmt.Errorf("Tried to write a handshake message that's too long") } + written := 0 + wrote := false + // Always make one pass through to allow EOED (which is empty). for { var err error - start, err = h.writeFragment(hm, start, h.maxFragmentLen) + wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen) if err != nil { - return err + return 0, err + } + if wrote { + written++ } if start >= len(hm.body) { break } } - return nil + return written, nil } -func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error { +func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) { + written := 0 for _, hm := range hms { logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) - err := h.WriteMessage(hm) + wrote, err := h.WriteMessage(hm) if err != nil { - return err + return 0, err } + written += wrote } - return nil + return written, nil } func encodeUint(v uint64, size int, out []byte) []byte { diff --git a/handshake-layer_test.go b/handshake-layer_test.go index bdd7d16..1abc1cd 100644 --- a/handshake-layer_test.go +++ b/handshake-layer_test.go @@ -154,7 +154,7 @@ func TestMessageFromBody(t *testing.T) { chValid := unhex(chValidHex) b := bytes.NewBuffer(nil) - h := NewHandshakeLayerTLS(NewRecordLayerTLS(b)) + h := NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionRead)) // Test successful conversion hm, err := h.HandshakeMessageFromBody(&chValidIn) @@ -169,6 +169,13 @@ func TestMessageFromBody(t *testing.T) { chValidIn.CipherSuites = chCipherSuites } +func newHandshakeLayerFromBytes(d []byte) *HandshakeLayer { + hc := &HandshakeContext{} + b := bytes.NewBuffer(d) + hc.hIn = NewHandshakeLayerTLS(hc, NewRecordLayerTLS(b, directionRead)) + return hc.hIn +} + func TestReadHandshakeMessage(t *testing.T) { short := unhex(shortHex) long := unhex(longHex) @@ -177,22 +184,19 @@ func TestReadHandshakeMessage(t *testing.T) { nonHandshake := unhex(nonHandshakeHex) // Test successful read of a message in a single record - b := bytes.NewBuffer(short) - h := NewHandshakeLayerTLS(NewRecordLayerTLS(b)) + h := newHandshakeLayerFromBytes(short) hm, err := h.ReadMessage() assertNotError(t, err, "Failed to read a short handshake message") assertDeepEquals(t, hm, shortMessageIn) // Test successful read of a message split across records - b = bytes.NewBuffer(long) - h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) + h = newHandshakeLayerFromBytes(long) hm, err = h.ReadMessage() assertNotError(t, err, "Failed to read a long handshake message") assertDeepEquals(t, hm, longMessageIn) // Test successful read of multiple messages sequentially - b = bytes.NewBuffer(shortLongShort) - h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) + h = newHandshakeLayerFromBytes(shortLongShort) hm1, err := h.ReadMessage() assertNotError(t, err, "Failed to read first handshake message") assertDeepEquals(t, hm1, shortMessageIn) @@ -204,27 +208,25 @@ func TestReadHandshakeMessage(t *testing.T) { assertDeepEquals(t, hm3, shortMessageIn) // Test read failure on inability to read header - b = bytes.NewBuffer(short[:handshakeHeaderLenTLS-1]) - h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) + h = newHandshakeLayerFromBytes(short[:handshakeHeaderLenTLS-1]) hm, err = h.ReadMessage() assertError(t, err, "Read handshake message with an incomplete header") // Test read failure on inability to read body - b = bytes.NewBuffer(insufficientData) - h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) + h = newHandshakeLayerFromBytes(insufficientData) hm, err = h.ReadMessage() assertError(t, err, "Read handshake message with an incomplete body") // Test read failure on receiving a non-handshake record - b = bytes.NewBuffer(nonHandshake) - h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) + h = newHandshakeLayerFromBytes(nonHandshake) hm, err = h.ReadMessage() assertError(t, err, "Read handshake message from a non-handshake record") } func testWriteHandshakeMessage(h *HandshakeLayer, hm *HandshakeMessage) error { hm.cipher = h.conn.cipher - return h.WriteMessage(hm) + _, err := h.WriteMessage(hm) + return err } func TestWriteHandshakeMessage(t *testing.T) { @@ -233,26 +235,26 @@ func TestWriteHandshakeMessage(t *testing.T) { // Test successful write of single message b := bytes.NewBuffer(nil) - h := NewHandshakeLayerTLS(NewRecordLayerTLS(b)) + h := NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionWrite)) err := testWriteHandshakeMessage(h, shortMessageIn) assertNotError(t, err, "Failed to write valid short message") assertByteEquals(t, b.Bytes(), short) // Test successful write of single long message b = bytes.NewBuffer(nil) - h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) + h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionWrite)) err = testWriteHandshakeMessage(h, longMessageIn) assertNotError(t, err, "Failed to write valid long message") assertByteEquals(t, b.Bytes(), long) // Test write failure on message too large b = bytes.NewBuffer(nil) - h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) + h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(b, directionWrite)) err = testWriteHandshakeMessage(h, tooLongMessageIn) assertError(t, err, "Wrote a message exceeding the length bound") // Test write failure on underlying write failure - h = NewHandshakeLayerTLS(NewRecordLayerTLS(ErrorReadWriter{})) + h = NewHandshakeLayerTLS(&HandshakeContext{}, NewRecordLayerTLS(ErrorReadWriter{}, directionWrite)) err = testWriteHandshakeMessage(h, longMessageIn) assertError(t, err, "Write succeeded despite error in full fragment send") err = testWriteHandshakeMessage(h, shortMessageIn) @@ -261,6 +263,7 @@ func TestWriteHandshakeMessage(t *testing.T) { type testReassembleFixture struct { t *testing.T + c HandshakeContext h *HandshakeLayer r *RecordLayer rd *pipeConn @@ -294,8 +297,11 @@ func newTestReassembleFixture(t *testing.T) *testReassembleFixture { } f.m1 = newHsFragment(m1, 1, 0, 2048) f.rd, f.wr = pipe() - f.r = NewRecordLayerDTLS(f.rd) - f.h = NewHandshakeLayerDTLS(f.r) + + f.r = NewRecordLayerDTLS(f.rd, directionRead) + f.h = NewHandshakeLayerDTLS(&f.c, f.r) + f.c.hIn = f.h + f.c.timers = newTimerSet() f.h.nonblocking = true return &f @@ -310,7 +316,6 @@ func newHsFragment(full []byte, seq uint32, offset uint32, fragLen uint32) *Hand offset, uint32(len(full)), nil, - nil, } } @@ -326,7 +331,7 @@ func (f *testReassembleFixture) addFragment(in *HandshakeMessage, expected *Hand h2, err := f.h.ReadMessage() if expected == nil { assertEquals(f.t, (*HandshakeMessage)(nil), h2) - assertEquals(f.t, WouldBlock, err) + assertEquals(f.t, nil, err) } else { assertNotError(f.t, err, "Error reading handshake") assertEquals(f.t, expected.seq, h2.seq) @@ -338,7 +343,7 @@ func TestHandshakeDTLSInOrder(t *testing.T) { f := newTestReassembleFixture(t) f.addFragment(f.m0, f.m0) - f.addFragment(f.m0, nil) // Should block + f.addFragment(f.m0, nil) f.addFragment(f.m1, f.m1) } diff --git a/log_test.go b/log_test.go index adb942c..39c2db7 100644 --- a/log_test.go +++ b/log_test.go @@ -20,15 +20,15 @@ func TestLogging(t *testing.T) { logSettings = map[string]bool{} env := []string{"MINT_LOG=*"} parseLogEnv(env) - assert(t, logAll, "Failed to parse wildcard log directive") - assert(t, len(logSettings) == 0, "Mistakenly set log settings") + assertTrue(t, logAll, "Failed to parse wildcard log directive") + assertTrue(t, len(logSettings) == 0, "Mistakenly set log settings") logAll = false logSettings = map[string]bool{} env = []string{"MINT_LOG=foo,bar"} parseLogEnv(env) - assert(t, !logAll, "Mistakenly set logAll") - assert(t, logSettings["foo"] && logSettings["bar"], "Failed to parse string log directive") + assertTrue(t, !logAll, "Mistakenly set logAll") + assertTrue(t, logSettings["foo"] && logSettings["bar"], "Failed to parse string log directive") logFunction = testLogFunction logAll = false diff --git a/negotiation.go b/negotiation.go index 4697bbc..d829cc2 100644 --- a/negotiation.go +++ b/negotiation.go @@ -168,10 +168,10 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme return nil, 0, fmt.Errorf("No certificates compatible with signature schemes") } -func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool { +func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) ( /* using */ bool /* rejected */, bool) { usingEarlyData := gotEarlyData && usingPSK && allowEarlyData logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData) - return usingEarlyData + return usingEarlyData, gotEarlyData && !usingEarlyData } func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) { diff --git a/negotiation_test.go b/negotiation_test.go index 705a56d..5f571c1 100644 --- a/negotiation_test.go +++ b/negotiation_test.go @@ -100,18 +100,18 @@ func TestPSKNegotiation(t *testing.T) { func TestPSKModeNegotiation(t *testing.T) { // Test that everything that's allowed gets used usingDH, usingPSK := PSKModeNegotiation(true, true, []PSKKeyExchangeMode{PSKModeKE, PSKModeDHEKE}) - assert(t, usingDH, "Unnecessarily disabled DH") - assert(t, usingPSK, "Unnecessarily disabled PSK") + assertTrue(t, usingDH, "Unnecessarily disabled DH") + assertTrue(t, usingPSK, "Unnecessarily disabled PSK") // Test that DH is disabled when not allowed with the PSK usingDH, usingPSK = PSKModeNegotiation(true, true, []PSKKeyExchangeMode{PSKModeKE}) - assert(t, !usingDH, "Should not have enabled DH") - assert(t, usingPSK, "Unnecessarily disabled PSK") + assertTrue(t, !usingDH, "Should not have enabled DH") + assertTrue(t, usingPSK, "Unnecessarily disabled PSK") // Test that the PSK is disabled when DH is required but not possible usingDH, usingPSK = PSKModeNegotiation(false, true, []PSKKeyExchangeMode{PSKModeDHEKE}) - assert(t, !usingDH, "Should not have enabled DH") - assert(t, !usingPSK, "Should not have enabled PSK") + assertTrue(t, !usingDH, "Should not have enabled DH") + assertTrue(t, !usingPSK, "Should not have enabled PSK") } func TestCertificateSelection(t *testing.T) { @@ -142,17 +142,21 @@ func TestCertificateSelection(t *testing.T) { } func TestEarlyDataNegotiation(t *testing.T) { - useEarlyData := EarlyDataNegotiation(true, true, true) - assert(t, useEarlyData, "Did not use early data when allowed") + useEarlyData, rejected := EarlyDataNegotiation(true, true, true) + assertTrue(t, useEarlyData, "Did not use early data when allowed") + assertTrue(t, !rejected, "Rejected when allowed") - useEarlyData = EarlyDataNegotiation(false, true, true) - assert(t, !useEarlyData, "Allowed early data when not using PSK") + useEarlyData, rejected = EarlyDataNegotiation(false, true, true) + assertTrue(t, !useEarlyData, "Allowed early data when not using PSK") + assertTrue(t, rejected, "Rejected not set") - useEarlyData = EarlyDataNegotiation(true, false, true) - assert(t, !useEarlyData, "Allowed early data when not signaled") + useEarlyData, rejected = EarlyDataNegotiation(true, false, true) + assertTrue(t, !useEarlyData, "Allowed early data when not signaled") + assertTrue(t, !rejected, "Rejected when not signaled") - useEarlyData = EarlyDataNegotiation(true, true, false) - assert(t, !useEarlyData, "Allowed early data when not allowed") + useEarlyData, rejected = EarlyDataNegotiation(true, true, false) + assertTrue(t, !useEarlyData, "Allowed early data when not allowed") + assertTrue(t, rejected, "Rejected not set") } func TestCipherSuiteNegotiation(t *testing.T) { diff --git a/record-layer.go b/record-layer.go index 761a868..025e780 100644 --- a/record-layer.go +++ b/record-layer.go @@ -1,7 +1,6 @@ package mint import ( - "bytes" "crypto/cipher" "fmt" "io" @@ -21,6 +20,13 @@ func (err DecryptError) Error() string { return string(err) } +type direction uint8 + +const ( + directionWrite = direction(1) + directionRead = direction(2) +) + // struct { // ContentType type; // ProtocolVersion record_version [0301 for CH, 0303 for others] @@ -31,20 +37,23 @@ type TLSPlaintext struct { // Omitted: record_version (static) // Omitted: length (computed from fragment) contentType RecordType + epoch Epoch + seq uint64 fragment []byte } type cipherState struct { epoch Epoch // DTLS epoch ivLength int // Length of the seq and nonce fields - seq []byte // Zero-padded sequence number + seq uint64 // Zero-padded sequence number iv []byte // Buffer for the IV cipher cipher.AEAD // AEAD cipher } type RecordLayer struct { sync.Mutex - + label string + direction direction version uint16 // The current version number conn io.ReadWriter // The underlying connection frame *frameReader // The buffered frame reader @@ -52,7 +61,9 @@ type RecordLayer struct { cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" cachedError error // Error on the last record read - cipher *cipherState + cipher *cipherState + readCiphers map[Epoch]*cipherState + datagram bool } @@ -76,7 +87,7 @@ func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { } func newCipherStateNull() *cipherState { - return &cipherState{EpochClear, 0, bytes.Repeat([]byte{0}, sequenceNumberLen), nil, nil} + return &cipherState{EpochClear, 0, 0, nil, nil} } func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) { @@ -85,11 +96,13 @@ func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) return nil, err } - return &cipherState{epoch, len(iv), bytes.Repeat([]byte{0}, sequenceNumberLen), iv, cipher}, nil + return &cipherState{epoch, len(iv), 0, iv, cipher}, nil } -func NewRecordLayerTLS(conn io.ReadWriter) *RecordLayer { +func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer { r := RecordLayer{} + r.label = "" + r.direction = dir r.conn = conn r.frame = newFrameReader(recordLayerFrameDetails{false}) r.cipher = newCipherStateNull() @@ -97,11 +110,14 @@ func NewRecordLayerTLS(conn io.ReadWriter) *RecordLayer { return &r } -func NewRecordLayerDTLS(conn io.ReadWriter) *RecordLayer { +func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer { r := RecordLayer{} + r.label = "" + r.direction = dir r.conn = conn r.frame = newFrameReader(recordLayerFrameDetails{true}) r.cipher = newCipherStateNull() + r.readCiphers = make(map[Epoch]*cipherState, 0) r.datagram = true return &r } @@ -116,47 +132,55 @@ func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []b return err } r.cipher = cipher + if r.datagram && r.direction == directionRead { + r.readCiphers[epoch] = cipher + } return nil } -func (c *cipherState) formatSeq(datagram bool) []byte { - seq := append([]byte{}, c.seq...) +func (r *RecordLayer) DiscardReadKey(epoch Epoch) { + if !r.datagram { + return + } + + _, ok := r.readCiphers[epoch] + assert(ok) + delete(r.readCiphers, epoch) +} + +func (c *cipherState) combineSeq(datagram bool) uint64 { + seq := c.seq if datagram { - seq[0] = byte(c.epoch >> 8) - seq[1] = byte(c.epoch & 0xff) + seq |= uint64(c.epoch) << 48 } return seq } -func (c *cipherState) computeNonce(seq []byte) []byte { +func (c *cipherState) computeNonce(seq uint64) []byte { nonce := make([]byte, len(c.iv)) copy(nonce, c.iv) - offset := len(c.iv) - len(seq) - for i, b := range seq { - nonce[i+offset] ^= b + s := seq + + offset := len(c.iv) + for i := 0; i < 8; i++ { + nonce[(offset-i)-1] ^= byte(s & 0xff) + s >>= 8 } + logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce) return nonce } func (c *cipherState) incrementSequenceNumber() { - var i int - for i = len(c.seq) - 1; i >= 0; i-- { - c.seq[i]++ - if c.seq[i] != 0 { - break - } - } - - if i < 0 { + if c.seq >= 1<<48 { // Not allowed to let sequence number wrap. // Instead, must renegotiate before it does. - // Not likely enough to bother. - // TODO(ekr@rtfm.com): Check for DTLS here - // because the limit is sooner. + // Not likely enough to bother. This is the + // DTLS limit. panic("TLS: sequence number wraparound") } + c.seq++ } func (c *cipherState) overhead() int { @@ -166,8 +190,9 @@ func (c *cipherState) overhead() int { return c.cipher.Overhead() } -func (r *RecordLayer) encrypt(cipher *cipherState, seq []byte, pt *TLSPlaintext, padLen int) *TLSPlaintext { - logf(logTypeIO, "Encrypt seq=[%x]", seq) +func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext { + assert(r.direction == directionWrite) + logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq) // Expand the fragment to hold contentType, padding, and overhead originalLen := len(pt.fragment) plaintextLen := originalLen + 1 + padLen @@ -191,8 +216,9 @@ func (r *RecordLayer) encrypt(cipher *cipherState, seq []byte, pt *TLSPlaintext, return out } -func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, error) { - logf(logTypeIO, "Decrypt seq=[%x]", seq) +func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) { + assert(r.direction == directionRead) + logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq) if len(pt.fragment) < r.cipher.overhead() { msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead()) return nil, 0, DecryptError(msg) @@ -207,7 +233,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, // Decrypt _, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil) if err != nil { - logf(logTypeIO, "AEAD decryption failure [%x]", pt) + logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt) return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") } @@ -222,6 +248,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, // Truncate the message to remove contentType, padding, overhead out.fragment = out.fragment[:newLen] + out.seq = seq return out, padLen, nil } @@ -230,11 +257,11 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { var err error for { - pt, err = r.nextRecord() + pt, err = r.nextRecord(false) if err == nil { break } - if !block || err != WouldBlock { + if !block || err != AlertWouldBlock { return 0, err } } @@ -242,7 +269,17 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { } func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { - pt, err := r.nextRecord() + pt, err := r.nextRecord(false) + + // Consume the cached record if there was one + r.cachedRecord = nil + r.cachedError = nil + + return pt, err +} + +func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) { + pt, err := r.nextRecord(true) // Consume the cached record if there was one r.cachedRecord = nil @@ -251,10 +288,10 @@ func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { return pt, err } -func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { +func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) { cipher := r.cipher if r.cachedRecord != nil { - logf(logTypeIO, "Returning cached record") + logf(logTypeIO, "%s Returning cached record", r.label) return r.cachedRecord, r.cachedError } @@ -262,9 +299,10 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { // // 1. We get a frame // 2. We try to read off the socket and get nothing, in which case - // return WouldBlock + // returnAlertWouldBlock // 3. We get an error. - err := WouldBlock + var err error + err = AlertWouldBlock var header, body []byte for err != nil { @@ -272,24 +310,24 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen) n, err := r.conn.Read(buf) if err != nil { - logf(logTypeIO, "Error reading, %v", err) + logf(logTypeIO, "%s Error reading, %v", r.label, err) return nil, err } if n == 0 { - return nil, WouldBlock + return nil, AlertWouldBlock } - logf(logTypeIO, "Read %v bytes", n) + logf(logTypeIO, "%s Read %v bytes", r.label, n) buf = buf[:n] r.frame.addChunk(buf) } header, body, err = r.frame.process() - // Loop around on WouldBlock to see if some + // Loop around onAlertWouldBlock to see if some // data is now available. - if err != nil && err != WouldBlock { + if err != nil && err != AlertWouldBlock { return nil, err } } @@ -299,7 +337,7 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { switch RecordType(header[0]) { default: return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) - case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData: + case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck: pt.contentType = RecordType(header[0]) } @@ -318,28 +356,47 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { pt.fragment = make([]byte, size) copy(pt.fragment, body) + // TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data. + // Attempt to decrypt fragment - if cipher.cipher != nil { - seq := cipher.seq - if r.datagram { - seq = header[3:11] - } - // TODO(ekr@rtfm.com): Handle the wrong epoch. + seq := cipher.seq + if r.datagram { // TODO(ekr@rtfm.com): Handle duplicates. - logf(logTypeIO, "RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", cipher.epoch.label(), seq, pt.contentType, pt.fragment) + seq, _ = decodeUint(header[3:11], 8) + epoch := Epoch(seq >> 48) + + // Look up the cipher suite from the epoch + if epoch != cipher.epoch { + logf(logTypeIO, "%s Message from non-current epoch: [%v != %v]", r.label, epoch, + cipher.epoch) + if !allowOldEpoch { + return nil, AlertWouldBlock + } + c, ok := r.readCiphers[epoch] + if !ok { + logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch) + return nil, AlertWouldBlock + } + cipher = c + } + } + + if cipher.cipher != nil { + logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment) pt, _, err = r.decrypt(pt, seq) if err != nil { - logf(logTypeIO, "Decryption failed") + logf(logTypeIO, "%s Decryption failed", r.label) return nil, err } } + pt.epoch = cipher.epoch // Check that plaintext length is not too long if len(pt.fragment) > maxFragmentLen { return nil, fmt.Errorf("tls.record: Plaintext size too big") } - logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment) + logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment) r.cachedRecord = pt cipher.incrementSequenceNumber() @@ -355,10 +412,9 @@ func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error } func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error { - seq := cipher.formatSeq(r.datagram) - + seq := cipher.combineSeq(r.datagram) if cipher.cipher != nil { - logf(logTypeIO, "RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) + logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) pt = r.encrypt(cipher, seq, pt, padLen) } else if padLen > 0 { return fmt.Errorf("tls.record: Padding can only be done on encrypted records") @@ -376,16 +432,18 @@ func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherSta byte(r.version >> 8), byte(r.version & 0xff), byte(length >> 8), byte(length)} } else { + seqb := make([]byte, 8) + encodeUint(seq, 8, seqb) version := dtlsConvertVersion(r.version) header = []byte{byte(pt.contentType), byte(version >> 8), byte(version & 0xff), - seq[0], seq[1], seq[2], seq[3], - seq[4], seq[5], seq[6], seq[7], - byte(length >> 8), byte(length)} + } + header = append(header, seqb...) + header = append(header, byte(length>>8), byte(length)) } record := append(header, pt.fragment...) - logf(logTypeIO, "RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) + logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) cipher.incrementSequenceNumber() _, err := r.conn.Write(record) diff --git a/record-layer_test.go b/record-layer_test.go index 6d1ed1e..8696fb6 100644 --- a/record-layer_test.go +++ b/record-layer_test.go @@ -25,7 +25,7 @@ func TestRekey(t *testing.T) { key := unhex(keyHex) iv := unhex(ivHex) - r := NewRecordLayerTLS(bytes.NewBuffer(nil)) + r := NewRecordLayerTLS(bytes.NewBuffer(nil), directionWrite) err := r.Rekey(EpochApplicationData, newAESGCM, key, iv) assertNotError(t, err, "Failed to rekey") } @@ -33,7 +33,7 @@ func TestRekey(t *testing.T) { func TestSequenceNumberRollover(t *testing.T) { defer func() { r := recover() - assert(t, r != nil, "failed to panic on sequence number overflow") + assertTrue(t, r != nil, "failed to panic on sequence number overflow") }() key := unhex(keyHex) @@ -41,9 +41,7 @@ func TestSequenceNumberRollover(t *testing.T) { cs, err := newCipherStateAead(EpochApplicationData, newAESGCM, key, iv) assertNotError(t, err, "Couldn't create cipher state") - for i := 0; i < sequenceNumberLen; i++ { - cs.seq[cs.ivLength-i-1] = 0xFF - } + cs.seq = 1 << 48 cs.incrementSequenceNumber() } @@ -51,7 +49,7 @@ func TestReadRecord(t *testing.T) { plaintext := unhex(plaintextHex) // Test that a known-good frame decodes properly - r := NewRecordLayerTLS(bytes.NewBuffer(plaintext)) + r := NewRecordLayerTLS(bytes.NewBuffer(plaintext), directionRead) pt, err := r.ReadRecord() assertNotError(t, err, "Failed to decode valid plaintext") assertEquals(t, pt.contentType, RecordTypeAlert) @@ -59,7 +57,7 @@ func TestReadRecord(t *testing.T) { // Test failure on unkown record type plaintext[0] = 0xFF - r = NewRecordLayerTLS(bytes.NewBuffer(plaintext)) + r = NewRecordLayerTLS(bytes.NewBuffer(plaintext), directionRead) pt, err = r.ReadRecord() assertError(t, err, "Failed to reject record with unknown type") plaintext[0] = 0x15 @@ -68,7 +66,7 @@ func TestReadRecord(t *testing.T) { originalAllowWrongVersionNumber := allowWrongVersionNumber allowWrongVersionNumber = false plaintext[2] = 0x02 - r = NewRecordLayerTLS(bytes.NewBuffer(plaintext)) + r = NewRecordLayerTLS(bytes.NewBuffer(plaintext), directionRead) pt, err = r.ReadRecord() assertError(t, err, "Failed to reject record with incorrect version") plaintext[2] = 0x01 @@ -76,18 +74,18 @@ func TestReadRecord(t *testing.T) { // Test failure on size too big plaintext[3] = 0xFF - r = NewRecordLayerTLS(bytes.NewBuffer(plaintext)) + r = NewRecordLayerTLS(bytes.NewBuffer(plaintext), directionRead) pt, err = r.ReadRecord() assertError(t, err, "Failed to reject record exceeding size limit") plaintext[3] = 0x00 // Test failure on header read failure - r = NewRecordLayerTLS(bytes.NewBuffer(plaintext[:3])) + r = NewRecordLayerTLS(bytes.NewBuffer(plaintext[:3]), directionRead) pt, err = r.ReadRecord() assertError(t, err, "Didn't fail when unable to read header") // Test failure on body read failure - r = NewRecordLayerTLS(bytes.NewBuffer(plaintext[:7])) + r = NewRecordLayerTLS(bytes.NewBuffer(plaintext[:7]), directionRead) pt, err = r.ReadRecord() assertError(t, err, "Didn't fail when unable to read fragment") } @@ -101,7 +99,7 @@ func TestWriteRecord(t *testing.T) { fragment: plaintext[5:], } b := bytes.NewBuffer(nil) - r := NewRecordLayerTLS(b) + r := NewRecordLayerTLS(b, directionWrite) err := r.WriteRecord(pt) assertNotError(t, err, "Failed to write valid record") assertByteEquals(t, b.Bytes(), plaintext) @@ -131,7 +129,7 @@ func TestDecryptRecord(t *testing.T) { ciphertext2 := unhex(ciphertext2Hex) // Test successful decrypt - r := NewRecordLayerTLS(bytes.NewBuffer(ciphertext1)) + r := NewRecordLayerTLS(bytes.NewBuffer(ciphertext1), directionRead) r.Rekey(EpochApplicationData, newAESGCM, key, iv) pt, err := r.ReadRecord() assertNotError(t, err, "Failed to decrypt valid record") @@ -139,7 +137,7 @@ func TestDecryptRecord(t *testing.T) { assertByteEquals(t, pt.fragment, plaintext[5:]) // Test successful decrypt after sequence number change - r = NewRecordLayerTLS(bytes.NewBuffer(ciphertext2)) + r = NewRecordLayerTLS(bytes.NewBuffer(ciphertext2), directionRead) r.Rekey(EpochApplicationData, newAESGCM, key, iv) for i := 0; i < sequenceChange; i++ { r.cipher.incrementSequenceNumber() @@ -151,7 +149,7 @@ func TestDecryptRecord(t *testing.T) { // Test failure on decrypt failure ciphertext1[7] ^= 0xFF - r = NewRecordLayerTLS(bytes.NewBuffer(ciphertext1)) + r = NewRecordLayerTLS(bytes.NewBuffer(ciphertext1), directionRead) r.Rekey(EpochApplicationData, newAESGCM, key, iv) pt, err = r.ReadRecord() assertError(t, err, "Failed to reject invalid record") @@ -168,7 +166,7 @@ func TestEncryptRecord(t *testing.T) { // Test successful encrypt b := bytes.NewBuffer(nil) - r := NewRecordLayerTLS(b) + r := NewRecordLayerTLS(b, directionWrite) r.Rekey(EpochApplicationData, newAESGCM, key, iv) pt := &TLSPlaintext{ contentType: RecordType(plaintext[0]), @@ -180,7 +178,7 @@ func TestEncryptRecord(t *testing.T) { // Test successful encrypt with padding b.Truncate(0) - r = NewRecordLayerTLS(b) + r = NewRecordLayerTLS(b, directionWrite) r.Rekey(EpochApplicationData, newAESGCM, key, iv) pt = &TLSPlaintext{ contentType: RecordType(plaintext[0]), @@ -192,7 +190,7 @@ func TestEncryptRecord(t *testing.T) { // Test successful enc after sequence number change b.Truncate(0) - r = NewRecordLayerTLS(b) + r = NewRecordLayerTLS(b, directionWrite) r.Rekey(EpochApplicationData, newAESGCM, key, iv) for i := 0; i < sequenceChange; i++ { r.cipher.incrementSequenceNumber() @@ -207,7 +205,7 @@ func TestEncryptRecord(t *testing.T) { // Test failure on size too big after encrypt b.Truncate(0) - r = NewRecordLayerTLS(b) + r = NewRecordLayerTLS(b, directionWrite) r.Rekey(EpochApplicationData, newAESGCM, key, iv) pt = &TLSPlaintext{ contentType: RecordType(plaintext[0]), @@ -223,8 +221,8 @@ func TestReadWriteTLS(t *testing.T) { plaintext := unhex(plaintextHex) b := bytes.NewBuffer(nil) - out := NewRecordLayerTLS(b) - in := NewRecordLayerTLS(b) + out := NewRecordLayerTLS(b, directionWrite) + in := NewRecordLayerTLS(b, directionRead) // Unencrypted ptIn := &TLSPlaintext{ @@ -255,9 +253,9 @@ func TestReadWriteDTLS(t *testing.T) { plaintext := unhex(plaintextHex) b := bytes.NewBuffer(nil) - out := NewRecordLayerDTLS(b) + out := NewRecordLayerDTLS(b, directionWrite) out.SetVersion(tls12Version) - in := NewRecordLayerDTLS(b) + in := NewRecordLayerDTLS(b, directionRead) in.SetVersion(tls12Version) // Unencrypted @@ -307,7 +305,7 @@ func TestOverSocket(t *testing.T) { assertNotError(t, err, "Unable to accept") defer conn.Close() - in := NewRecordLayerTLS(conn) + in := NewRecordLayerTLS(conn, directionRead) in.Rekey(EpochApplicationData, newAESGCM, key, iv) pt, err := in.ReadRecord() assertNotError(t, err, "Unable to read record") @@ -319,7 +317,7 @@ func TestOverSocket(t *testing.T) { conn, err := net.Dial("tcp", port) assertNotError(t, err, "Unable to dial") - out := NewRecordLayerTLS(conn) + out := NewRecordLayerTLS(conn, directionWrite) out.Rekey(EpochApplicationData, newAESGCM, key, iv) err = out.WriteRecord(&ptIn) assertNotError(t, err, "Unable to write record") @@ -355,10 +353,10 @@ func TestNonblockingRecord(t *testing.T) { // Add the prefix, which should cause blocking. b := bytes.NewBuffer(ciphertext1[:1]) - r := NewRecordLayerTLS(&NoEofReader{b}) + r := NewRecordLayerTLS(&NoEofReader{b}, directionRead) r.Rekey(EpochApplicationData, newAESGCM, key, iv) pt, err := r.ReadRecord() - assertEquals(t, err, WouldBlock) + assertEquals(t, err, AlertWouldBlock) // Now the rest of the record, which lets us decrypt it b.Write(ciphertext1[1:]) diff --git a/server-state-machine.go b/server-state-machine.go index 0b851f4..d7d987d 100644 --- a/server-state-machine.go +++ b/server-state-machine.go @@ -74,7 +74,7 @@ type cookie struct { type serverStateStart struct { Config *Config conn *Conn - hsCtx HandshakeContext + hsCtx *HandshakeContext } var _ HandshakeState = &serverStateStart{} @@ -361,7 +361,7 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ // Figure out if we're going to do early data var clientEarlyTrafficSecret []byte connParams.ClientSendingEarlyData = foundExts[ExtensionTypeEarlyData] - connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData) + connParams.UsingEarlyData, connParams.RejectedEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData) if connParams.UsingEarlyData { h := params.Hash.New() h.Write(clientHello.Marshal()) @@ -379,6 +379,8 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ return nil, nil, AlertNoApplicationProtocol } + state.hsCtx.receivedEndOfFlight() + logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2. return serverStateNegotiated{ @@ -445,7 +447,7 @@ func (state *serverStateStart) generateHRR(cs CipherSuite, legacySessionId []byt type serverStateNegotiated struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext dhGroup NamedGroup dhPublic []byte dhSecret []byte @@ -731,7 +733,6 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat } toSend = append(toSend, []HandshakeAction{ RekeyIn{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, - ReadEarlyData{}, }...) return nextState, toSend, AlertNoAlert } @@ -739,9 +740,9 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]") toSend = append(toSend, []HandshakeAction{ RekeyIn{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}, - ReadPastEarlyData{}, }...) - waitFlight2 := serverStateWaitFlight2{ + var nextState HandshakeState + nextState = serverStateWaitFlight2{ Config: state.Config, Params: state.Params, hsCtx: state.hsCtx, @@ -753,13 +754,19 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat serverTrafficSecret: serverTrafficSecret, exporterSecret: exporterSecret, } - return waitFlight2, toSend, AlertNoAlert + if state.Params.RejectedEarlyData { + nextState = serverStateReadPastEarlyData{ + state.hsCtx, + &nextState, + } + } + return nextState, toSend, AlertNoAlert } type serverStateWaitEOED struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -776,6 +783,37 @@ func (state serverStateWaitEOED) State() State { } func (state serverStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + for { + logf(logTypeHandshake, "Server reading early data...") + t, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking) + if err == AlertWouldBlock { + return nil, nil, AlertWouldBlock + } + + if err != nil { + logf(logTypeHandshake, "Server Error reading record type (1): %v", err) + return nil, nil, AlertBadRecordMAC + } + + logf(logTypeHandshake, "Server got record type(1): %v", t) + + if t != RecordTypeApplicationData { + break + } + + // Read a record into the buffer. Note that this is safe + // in blocking mode because we read the record in in + // PeekRecordType. + pt, err := state.hsCtx.hIn.conn.ReadRecord() + if err != nil { + logf(logTypeHandshake, "Server error reading early data record: %v", err) + return nil, nil, AlertInternalError + } + + logf(logTypeHandshake, "Server read early data: %x", pt.fragment) + state.hsCtx.earlyData = append(state.hsCtx.earlyData, pt.fragment...) + } + hm, alert := hr.ReadMessage() if alert != AlertNoAlert { return nil, nil, alert @@ -813,10 +851,44 @@ func (state serverStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState return waitFlight2, toSend, AlertNoAlert } +var _ HandshakeState = &serverStateReadPastEarlyData{} + +type serverStateReadPastEarlyData struct { + hsCtx *HandshakeContext + next *HandshakeState +} + +func (state serverStateReadPastEarlyData) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + for { + logf(logTypeHandshake, "Server reading past early data...") + // Scan past all records that fail to decrypt + _, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking) + if err == nil { + break + } + + if err == AlertWouldBlock { + return nil, nil, AlertWouldBlock + } + + // Continue on DecryptError + _, ok := err.(DecryptError) + if !ok { + return nil, nil, AlertInternalError // Really need something else. + } + } + + return *state.next, nil, AlertNoAlert +} + +func (state serverStateReadPastEarlyData) State() State { + return StateServerReadPastEarlyData +} + type serverStateWaitFlight2 struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -868,7 +940,7 @@ func (state serverStateWaitFlight2) Next(_ handshakeMessageReader) (HandshakeSta type serverStateWaitCert struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -940,7 +1012,7 @@ func (state serverStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState type serverStateWaitCV struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte @@ -1023,7 +1095,7 @@ func (state serverStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, type serverStateWaitFinished struct { Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte @@ -1082,6 +1154,8 @@ func (state serverStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS // Compute client traffic keys clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + state.hsCtx.receivedFinalFlight() + logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]") nextState := stateConnected{ Params: state.Params, diff --git a/state-machine.go b/state-machine.go index 7639c5f..2532262 100644 --- a/state-machine.go +++ b/state-machine.go @@ -17,10 +17,6 @@ type SendQueuedHandshake struct{} type SendEarlyData struct{} -type ReadEarlyData struct{} - -type ReadPastEarlyData struct{} - type RekeyIn struct { epoch Epoch KeySet keySet @@ -50,7 +46,6 @@ type AppExtensionHandler interface { type ConnectionOptions struct { ServerName string NextProtos []string - EarlyData []byte } // ConnectionParameters objects represent the parameters negotiated for a @@ -60,6 +55,7 @@ type ConnectionParameters struct { UsingDH bool ClientSendingEarlyData bool UsingEarlyData bool + RejectedEarlyData bool UsingClientAuth bool CipherSuite CipherSuite @@ -69,7 +65,13 @@ type ConnectionParameters struct { // Working state for the handshake. type HandshakeContext struct { - hIn, hOut *HandshakeLayer + timeoutMS uint32 + timers *timerSet + recvdRecords []uint64 + sentFragments []*SentHandshakeFragment + hIn, hOut *HandshakeLayer + waitingNextFlight bool + earlyData []byte } func (hc *HandshakeContext) SetVersion(version uint16) { @@ -84,7 +86,7 @@ func (hc *HandshakeContext) SetVersion(version uint16) { // stateConnected is symmetric between client and server type stateConnected struct { Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext isClient bool cryptoParams CipherSuiteParams resumptionSecret []byte diff --git a/state-machine_test.go b/state-machine_test.go index ae59c78..e2000e0 100644 --- a/state-machine_test.go +++ b/state-machine_test.go @@ -173,6 +173,8 @@ func TestStateMachineIntegration(t *testing.T) { }, }, + /* Commented out because PeekRecordType() not available without a record layer + // PSK case, with early data "pskWithEarlyData": { clientConfig: &Config{ @@ -218,6 +220,7 @@ func TestStateMachineIntegration(t *testing.T) { }, }, + */ // PSK case, server rejects PSK "pskRejected": { clientConfig: &Config{ @@ -362,9 +365,9 @@ func TestStateMachineIntegration(t *testing.T) { clientState = clientStateStart{ Config: params.clientConfig, Opts: params.clientOptions, - hsCtx: chsCtx, + hsCtx: &chsCtx, } - serverState = serverStateStart{Config: params.serverConfig, hsCtx: shsCtx} + serverState = serverStateStart{Config: params.serverConfig, hsCtx: &shsCtx} t.Logf("Client: %s", reflect.TypeOf(clientState).Name()) t.Logf("Server: %s", reflect.TypeOf(serverState).Name()) @@ -399,7 +402,7 @@ func TestStateMachineIntegration(t *testing.T) { } serverState = nextState serverResponses := messagesFromActions(serverInstr) - assert(t, alert == AlertNoAlert || alert == AlertStatelessRetry, fmt.Sprintf("Alert from server [%v]", alert)) + assertTrue(t, alert == AlertNoAlert || alert == AlertStatelessRetry, fmt.Sprintf("Alert from server [%v]", alert)) serverStateSequence = append(serverStateSequence, serverState) t.Logf("Server: %s", reflect.TypeOf(serverState).Name()) clientHandshakeMessageReader.queue = append(clientHandshakeMessageReader.queue, serverResponses...) @@ -417,7 +420,7 @@ func TestStateMachineIntegration(t *testing.T) { } clientState = nextState clientResponses := messagesFromActions(clientInstr) - assert(t, alert == AlertNoAlert, fmt.Sprintf("Alert from client [%v]", alert)) + assertTrue(t, alert == AlertNoAlert, fmt.Sprintf("Alert from client [%v]", alert)) clientStateSequence = append(clientStateSequence, clientState) t.Logf("Client: %s", reflect.TypeOf(clientState).Name()) serverHandshakeMessageReader.queue = append(serverHandshakeMessageReader.queue, clientResponses...)