diff --git a/attestedtls/attestation.go b/attestedtls/attestation.go index c8d55af..8c2bc5d 100644 --- a/attestedtls/attestation.go +++ b/attestedtls/attestation.go @@ -65,6 +65,8 @@ func atlsHandshakeStart(conn *tls.Conn, chbindings []byte, cc CmcConfig, endpoin } // Wait for attestation request to be received + log.Debugf("Prover %v: waiting for atls handshake request mode %v from %v", + ownAddr, cc.Attest.String(), peerAddr) req, err := receiveAtlsRequest(conn) if err != nil { return fmt.Errorf("prover %v: failed to receive attestation request from %v: %w", @@ -98,16 +100,15 @@ func atlsHandshakeStart(conn *tls.Conn, chbindings []byte, cc CmcConfig, endpoin Error: reportErr, } - // Send created atls handshake response to listener + // Send created atls handshake response to listener asynchronously to avoid TCP deadlock log.Debugf("Prover %v: sending atls handshake response with report length %v to %v", ownAddr, len(ownReport), peerAddr) - - err = sendAtlsResponse(conn, report) - if err != nil { - return fmt.Errorf("prover %v: failed to send AR to %v: %w", ownAddr, peerAddr, err) - } - log.Debugf("Prover %v: finished sending atls handshake response to %v", - ownAddr, peerAddr) + errChan := make(chan error) + go func() { + err := sendAtlsResponse(conn, report) + errChan <- err + close(errChan) + }() log.Debugf("Verifier %v: waiting for atls handshake response from %v", ownAddr, peerAddr) resp, err := receiveAtlsResponse(conn) @@ -122,6 +123,15 @@ func atlsHandshakeStart(conn *tls.Conn, chbindings []byte, cc CmcConfig, endpoin log.Debugf("Verifier %v: received atls handshake response mode %v from %v", ownAddr, resp.Attest.String(), peerAddr) + // Wait until sending of own handshake response is finished + log.Debugf("Prover %v: Waiting for atls handshake response sending to be completed", ownAddr) + err = <-errChan + if err != nil { + return fmt.Errorf("prover %v: failed to send atls handshake response: %w", ownAddr, err) + } + log.Debugf("Prover %v: finished sending atls handshake response to %v", + ownAddr, peerAddr) + // Check that configured attestation mode matches peers attestation mode err = checkAttestationMode(cc.Attest, req.Attest) if err != nil { @@ -152,12 +162,17 @@ func aTlsHandshakeComplete(conn *tls.Conn, handshakeError error) error { if handshakeError != nil { complete.Error = handshakeError.Error() } + + log.Debugf("Prover %v: sending atls handshake complete to %v", + conn.LocalAddr().String(), conn.RemoteAddr().String()) err := sendAtlsComplete(conn, complete) if err != nil { return fmt.Errorf("%v: failed to send handshake complete to %v: %w", conn.LocalAddr().String(), conn.RemoteAddr().String(), err) } + log.Debugf("Verifier %v: waiting for atls handshake complete from %v", + conn.LocalAddr().String(), conn.RemoteAddr().String()) resp, err := receiveAtlsComplete(conn) if err != nil { return fmt.Errorf("%v: failed to receive handshake complete from %v: %w",