Skip to content

Commit

Permalink
Multiple enhancements to support DTLS.
Browse files Browse the repository at this point in the history
    This still isn't complete but it's big and I wanted to get it
    landed.

    - Support for timeout and retransmission
    - Support for ACKs
    - Fixed 0-RTT support in nonblocking mode
    - Updated the test system to support parametrized tests and
      used that to simplify and expand the connection tests

    I also reworked the 0-RTT API to support streaming

    - When you are a client, and you are in 0-RTT mode, you can
      Write() when the handshake is not complete
    - When you are a server and have 0-RTT enabled() you can Read()
      at any time. Prior to handshake completion, this reads out of
      the 0-RTT buffer but doesn't cause a network read. After
      handshake completion, it just does a normal read (with the
      0-RTT data buffer having been merged into the main buffer).
      This has the odd side effect that you can only read 0-RTT
      data off the network by doing Handshake(), so the way you
      drive the server is to do:

      for {
         server.Handshake()
         if server connected {
            break
         }
         server.Read() // 0-RTT data
      }

    There are still a number of defects that make DTLS not ready for
    prime time.

    - No timer backoff
    - No MTU backoff
    - I don't properly clean up the out-of-epoch cipher suites
      [This may be serious]
    - Finished isn't triggered properly at the end of handshake
      if loss occurs
    - There are way too few tests

    There are probably also a pile of bugs I don't know about.
  • Loading branch information
ekr committed Feb 23, 2018
1 parent 340be3a commit 1080921
Show file tree
Hide file tree
Showing 20 changed files with 1,399 additions and 468 deletions.
57 changes: 39 additions & 18 deletions client-state-machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type clientStateStart struct {
cookie []byte
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
hsCtx HandshakeContext
hsCtx *HandshakeContext
}

var _ HandshakeState = &clientStateStart{}
Expand Down Expand Up @@ -172,8 +172,10 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
}
ch.CipherSuites = compatibleSuites

// TODO([email protected]): Check that the ticket can be used for early
// data.
// Signal early data if we're going to do it
if len(state.Opts.EarlyData) > 0 {
if state.Config.AllowEarlyData {
state.Params.ClientSendingEarlyData = true
ed = &EarlyDataExtension{}
err = ch.Extensions.Add(ed)
Expand Down Expand Up @@ -255,9 +257,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
} else if len(state.Opts.EarlyData) > 0 {
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
return nil, nil, AlertInternalError
} else {
clientHello, err = state.hsCtx.hOut.HandshakeMessageFromBody(ch)
if err != nil {
Expand Down Expand Up @@ -291,7 +290,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
if state.Params.ClientSendingEarlyData {
toSend = append(toSend, []HandshakeAction{
RekeyOut{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys},
SendEarlyData{},
}...)
}

Expand All @@ -302,7 +300,7 @@ type clientStateWaitSH struct {
Config *Config
Opts ConnectionOptions
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
OfferedDH map[NamedGroup][]byte
OfferedPSK PreSharedKey
PSK []byte
Expand Down Expand Up @@ -412,6 +410,11 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
body: h.Sum(nil),
}

state.hsCtx.receivedEndOfFlight()

// TODO([email protected]): Need to rekey with cleartext if we are on 0-RTT
// mode. In DTLS, we also need to bump the sequence number.
// This is a pre-existing defect in Mint. Issue #175.
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
return clientStateStart{
Config: state.Config,
Expand Down Expand Up @@ -515,7 +518,6 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)

serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)

logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
nextState := clientStateWaitEE{
Config: state.Config,
Expand All @@ -530,13 +532,20 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
toSend := []HandshakeAction{
RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys},
}
// We're definitely not going to have to send anything with
// early data.
if !state.Params.ClientSendingEarlyData {
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
KeySet: makeTrafficKeys(params, clientHandshakeTrafficSecret)})
}

return nextState, toSend, AlertNoAlert
}

type clientStateWaitEE struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
masterSecret []byte
Expand Down Expand Up @@ -596,6 +605,14 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,

state.handshakeHash.Write(hm.Marshal())

toSend := []HandshakeAction{}

if state.Params.ClientSendingEarlyData && !state.Params.UsingEarlyData {
// We didn't get 0-RTT, so rekey to handshake.
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)})
}

if state.Params.UsingPSK {
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
nextState := clientStateWaitFinished{
Expand All @@ -608,7 +625,7 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
return nextState, toSend, AlertNoAlert
}

logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
Expand All @@ -622,13 +639,13 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
return nextState, toSend, AlertNoAlert
}

type clientStateWaitCertCR struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
masterSecret []byte
Expand Down Expand Up @@ -706,7 +723,7 @@ func (state clientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeSta
type clientStateWaitCert struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash

Expand Down Expand Up @@ -760,7 +777,7 @@ func (state clientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
type clientStateWaitCV struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash

Expand Down Expand Up @@ -861,7 +878,7 @@ func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,

type clientStateWaitFinished struct {
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash

Expand Down Expand Up @@ -933,6 +950,7 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
toSend := []HandshakeAction{}

if state.Params.UsingEarlyData {
logf(logTypeHandshake, "Sending end of early data")
// Note: We only send EOED if the server is actually going to use the early
// data. Otherwise, it will never see it, and the transcripts will
// mismatch.
Expand All @@ -942,10 +960,11 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS

state.handshakeHash.Write(eoedm.Marshal())
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
}

clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys})
// And then rekey to handshake
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)})
}

if state.Params.UsingClientAuth {
// Extract constraints from certicateRequest
Expand Down Expand Up @@ -1045,6 +1064,8 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
RekeyOut{epoch: EpochApplicationData, KeySet: clientTrafficKeys},
}...)

