diff --git a/hpke/aead.go b/hpke/aead.go index baf147bb4..0e81be523 100644 --- a/hpke/aead.go +++ b/hpke/aead.go @@ -2,6 +2,7 @@ package hpke import ( "crypto/cipher" + "encoding/base64" "fmt" ) @@ -101,3 +102,54 @@ func (c *openContext) Open(ct, aad []byte) ([]byte, error) { } return pt, nil } + +// SealWithNonce takes a plaintext pt and associated data aad, and returns a ciphertext +// that is base64-encoded along with the nonce used for encryption. The nonce is auto-incremented +// and is prepended to the ciphertext before base64-encoding. +func (c *sealContext) SealWithNonce(pt, aad []byte) ([]byte, error) { + nonce := c.calcNonce() + ct := c.AEAD.Seal(nil, nonce, pt, aad) + err := c.increment() + if err != nil { + for i := range ct { + ct[i] = 0 + } + return nil, err + } + + // prepended the nonce to the ciphertext. + ct = append(nonce, ct...) + + // base64-encoded the ciphertext. + encodedCt := make([]byte, base64.RawStdEncoding.EncodedLen(len(ct))) + base64.RawStdEncoding.Encode(encodedCt, ct) + return encodedCt, nil +} + +// OpenWithNonce takes a base64-encoded ciphertext ct and associated data aad, +// and returns the corresponding plaintext. It assumes that the nonce is prepended +// to the ciphertext before base64-encoding. +func (c *openContext) OpenWithNonce(encodedCt, aad []byte) ([]byte, error) { + // base64-decodes the ciphertext. + decodedCt := make([]byte, base64.RawStdEncoding.DecodedLen(len(encodedCt))) + n, err := base64.RawStdEncoding.Decode(decodedCt, encodedCt) + if err != nil { + return nil, err + } + + decodedCt = decodedCt[:n] + + // The nonce is extracted from the ciphertext. + Nn := c.AEAD.NonceSize() + if len(decodedCt) < Nn { + return nil, fmt.Errorf("invalid ciphertext") + } + + // decrypts the ciphertext. + pt, err := c.AEAD.Open(nil, decodedCt[:Nn], decodedCt[Nn:], aad) + if err != nil { + return nil, err + } + + return pt, nil +}