diff --git a/attestationreport/validationreport.go b/attestationreport/validationreport.go index 5cb7a826..2483d557 100644 --- a/attestationreport/validationreport.go +++ b/attestationreport/validationreport.go @@ -707,7 +707,7 @@ func (r *VerificationResult) PrintErr() { if !a.Success { details := "" if a.Pcr != nil { - details = "PCR%v" + details = fmt.Sprintf("PCR%v", *a.Pcr) } log.Warnf("%v Measurement %v: %v verification failed", details, a.Name, a.Digest) } diff --git a/attestedtls/attestation.go b/attestedtls/attestation.go index 509e3de3..982cd850 100644 --- a/attestedtls/attestation.go +++ b/attestedtls/attestation.go @@ -27,6 +27,7 @@ var id = "0000" var log = logrus.WithField("service", "atls") func attestDialer(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { + ch := make(chan error) //optional: attest Client if cc.Attest == Attest_Mutual || cc.Attest == Attest_Client { @@ -38,13 +39,17 @@ func attestDialer(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { } // Send created attestation report to listener - log.Tracef("Sending attestation report length %v to listener", len(resp)) - - err = Write(append([]byte{byte(cc.Attest)}, resp...), conn) - if err != nil { - return fmt.Errorf("failed to send AR: %w", err) - } - log.Trace("Sent AR") + log.Tracef("Dialer: sending attestation report length %v to listener %v", + len(resp), conn.RemoteAddr().String()) + + go func() { + err = Write(append([]byte{byte(cc.Attest)}, resp...), conn) + if err != nil { + ch <- fmt.Errorf("failed to send AR to listener: %w", err) + } + log.Trace("Finished asynchronous sending of attestation report to listener") + ch <- nil + }() } else { //if not sending attestation report, send the attestation mode err := Write([]byte{byte(cc.Attest)}, conn) @@ -54,14 +59,14 @@ func attestDialer(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { log.Debug("Skipping client-side attestation") } - readvalue, err := readValue(conn, cc.Attest, true) + // Fetch attestation report from listener + report, err := readValue(conn, cc.Attest) if err != nil { return err } //optional: Wait for attestation report from Server if cc.Attest == Attest_Mutual || cc.Attest == Attest_Server { - report := readvalue // Verify AR from listener with own channel bindings log.Trace("Verifying attestation report from listener") err = cc.CmcApi.verifyAR(chbindings, report, cc) @@ -72,12 +77,22 @@ func attestDialer(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { log.Debug("Skipping client-side verification") } + // Finally check if asynchronous sending succeeded + if cc.Attest == Attest_Mutual || cc.Attest == Attest_Client { + err = <-ch + if err != nil { + return fmt.Errorf("failed to write asynchronously: %w", err) + } + } + log.Trace("Attestation successful") return nil } func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { + ch := make(chan error) + // optional: attest server if cc.Attest == Attest_Mutual || cc.Attest == Attest_Server { // Obtain own attestation report from local cmcd @@ -87,12 +102,19 @@ func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { return fmt.Errorf("could not obtain AR of Listener : %w", err) } - // Send own attestation report to dialer - log.Trace("Sending own attestation report") - err = Write(append([]byte{byte(cc.Attest)}, resp...), conn) - if err != nil { - return fmt.Errorf("failed to send AR: %w", err) - } + // Send own attestation report to dialer. This is done asynchronously to + // avoid blocking if each side sends a large report at the same time + log.Tracef("Listener: Sending attestation report length %v to dialer %v", + len(resp), conn.RemoteAddr().String()) + + go func() { + err = Write(append([]byte{byte(cc.Attest)}, resp...), conn) + if err != nil { + ch <- fmt.Errorf("failed to send AR to dialer: %w", err) + } + ch <- nil + log.Trace("Finished asynchronous sending of attestation report to dialer") + }() } else { //if not sending attestation report, send the attestation mode err := Write([]byte{byte(cc.Attest)}, conn) @@ -102,7 +124,7 @@ func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { log.Debug("Skipping server-side attestation") } - report, err := readValue(conn, cc.Attest, false) + report, err := readValue(conn, cc.Attest) if err != nil { return err } @@ -119,20 +141,28 @@ func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { log.Debug("Skipping server-side verification") } + // Finally check if asynchronous sending succeeded + if cc.Attest == Attest_Mutual || cc.Attest == Attest_Server { + err = <-ch + if err != nil { + return fmt.Errorf("failed to write asynchronously: %w", err) + } + } + log.Trace("Attestation successful") return nil } -func readValue(conn *tls.Conn, selection AttestSelect, dialer bool) ([]byte, error) { +func readValue(conn *tls.Conn, selection AttestSelect) ([]byte, error) { readvalue, err := Read(conn) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } - selectionStr, errS1 := selectionString(byte(selection)) - if errS1 != nil { - return nil, errS1 + selectionStr, err := selectionString(byte(selection)) + if err != nil { + return nil, err } // the first byte should always be the attestation mode @@ -140,9 +170,9 @@ func readValue(conn *tls.Conn, selection AttestSelect, dialer bool) ([]byte, err log.Debugf("Matching attestation mode: [%v]", selectionStr) } else { reportByte := readvalue[0] - reportStr, errS1 := selectionString(reportByte) - if errS1 != nil { - return nil, errS1 + reportStr, err := selectionString(reportByte) + if err != nil { + return nil, err } return nil, fmt.Errorf("mismatching attestation mode, local set to: [%v], while remote is set to: [%v]", selectionStr, reportStr) } diff --git a/attestedtls/backend.go b/attestedtls/backend.go index 53c789f9..14095318 100644 --- a/attestedtls/backend.go +++ b/attestedtls/backend.go @@ -16,6 +16,7 @@ package attestedtls import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -23,20 +24,36 @@ import ( "time" ) -// Writes byte array to provided channel by first sending length information, then data +// Writes byte array to provided channel by first sending length information, then data. // Used for transmitting the attestation reports between peers func Write(msg []byte, c net.Conn) error { + + length := len(msg) lenbuf := make([]byte, 4) - binary.BigEndian.PutUint32(lenbuf, uint32(len(msg))) + binary.BigEndian.PutUint32(lenbuf, uint32(length)) - buf := append(lenbuf, msg...) + n, err := c.Write(lenbuf) + if err != nil { + return fmt.Errorf("failed to write length to %v: %w", c.RemoteAddr().String(), err) + } + if n != len(lenbuf) { + return fmt.Errorf("could only send %v of %v bytes to %v", n, len(lenbuf), + c.RemoteAddr().String()) + } - _, err := c.Write(buf) + n, err = c.Write(msg) + if err != nil { + return fmt.Errorf("failed to write payload to %v: %w", c.RemoteAddr().String(), err) + } + if n != len(msg) { + return fmt.Errorf("could only send %v of %v bytes to %v", n, len(msg), + c.RemoteAddr().String()) + } return err } -// Receives byte array from provided channel by first receiving length information, then data +// Receives byte array from provided channel by first receiving length information, then data. // Used for transmitting the attestation reports between peers func Read(c net.Conn) ([]byte, error) { start := time.Now() @@ -48,30 +65,30 @@ func Read(c net.Conn) ([]byte, error) { return nil, fmt.Errorf("failed to receive message: no length: %v", err) } - len := binary.BigEndian.Uint32(lenbuf) // Max size of 4GB - log.Trace("TCP Message Length: ", len) + len := int(binary.BigEndian.Uint32(lenbuf)) // Max size of 4GB + log.Tracef("TCP Message to be received: %v", len) if len == 0 { return nil, errors.New("message length is zero") } - // Receive data in chunks of 1024 bytes as the Read function receives a maxium of 65536 bytes + // Receive data in chunks of 1024 bytes as the Read function receives a maxium of 64K bytes // and the buffer must be longer, then append it to the final buffer - tmpbuf := make([]byte, 1024) - buf := make([]byte, 0) - rcvlen := uint32(0) + buf := bytes.NewBuffer(nil) + received := 0 for { - n, err := c.Read(tmpbuf) - rcvlen += uint32(n) + chunk := make([]byte, 64*1024) + n, err := c.Read(chunk) + received += n if err != nil { return nil, fmt.Errorf("failed to receive message: %w", err) } - buf = append(buf, tmpbuf[:n]...) + buf.Write(chunk[:n]) // Abort as soon as we have read the expected data as signaled in the first 4 bytes // of the message - if rcvlen == len { + if received == len { log.Trace("Received message") break } @@ -81,5 +98,5 @@ func Read(c net.Conn) ([]byte, error) { break } } - return buf, nil + return buf.Bytes(), nil }