state.hsCtx.receivedEndOfFlight()

logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
nextState := stateConnected{
Params: state.Params,
Expand Down
12 changes: 12 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
RecordTypeAlert RecordType = 21
RecordTypeHandshake RecordType = 22
RecordTypeApplicationData RecordType = 23
RecordTypeAck RecordType = 25
)

// enum {...} HandshakeType;
Expand Down Expand Up @@ -166,6 +167,8 @@ const (
type State uint8

const (
StateInit = 0

// states valid for the client
StateClientStart State = iota
StateClientWaitSH
Expand All @@ -179,6 +182,7 @@ const (
StateServerStart State = iota
StateServerRecvdCH
StateServerNegotiated
StateServerReadPastEarlyData
StateServerWaitEOED
StateServerWaitFlight2
StateServerWaitCert
Expand Down Expand Up @@ -209,6 +213,8 @@ func (s State) String() string {
return "Server RECVD_CH"
case StateServerNegotiated:
return "Server NEGOTIATED"
case StateServerReadPastEarlyData:
return "Server READ_PAST_EARLY_DATA"
case StateServerWaitEOED:
return "Server WAIT_EOED"
case StateServerWaitFlight2:
Expand Down Expand Up @@ -250,3 +256,9 @@ func (e Epoch) label() string {
}
return "Application data (updated)"
}

func assert(b bool) {
if !b {
panic("Assertion failed")
}
}
89 changes: 79 additions & 10 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func unhex(h string) []byte {
return b
}

func assert(t *testing.T, test bool, msg string) {
func assertTrue(t *testing.T, test bool, msg string) {
t.Helper()
prefix := string("")
for i := 1; ; i++ {
Expand All @@ -34,40 +34,40 @@ func assert(t *testing.T, test bool, msg string) {

func assertError(t *testing.T, err error, msg string) {
t.Helper()
assert(t, err != nil, msg)
assertTrue(t, err != nil, msg)
}

func assertNotError(t *testing.T, err error, msg string) {
t.Helper()
if err != nil {
msg += ": " + err.Error()
}
assert(t, err == nil, msg)
assertTrue(t, err == nil, msg)
}

func assertNil(t *testing.T, x interface{}, msg string) {
t.Helper()
assert(t, x == nil, msg)
assertTrue(t, x == nil, msg)
}

func assertNotNil(t *testing.T, x interface{}, msg string) {
t.Helper()
assert(t, x != nil, msg)
assertTrue(t, x != nil, msg)
}

func assertEquals(t *testing.T, a, b interface{}) {
t.Helper()
assert(t, a == b, fmt.Sprintf("%+v != %+v", a, b))
assertTrue(t, a == b, fmt.Sprintf("%+v != %+v", a, b))
}

func assertByteEquals(t *testing.T, a, b []byte) {
t.Helper()
assert(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
assertTrue(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
}

func assertNotByteEquals(t *testing.T, a, b []byte) {
t.Helper()
assert(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
assertTrue(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
}

func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) {
Expand All @@ -81,12 +81,81 @@ func assertCipherSuiteParamsEquals(t *testing.T, a, b CipherSuiteParams) {

func assertDeepEquals(t *testing.T, a, b interface{}) {
t.Helper()
assert(t, reflect.DeepEqual(a, b), fmt.Sprintf("%+v != %+v", a, b))
assertTrue(t, reflect.DeepEqual(a, b), fmt.Sprintf("%+v != %+v", a, b))
}

func assertSameType(t *testing.T, a, b interface{}) {
t.Helper()
A := reflect.TypeOf(a)
B := reflect.TypeOf(b)
assert(t, A == B, fmt.Sprintf("%s != %s", A.Name(), B.Name()))
assertTrue(t, A == B, fmt.Sprintf("%s != %s", A.Name(), B.Name()))
}

// Utilities for parametrized tests
// Represents the configuration for a given test instance.
type testInstanceState interface {
set(string, string)
get(string) string
}

// This complicated mixin stuff is in case we want to extend this to
// pass other parameters that aren't srings. Arguably YAGNI.
type testInstanceStateMixin struct {
params map[string]string
}

func (m *testInstanceStateMixin) set(k string, v string) {
m.params[k] = v
}

func (m *testInstanceStateMixin) get(k string) string {
return m.params[k]
}

func newTestInstanceStateMixin() testInstanceStateMixin {
return testInstanceStateMixin{make(map[string]string, 0)}
}

type testInstanceStateBase struct {
testInstanceStateMixin
}

func newTestInstanceStateBase() testInstanceState {
return &testInstanceStateBase{newTestInstanceStateMixin()}
}

// Helper function.
func runParametrizedInner(t *testing.T, name string, state testInstanceState, inparams []testParameter, f parametrizedTest) {
next := inparams[1:]

param := inparams[0]
for _, p := range param.vals {
state.set(param.name, p)
var n string
if len(name) > 0 {
n = name + "/"
}
n = n + param.name + "=" + p

if len(next) == 0 {
t.Run(n, func(t *testing.T) {
f(t, n, state)
})
continue
}
runParametrizedInner(t, n, state, next, f)
}
}

// Nominally public API.
type testParameter struct {
name string
vals []string
}

type parametrizedTest func(t *testing.T, name string, p testInstanceState)

// This is the function you call.
func runParametrizedTest(t *testing.T, inparams []testParameter, f parametrizedTest) {
runParametrizedInner(t, "", newTestInstanceStateBase(), inparams, f)
}
Loading

0 comments on commit 1080921

Please sign in to comment.