Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

treewide: increased max message size limit #173

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ type TLSCertResponse struct {
Certificate [][]byte `json:"certificate" cbor:"0,keyasint"`
}

const (
// Set maximum message length to 10 MB
MaxMsgLen = 1024 * 1024 * 10
)

type HashFunction int32

const (
Expand Down Expand Up @@ -188,8 +193,17 @@ func SignerOptsToHash(opts crypto.SignerOpts) (HashFunction, error) {
// Type uint32 -> Type of the payload
// payload []byte -> CBOR-encoded payload
func Receive(conn net.Conn) ([]byte, uint32, error) {

err := conn.(*net.UnixConn).SetReadBuffer(MaxMsgLen)
if err != nil {
return nil, 0, fmt.Errorf("failed to socket write buffer size %v", err)
}

// Read header
buf := make([]byte, 8)

log.Tracef("Reading header length %v", len(buf))

n, err := conn.Read(buf)
if err != nil {
return nil, 0, fmt.Errorf("failed to read header: %w", err)
Expand All @@ -199,18 +213,25 @@ func Receive(conn net.Conn) ([]byte, uint32, error) {
}

// Decode header to get length and type
len := binary.BigEndian.Uint32(buf[0:4])
payloadLen := binary.BigEndian.Uint32(buf[0:4])
msgType := binary.BigEndian.Uint32(buf[4:8])

if payloadLen > MaxMsgLen {
return nil, 0, fmt.Errorf("cannot receive: payload size %v exceeds maximum size %v",
payloadLen, MaxMsgLen)
}

log.Tracef("Decoded header. Expecting type %v, length %v", msgType, payloadLen)

// Read payload
payload := make([]byte, len)
payload := make([]byte, payloadLen)
n, err = conn.Read(payload)
if err != nil {
return nil, 0, fmt.Errorf("failed to read payload: %w", err)
}
if uint32(n) != len {
if uint32(n) != payloadLen {
return nil, 0, fmt.Errorf("failed to read payload (received %v, expected %v bytes)",
n, len)
n, payloadLen)
}

if msgType == TypeError {
Expand All @@ -232,16 +253,33 @@ func Receive(conn net.Conn) ([]byte, uint32, error) {
// Type uint32 -> Type of the payload
// payload []byte -> CBOR-encoded payload
func Send(conn net.Conn, payload []byte, t uint32) error {

if len(payload) > MaxMsgLen {
return fmt.Errorf("cannot send: payload size %v exceeds maximum size %v",
len(payload), MaxMsgLen)
}

err := conn.(*net.UnixConn).SetWriteBuffer(MaxMsgLen)
if err != nil {
return fmt.Errorf("failed to socket write buffer size %v", err)
}

buf := make([]byte, 8)
binary.BigEndian.PutUint32(buf[0:4], uint32(len(payload)))
binary.BigEndian.PutUint32(buf[4:8], t)

log.Tracef("Sending header length %v", len(buf))

n, err := conn.Write(buf)
if err != nil {
return fmt.Errorf("failed to send header: %w", err)
}
if n != len(buf) {
return fmt.Errorf("could only send %v of %v bytes", n, len(buf))
}

log.Tracef("Sending payload type %v length %v", t, uint32(len(payload)))

n, err = conn.Write(payload)
if err != nil {
return fmt.Errorf("failed to send response: %w", err)
Expand Down
20 changes: 16 additions & 4 deletions cmcd/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc) {
return
}

api.Send(conn, data, api.TypeAttest)
err = api.Send(conn, data, api.TypeAttest)
if err != nil {
api.SendError(conn, "failed to send: %v", err)
}

log.Debug("Prover: Finished")
}
Expand Down Expand Up @@ -181,7 +184,10 @@ func verify(conn net.Conn, payload []byte, cmc *cmc.Cmc) {
return
}

api.Send(conn, data, api.TypeVerify)
err = api.Send(conn, data, api.TypeVerify)
if err != nil {
api.SendError(conn, "failed to send: %v", err)
}

log.Debug("Verifier: Finished")
}
Expand Down Expand Up @@ -235,7 +241,10 @@ func tlssign(conn net.Conn, payload []byte, cmc *cmc.Cmc) {
return
}

api.Send(conn, data, api.TypeTLSSign)
err = api.Send(conn, data, api.TypeTLSSign)
if err != nil {
api.SendError(conn, "failed to send: %v", err)
}

log.Debug("Performed signing")
}
Expand Down Expand Up @@ -276,7 +285,10 @@ func tlscert(conn net.Conn, payload []byte, cmc *cmc.Cmc) {
return
}

api.Send(conn, data, api.TypeTLSCert)
err = api.Send(conn, data, api.TypeTLSCert)
if err != nil {
api.SendError(conn, "failed to send: %v", err)
}

log.Debug("Obtained TLS cert")
}
9 changes: 7 additions & 2 deletions testtool/coap.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ func init() {

func (a CoapApi) generate(c *config) {

log.Tracef("Connecting via CoAP to %v", c.CmcAddr)

// Establish connection
conn, err := udp.Dial(c.CmcAddr)
if err != nil {
Expand Down Expand Up @@ -96,14 +98,14 @@ func (a CoapApi) generate(c *config) {
if err != nil {
log.Fatalf("Failed to save attestation report as %v: %v", c.ReportFile, err)
}
fmt.Println("Wrote attestation report: ", c.ReportFile)
log.Infof("Wrote attestation report: %v", c.ReportFile)

// Save the nonce for the verifier
os.WriteFile(c.NonceFile, nonce, 0644)
if err != nil {
log.Fatalf("Failed to save nonce as %v: %v", c.NonceFile, err)
}
fmt.Println("Wrote nonce: ", c.NonceFile)
log.Infof("Wrote nonce: %v", c.NonceFile)

}

Expand Down Expand Up @@ -168,6 +170,9 @@ func (a CoapApi) iothub(c *config) {

func verifyInternal(addr string, req *api.VerificationRequest,
) (*api.VerificationResponse, error) {

log.Tracef("Connecting via CoAP to %v", addr)

// Establish connection
conn, err := udp.Dial(addr)
if err != nil {
Expand Down
9 changes: 6 additions & 3 deletions testtool/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ package main
import (
"context"
"crypto/rand"
"fmt"
"os"
"time"

Expand All @@ -46,6 +45,8 @@ func (a GrpcApi) generate(c *config) {
ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second)
defer cancel()

log.Tracef("Connecting via gRPC to %v", c.CmcAddr)

conn, err := grpc.DialContext(ctx, c.CmcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
if err != nil {
log.Fatalf("Failed to connect to cmcd: %v", err)
Expand Down Expand Up @@ -76,14 +77,14 @@ func (a GrpcApi) generate(c *config) {
if err != nil {
log.Fatalf("Failed to save attestation report as %v: %v", c.ReportFile, err)
}
fmt.Println("Wrote attestation report: ", c.ReportFile)
log.Infof("Wrote attestation report: %v", c.ReportFile)

// Save the nonce for the verifier
os.WriteFile(c.NonceFile, nonce, 0644)
if err != nil {
log.Fatalf("Failed to save nonce as %v: %v", c.NonceFile, err)
}
fmt.Println("Wrote nonce: ", c.NonceFile)
log.Infof("Wrote nonce: %v", c.NonceFile)

}

Expand All @@ -93,6 +94,8 @@ func (a GrpcApi) verify(c *config) {
ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second)
defer cancel()

log.Tracef("Connecting via gRPC to %v", c.CmcAddr)

conn, err := grpc.DialContext(ctx, c.CmcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
if err != nil {
log.Fatalf("Failed to connect to cmcd: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion testtool/publish.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func saveResult(file, addr string, result []byte) error {
// Save the Attestation Result to file
if file != "" {
os.WriteFile(file, out.Bytes(), 0644)
fmt.Println("Wrote file ", file)
log.Infof("Wrote file %v", file)
} else {
log.Debug("No config file specified: will not save attestation report")
}
Expand Down
9 changes: 7 additions & 2 deletions testtool/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ func init() {

func (a SocketApi) generate(c *config) {

log.Tracef("Connecting via %v socket to %v", c.Network, c.CmcAddr)

// Establish connection
conn, err := net.Dial(c.Network, c.CmcAddr)
if err != nil {
Expand Down Expand Up @@ -88,14 +90,14 @@ func (a SocketApi) generate(c *config) {
if err != nil {
log.Fatalf("Failed to save attestation report as %v: %v", c.ReportFile, err)
}
fmt.Println("Wrote attestation report: ", c.ReportFile)
log.Infof("Wrote attestation report: %v", c.ReportFile)

// Save the nonce for the verifier
os.WriteFile(c.NonceFile, nonce, 0644)
if err != nil {
log.Fatalf("Failed to save nonce as %v: %v", c.NonceFile, err)
}
fmt.Println("Wrote nonce: ", c.NonceFile)
log.Infof("Wrote nonce: %v", c.NonceFile)

}

Expand Down Expand Up @@ -156,6 +158,9 @@ func (a SocketApi) iothub(c *config) {

func verifySocketRequest(network, addr string, req *api.VerificationRequest,
) (*api.VerificationResponse, error) {

log.Tracef("Connecting via %v socket to %v", network, addr)

// Establish connection
conn, err := net.Dial(network, addr)
if err != nil {
Expand Down
Loading