From 4bfc5e845444e1690d505b6f70a4eb54e2d36cd1 Mon Sep 17 00:00:00 2001 From: Simon Ott Date: Mon, 18 Mar 2024 17:47:55 +0000 Subject: [PATCH] attestedtls: send attestation report asynchronously If the attestation report is too large, both the TLS dialer and listener block, as the attestation reports are exchanged at the same time. Send them asynchronously to avoid this. Signed-off-by: Simon Ott --- attestationreport/validationreport.go | 2 +- attestedtls/attestation.go | 76 +++++++++++++++++++-------- attestedtls/backend.go | 49 +++++++++++------ 3 files changed, 87 insertions(+), 40 deletions(-) diff --git a/attestationreport/validationreport.go b/attestationreport/validationreport.go index 5cb7a826..2483d557 100644 --- a/attestationreport/validationreport.go +++ b/attestationreport/validationreport.go @@ -707,7 +707,7 @@ func (r *VerificationResult) PrintErr() { if !a.Success { details := "" if a.Pcr != nil { - details = "PCR%v" + details = fmt.Sprintf("PCR%v", *a.Pcr) } log.Warnf("%v Measurement %v: %v verification failed", details, a.Name, a.Digest) } diff --git a/attestedtls/attestation.go b/attestedtls/attestation.go index 509e3de3..982cd850 100644 --- a/attestedtls/attestation.go +++ b/attestedtls/attestation.go @@ -27,6 +27,7 @@ var id = "0000" var log = logrus.WithField("service", "atls") func attestDialer(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { + ch := make(chan error) //optional: attest Client if cc.Attest == Attest_Mutual || cc.Attest == Attest_Client { @@ -38,13 +39,17 @@ func attestDialer(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { } // Send created attestation report to listener - log.Tracef("Sending attestation report length %v to listener", len(resp)) - - err = Write(append([]byte{byte(cc.Attest)}, resp...), conn) - if err != nil { - return fmt.Errorf("failed to send AR: %w", err) - } - log.Trace("Sent AR") + log.Tracef("Dialer: sending attestation report length %v to listener %v", + len(resp), conn.RemoteAddr().String()) + + go func() { + err = Write(append([]byte{byte(cc.Attest)}, resp...), conn) + if err != nil { + ch <- fmt.Errorf("failed to send AR to listener: %w", err) + } + log.Trace("Finished asynchronous sending of attestation report to listener") + ch <- nil + }() } else { //if not sending attestation report, send the attestation mode err := Write([]byte{byte(cc.Attest)}, conn) @@ -54,14 +59,14 @@ func attestDialer(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { log.Debug("Skipping client-side attestation") } - readvalue, err := readValue(conn, cc.Attest, true) + // Fetch attestation report from listener + report, err := readValue(conn, cc.Attest) if err != nil { return err } //optional: Wait for attestation report from Server if cc.Attest == Attest_Mutual || cc.Attest == Attest_Server { - report := readvalue // Verify AR from listener with own channel bindings log.Trace("Verifying attestation report from listener") err = cc.CmcApi.verifyAR(chbindings, report, cc) @@ -72,12 +77,22 @@ func attestDialer(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { log.Debug("Skipping client-side verification") } + // Finally check if asynchronous sending succeeded + if cc.Attest == Attest_Mutual || cc.Attest == Attest_Client { + err = <-ch + if err != nil { + return fmt.Errorf("failed to write asynchronously: %w", err) + } + } + log.Trace("Attestation successful") return nil } func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { + ch := make(chan error) + // optional: attest server if cc.Attest == Attest_Mutual || cc.Attest == Attest_Server { // Obtain own attestation report from local cmcd @@ -87,12 +102,19 @@ func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { return fmt.Errorf("could not obtain AR of Listener : %w", err) } - // Send own attestation report to dialer - log.Trace("Sending own attestation report") - err = Write(append([]byte{byte(cc.Attest)}, resp...), conn) - if err != nil { - return fmt.Errorf("failed to send AR: %w", err) - } + // Send own attestation report to dialer. This is done asynchronously to + // avoid blocking if each side sends a large report at the same time + log.Tracef("Listener: Sending attestation report length %v to dialer %v", + len(resp), conn.RemoteAddr().String()) + + go func() { + err = Write(append([]byte{byte(cc.Attest)}, resp...), conn) + if err != nil { + ch <- fmt.Errorf("failed to send AR to dialer: %w", err) + } + ch <- nil + log.Trace("Finished asynchronous sending of attestation report to dialer") + }() } else { //if not sending attestation report, send the attestation mode err := Write([]byte{byte(cc.Attest)}, conn) @@ -102,7 +124,7 @@ func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { log.Debug("Skipping server-side attestation") } - report, err := readValue(conn, cc.Attest, false) + report, err := readValue(conn, cc.Attest) if err != nil { return err } @@ -119,20 +141,28 @@ func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { log.Debug("Skipping server-side verification") } + // Finally check if asynchronous sending succeeded + if cc.Attest == Attest_Mutual || cc.Attest == Attest_Server { + err = <-ch + if err != nil { + return fmt.Errorf("failed to write asynchronously: %w", err) + } + } + log.Trace("Attestation successful") return nil } -func readValue(conn *tls.Conn, selection AttestSelect, dialer bool) ([]byte, error) { +func readValue(conn *tls.Conn, selection AttestSelect) ([]byte, error) { readvalue, err := Read(conn) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } - selectionStr, errS1 := selectionString(byte(selection)) - if errS1 != nil { - return nil, errS1 + selectionStr, err := selectionString(byte(selection)) + if err != nil { + return nil, err } // the first byte should always be the attestation mode @@ -140,9 +170,9 @@ func readValue(conn *tls.Conn, selection AttestSelect, dialer bool) ([]byte, err log.Debugf("Matching attestation mode: [%v]", selectionStr) } else { reportByte := readvalue[0] - reportStr, errS1 := selectionString(reportByte) - if errS1 != nil { - return nil, errS1 + reportStr, err := selectionString(reportByte) + if err != nil { + return nil, err } return nil, fmt.Errorf("mismatching attestation mode, local set to: [%v], while remote is set to: [%v]", selectionStr, reportStr) } diff --git a/attestedtls/backend.go b/attestedtls/backend.go index 53c789f9..14095318 100644 --- a/attestedtls/backend.go +++ b/attestedtls/backend.go @@ -16,6 +16,7 @@ package attestedtls import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -23,20 +24,36 @@ import ( "time" ) -// Writes byte array to provided channel by first sending length information, then data +// Writes byte array to provided channel by first sending length information, then data. // Used for transmitting the attestation reports between peers func Write(msg []byte, c net.Conn) error { + + length := len(msg) lenbuf := make([]byte, 4) - binary.BigEndian.PutUint32(lenbuf, uint32(len(msg))) + binary.BigEndian.PutUint32(lenbuf, uint32(length)) - buf := append(lenbuf, msg...) + n, err := c.Write(lenbuf) + if err != nil { + return fmt.Errorf("failed to write length to %v: %w", c.RemoteAddr().String(), err) + } + if n != len(lenbuf) { + return fmt.Errorf("could only send %v of %v bytes to %v", n, len(lenbuf), + c.RemoteAddr().String()) + } - _, err := c.Write(buf) + n, err = c.Write(msg) + if err != nil { + return fmt.Errorf("failed to write payload to %v: %w", c.RemoteAddr().String(), err) + } + if n != len(msg) { + return fmt.Errorf("could only send %v of %v bytes to %v", n, len(msg), + c.RemoteAddr().String()) + } return err } -// Receives byte array from provided channel by first receiving length information, then data +// Receives byte array from provided channel by first receiving length information, then data. // Used for transmitting the attestation reports between peers func Read(c net.Conn) ([]byte, error) { start := time.Now() @@ -48,30 +65,30 @@ func Read(c net.Conn) ([]byte, error) { return nil, fmt.Errorf("failed to receive message: no length: %v", err) } - len := binary.BigEndian.Uint32(lenbuf) // Max size of 4GB - log.Trace("TCP Message Length: ", len) + len := int(binary.BigEndian.Uint32(lenbuf)) // Max size of 4GB + log.Tracef("TCP Message to be received: %v", len) if len == 0 { return nil, errors.New("message length is zero") } - // Receive data in chunks of 1024 bytes as the Read function receives a maxium of 65536 bytes + // Receive data in chunks of 1024 bytes as the Read function receives a maxium of 64K bytes // and the buffer must be longer, then append it to the final buffer - tmpbuf := make([]byte, 1024) - buf := make([]byte, 0) - rcvlen := uint32(0) + buf := bytes.NewBuffer(nil) + received := 0 for { - n, err := c.Read(tmpbuf) - rcvlen += uint32(n) + chunk := make([]byte, 64*1024) + n, err := c.Read(chunk) + received += n if err != nil { return nil, fmt.Errorf("failed to receive message: %w", err) } - buf = append(buf, tmpbuf[:n]...) + buf.Write(chunk[:n]) // Abort as soon as we have read the expected data as signaled in the first 4 bytes // of the message - if rcvlen == len { + if received == len { log.Trace("Received message") break } @@ -81,5 +98,5 @@ func Read(c net.Conn) ([]byte, error) { break } } - return buf, nil + return buf.Bytes(), nil }