Skip to content

Commit

Permalink
RSDK-6883 rpc: Replace RS256 with EdDSA as default signing algo (#257)
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels authored Mar 21, 2024
1 parent b8cf57a commit 7172d3d
Show file tree
Hide file tree
Showing 12 changed files with 616 additions and 154 deletions.
4 changes: 2 additions & 2 deletions etc/setup_priv_key.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ cd test_keys
if [[ -f "private-key.pem" && -f "pkcs8.key" ]]; then
exit 0;
fi
openssl genrsa -out private-key.pem 4096
openssl genpkey -algorithm ed25519 -out private-key.pem
openssl pkcs8 -topk8 -inform PEM -outform PEM -nocrypt -in private-key.pem -out pkcs8.key
openssl rsa -in private-key.pem -outform PEM -pubout -out public-key.pem
openssl pkey -in private-key.pem -pubout -out public-key.pem
14 changes: 6 additions & 8 deletions jwks/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package jwks

import (
"context"
"crypto/rsa"
"errors"
"fmt"
"io"
"net/http"
"strings"
Expand Down Expand Up @@ -151,7 +151,7 @@ func NewStaticJWKKeyProvider(keyset KeySet) KeyProvider {
}
}

func publicKeyFromKeySet(keyset KeySet, kid, alg string) (*rsa.PublicKey, error) {
func publicKeyFromKeySet(keyset KeySet, kid, alg string) (interface{}, error) {
key, ok := keyset.LookupKeyID(kid)
if !ok {
return nil, errors.New("kid header does not exist in keyset")
Expand All @@ -161,11 +161,9 @@ func publicKeyFromKeySet(keyset KeySet, kid, alg string) (*rsa.PublicKey, error)
return nil, errors.New("key from kid has different signing alg")
}

var pubKey rsa.PublicKey
err := key.Raw(&pubKey)
if err != nil {
return nil, errors.New("invalid key type")
var pubKey interface{}
if err := key.Raw(&pubKey); err != nil {
return nil, fmt.Errorf("error getting raw key: %w", err)
}

return &pubKey, nil
return pubKey, nil
}
31 changes: 25 additions & 6 deletions rpc/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rpc

import (
"context"
"crypto/ed25519"
"crypto/rsa"
//nolint:gosec // using for fingerprint
"crypto/sha1"
Expand Down Expand Up @@ -116,8 +117,8 @@ func (p TokenVerificationKeyProviderFunc) Close(ctx context.Context) error {
return nil
}

// MakePublicKeyProvider returns a TokenVerificationKeyProvider that provides a public key for JWT verification.
func MakePublicKeyProvider(pubKey *rsa.PublicKey) TokenVerificationKeyProvider {
// MakeRSAPublicKeyProvider returns a TokenVerificationKeyProvider that provides a public key for JWT verification.
func MakeRSAPublicKeyProvider(pubKey *rsa.PublicKey) TokenVerificationKeyProvider {
return TokenVerificationKeyProviderFunc(
func(ctx context.Context, token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
Expand All @@ -129,6 +130,19 @@ func MakePublicKeyProvider(pubKey *rsa.PublicKey) TokenVerificationKeyProvider {
)
}

// MakeEd25519PublicKeyProvider returns a TokenVerificationKeyProvider that provides a public key for JWT verification.
func MakeEd25519PublicKeyProvider(pubKey ed25519.PublicKey) TokenVerificationKeyProvider {
return TokenVerificationKeyProviderFunc(
func(ctx context.Context, token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
return nil, fmt.Errorf("unexpected signing method %q", token.Method.Alg())
}

return pubKey, nil
},
)
}

// MakeOIDCKeyProvider returns a TokenVerificationKeyProvider that dynamically looks up a public key for
// JWT verification by inspecting the JWT's kid field. The given issuer is used to discover the JWKs
// used for verification. This issuer is expected to follow the OIDC Discovery protocol.
Expand Down Expand Up @@ -256,6 +270,12 @@ type Credentials struct {
Payload string `json:"payload"`
}

type credAuthHandlers struct {
AuthHandler AuthHandler
EntityDataLoader EntityDataLoader
TokenVerificationKeyProvider TokenVerificationKeyProvider
}

// RSAPublicKeyThumbprint returns SHA1 of the public key's modulus Base64 URL encoded without padding.
func RSAPublicKeyThumbprint(key *rsa.PublicKey) (string, error) {
//nolint:gosec // using for fingerprint
Expand All @@ -268,8 +288,7 @@ func RSAPublicKeyThumbprint(key *rsa.PublicKey) (string, error) {
return base64.RawURLEncoding.EncodeToString(thumbPrint.Sum(nil)), nil
}

type credAuthHandlers struct {
AuthHandler AuthHandler
EntityDataLoader EntityDataLoader
TokenVerificationKeyProvider TokenVerificationKeyProvider
// ED25519PublicKeyThumbprint returns the base64 encoded public key.
func ED25519PublicKeyThumbprint(key ed25519.PublicKey) string {
return base64.RawURLEncoding.EncodeToString(key)
}
45 changes: 42 additions & 3 deletions rpc/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rpc

import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"testing"
Expand Down Expand Up @@ -173,11 +174,11 @@ func TestTokenVerificationKeyProviderFunc(t *testing.T) {
<-capCtx
}

func TestWithPublicKeyProvider(t *testing.T) {
privKey, err := rsa.GenerateKey(rand.Reader, generatedRSAKeyBits)
func TestWithRSAPublicKeyProvider(t *testing.T) {
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
test.That(t, err, test.ShouldBeNil)
pubKey := &privKey.PublicKey
provider := MakePublicKeyProvider(pubKey)
provider := MakeRSAPublicKeyProvider(pubKey)

token := jwt.NewWithClaims(jwt.SigningMethodRS256, JWTClaims{
RegisteredClaims: jwt.RegisteredClaims{
Expand Down Expand Up @@ -212,6 +213,44 @@ func TestWithPublicKeyProvider(t *testing.T) {
test.That(t, err, test.ShouldBeNil)
}

func TestWithEd25519PublicKeyProvider(t *testing.T) {
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
test.That(t, err, test.ShouldBeNil)
provider := MakeEd25519PublicKeyProvider(pubKey)

token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, JWTClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: uuid.NewString(),
Audience: jwt.ClaimStrings{"does not matter"},
},
AuthCredentialsType: CredentialsType("fake"),
})

verificationKey, err := provider.TokenVerificationKey(context.Background(), token)
test.That(t, err, test.ShouldBeNil)

badToken := jwt.NewWithClaims(jwt.SigningMethodHS256, JWTClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: uuid.NewString(),
Audience: jwt.ClaimStrings{"does not matter"},
},
AuthCredentialsType: CredentialsType("fake"),
})

_, err = provider.TokenVerificationKey(context.Background(), badToken)
test.That(t, err, test.ShouldNotBeNil)
test.That(t, err.Error(), test.ShouldContainSubstring, "unexpected signing method")

tokenString, err := token.SignedString(privKey)
test.That(t, err, test.ShouldBeNil)

var claims JWTClaims
_, err = jwt.ParseWithClaims(tokenString, &claims, func(token *jwt.Token) (interface{}, error) {
return verificationKey, nil
})
test.That(t, err, test.ShouldBeNil)
}

func TestRSAPublicKeyThumbprint(t *testing.T) {
privKey1, err := rsa.GenerateKey(rand.Reader, 512)
test.That(t, err, test.ShouldBeNil)
Expand Down
54 changes: 32 additions & 22 deletions rpc/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package rpc

import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -103,7 +104,7 @@ func testDial(t *testing.T, signalingCallQueue WebRTCCallQueue, logger golog.Log
httpListenerExternal, err := net.Listen("tcp", "localhost:0")
test.That(t, err, test.ShouldBeNil)

privKeyExternal, err := rsa.GenerateKey(rand.Reader, generatedRSAKeyBits)
pubKeyExternal, privKeyExternal, err := ed25519.GenerateKey(rand.Reader)
test.That(t, err, test.ShouldBeNil)

externalSignalingHosts := make([]string, len(hosts))
Expand Down Expand Up @@ -145,7 +146,7 @@ func testDial(t *testing.T, signalingCallQueue WebRTCCallQueue, logger golog.Log
}
return nil, errors.New("this auth does not work yet")
})),
WithExternalAuthPublicKeyTokenVerifier(&privKeyExternal.PublicKey),
WithExternalAuthEd25519PublicKeyTokenVerifier(pubKeyExternal),
)
test.That(t, err, test.ShouldBeNil)

Expand All @@ -168,6 +169,8 @@ func testDial(t *testing.T, signalingCallQueue WebRTCCallQueue, logger golog.Log

var authToFail bool
acceptedFakeWithKeyEnts := []string{"someotherthing", httpListenerExternal.Addr().String()}
keyOpt, keyID := WithAuthED25519PrivateKey(privKeyExternal)
test.That(t, keyID, test.ShouldEqual, base64.RawURLEncoding.EncodeToString(privKeyExternal.Public().(ed25519.PublicKey)))
rpcServerExternal, err := NewServer(
logger,
WithAuthHandler("fakeExtWithKey", AuthHandlerFunc(func(ctx context.Context, entity, payload string) (map[string]string, error) {
Expand All @@ -186,7 +189,7 @@ func testDial(t *testing.T, signalingCallQueue WebRTCCallQueue, logger golog.Log
}
return map[string]string{}, nil
})),
WithAuthRSAPrivateKey(privKeyExternal),
keyOpt,
WithAuthenticateToHandler(func(ctx context.Context, entity string) (map[string]string, error) {
if authToFail {
return nil, errors.New("darn")
Expand Down Expand Up @@ -492,19 +495,20 @@ func TestDialExternalAuth(t *testing.T) {
httpListenerExternal2, err := net.Listen("tcp", "localhost:0")
test.That(t, err, test.ShouldBeNil)

privKeyInternal, err := rsa.GenerateKey(rand.Reader, generatedRSAKeyBits)
pubKeyInternal, privKeyInternal, err := ed25519.GenerateKey(rand.Reader)
test.That(t, err, test.ShouldBeNil)
privKeyExternal, err := rsa.GenerateKey(rand.Reader, generatedRSAKeyBits)
pubKeyExternal, privKeyExternal, err := ed25519.GenerateKey(rand.Reader)
test.That(t, err, test.ShouldBeNil)
privKeyExternal2, err := rsa.GenerateKey(rand.Reader, generatedRSAKeyBits)
pubKeyExternal2, privKeyExternal2, err := ed25519.GenerateKey(rand.Reader)
test.That(t, err, test.ShouldBeNil)

internalAudience := []string{"int-aud2", "int-aud1", "int-aud3"}
keyOpt, _ := WithAuthED25519PrivateKey(privKeyInternal)
rpcServerInternal, err := NewServer(
logger,
// we are both some UUID and somesub as far as an audience goes
WithAuthAudience(internalAudience...),
WithAuthRSAPrivateKey(privKeyInternal),
keyOpt,
WithWebRTCServerOptions(WebRTCServerOptions{
Enable: true,
InternalSignalingHosts: []string{"yeehaw", internalAddr},
Expand All @@ -522,7 +526,7 @@ func TestDialExternalAuth(t *testing.T) {
return claims.Entity(), nil
})),
WithTokenVerificationKeyProvider(CredentialsTypeExternal,
MakePublicKeyProvider(&privKeyExternal.PublicKey),
MakeEd25519PublicKeyProvider(pubKeyExternal),
),
)
test.That(t, err, test.ShouldBeNil)
Expand All @@ -541,6 +545,7 @@ func TestDialExternalAuth(t *testing.T) {

var authToFail bool
acceptedFakeWithKeyEnts := []string{"someotherthing", httpListenerExternal.Addr().String()}
keyOptExternal, _ := WithAuthED25519PrivateKey(privKeyExternal)
rpcServerExternal, err := NewServer(
logger,
WithWebRTCServerOptions(WebRTCServerOptions{
Expand All @@ -566,7 +571,7 @@ func TestDialExternalAuth(t *testing.T) {
}
return map[string]string{}, nil
})),
WithAuthRSAPrivateKey(privKeyExternal),
keyOptExternal,
WithAuthenticateToHandler(func(ctx context.Context, entity string) (map[string]string, error) {
if authToFail {
return nil, errors.New("darn")
Expand All @@ -586,6 +591,7 @@ func TestDialExternalAuth(t *testing.T) {
)
test.That(t, err, test.ShouldBeNil)

keyOptExternal2, _ := WithAuthED25519PrivateKey(privKeyExternal2)
rpcServerExternal2, err := NewServer(
logger,
WithAuthHandler("fake", AuthHandlerFunc(func(ctx context.Context, entity, payload string) (map[string]string, error) {
Expand All @@ -594,7 +600,7 @@ func TestDialExternalAuth(t *testing.T) {
WithAuthHandler("fakeWithKey", AuthHandlerFunc(func(ctx context.Context, entity, payload string) (map[string]string, error) {
return map[string]string{}, nil
})),
WithAuthRSAPrivateKey(privKeyExternal2),
keyOptExternal2,
WithAuthenticateToHandler(func(ctx context.Context, entity string) (map[string]string, error) {
var ok bool
for _, ent := range internalAudience {
Expand Down Expand Up @@ -873,7 +879,7 @@ func TestDialExternalAuth(t *testing.T) {

t.Run("with signaling external auth material", func(t *testing.T) {
// rpcServerExternal.InstanceNames()[0] is the implicit audience
accessToken := signTestAuthToken(t, privKeyExternal, rpcServerExternal.InstanceNames()[0], "sub1", "fake")
accessToken := signTestAuthToken(t, pubKeyExternal, privKeyExternal, rpcServerExternal.InstanceNames()[0], "sub1", "fake")
opts := []DialOption{
WithInsecure(),
WithExternalAuthInsecure(),
Expand All @@ -897,7 +903,7 @@ func TestDialExternalAuth(t *testing.T) {

t.Run("with external auth material for external auth and signaler", func(t *testing.T) {
internalExternalAuthSrv.fail = false
accessToken := signTestAuthToken(t, privKeyInternal, "int-aud1", "sub1", "fake")
accessToken := signTestAuthToken(t, pubKeyInternal, privKeyInternal, "int-aud1", "sub1", "fake")
opts := []DialOption{
WithInsecure(),
WithExternalAuthInsecure(),
Expand All @@ -910,7 +916,7 @@ func TestDialExternalAuth(t *testing.T) {

t.Run("with external auth material for external auth and signaler with invalid key", func(t *testing.T) {
internalExternalAuthSrv.fail = false
accessToken := signTestAuthToken(t, privKeyExternal2, "aud1", "sub1", "fake")
accessToken := signTestAuthToken(t, pubKeyExternal2, privKeyExternal2, "aud1", "sub1", "fake")
opts := []DialOption{
WithInsecure(),
WithExternalAuthInsecure(),
Expand All @@ -923,7 +929,7 @@ func TestDialExternalAuth(t *testing.T) {
gStatus, ok := status.FromError(err)
test.That(t, ok, test.ShouldBeTrue)
test.That(t, gStatus.Code(), test.ShouldEqual, codes.Unauthenticated)
test.That(t, gStatus.Message(), test.ShouldContainSubstring, "crypto/rsa: verification error")
test.That(t, gStatus.Message(), test.ShouldContainSubstring, " this server did not sign this JWT")
})
})

Expand Down Expand Up @@ -1635,7 +1641,7 @@ type externalAuthServer struct {
expectedEnt string
expectedAud []string
noMetadata bool
privKey *rsa.PrivateKey
privKey ed25519.PrivateKey
}

func (svc *externalAuthServer) AuthenticateTo(
Expand Down Expand Up @@ -1668,7 +1674,7 @@ func (svc *externalAuthServer) AuthenticateTo(

entity := MustContextAuthEntity(ctx).Entity

token := jwt.NewWithClaims(jwt.SigningMethodRS256, JWTClaims{
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, JWTClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: entity,
Audience: jwt.ClaimStrings{req.Entity},
Expand All @@ -1687,10 +1693,16 @@ func (svc *externalAuthServer) AuthenticateTo(
}, nil
}

func signTestAuthToken(t *testing.T, privKey *rsa.PrivateKey, aud, ent string, credType CredentialsType) string {
func signTestAuthToken(
t *testing.T,
pubKey ed25519.PublicKey,
privKey ed25519.PrivateKey,
aud, ent string,
credType CredentialsType,
) string {
t.Helper()

token := jwt.NewWithClaims(jwt.SigningMethodRS256, JWTClaims{
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, JWTClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: ent,
Audience: jwt.ClaimStrings{aud},
Expand All @@ -1699,9 +1711,7 @@ func signTestAuthToken(t *testing.T, privKey *rsa.PrivateKey, aud, ent string, c
AuthMetadata: map[string]string{},
})

var err error
token.Header["kid"], err = RSAPublicKeyThumbprint(&privKey.PublicKey)
test.That(t, err, test.ShouldBeNil)
token.Header["kid"] = base64.RawURLEncoding.EncodeToString(pubKey)

tokenString, err := token.SignedString(privKey)
test.That(t, err, test.ShouldBeNil)
Expand Down
Loading

0 comments on commit 7172d3d

Please sign in to comment.