Skip to content

Commit

Permalink
RSDK-7010 Invalidate JWT access token upon Unauthenticated error (#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjirewis authored Apr 16, 2024
1 parent 8ae9552 commit 0a38aae
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 4 deletions.
18 changes: 18 additions & 0 deletions rpc/client_interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (

"go.opencensus.io/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

// UnaryClientTracingInterceptor adds the current Span's metadata to the context.
Expand Down Expand Up @@ -45,3 +47,19 @@ func contextWithSpanMetadata(ctx context.Context) context.Context {
)
return ctx
}

// UnaryClientInvalidAuthInterceptor clears the access token stored on creds in
// the event of an "Unauthenticated" or "PermissionDenied" gRPC error to force
// re-auth.
func UnaryClientInvalidAuthInterceptor(creds *perRPCJWTCredentials) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
) error {
err := invoker(ctx, method, req, reply, cc, opts...)
if c := status.Code(err); c == codes.Unauthenticated || c == codes.PermissionDenied &&
creds != nil {
creds.accessToken = ""
}
return err
}
}
13 changes: 9 additions & 4 deletions rpc/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,17 +290,13 @@ func dialDirectGRPC(ctx context.Context, address string, dOpts dialOptions, logg
if dOpts.unaryInterceptor != nil {
unaryInterceptors = append(unaryInterceptors, dOpts.unaryInterceptor)
}
unaryInterceptor := grpc_middleware.ChainUnaryClient(unaryInterceptors...)
dialOpts = append(dialOpts, grpc.WithUnaryInterceptor(unaryInterceptor))

var streamInterceptors []grpc.StreamClientInterceptor
streamInterceptors = append(streamInterceptors, grpc_zap.StreamClientInterceptor(grpcLogger))
streamInterceptors = append(streamInterceptors, StreamClientTracingInterceptor())
if dOpts.streamInterceptor != nil {
streamInterceptors = append(streamInterceptors, dOpts.streamInterceptor)
}
streamInterceptor := grpc_middleware.ChainStreamClient(streamInterceptors...)
dialOpts = append(dialOpts, grpc.WithStreamInterceptor(streamInterceptor))

var connPtr *ClientConn
var closeCredsFunc func() error
Expand Down Expand Up @@ -335,8 +331,17 @@ func dialDirectGRPC(ctx context.Context, address string, dOpts dialOptions, logg
connPtr = &rpcCreds.conn
}
dialOpts = append(dialOpts, grpc.WithPerRPCCredentials(rpcCreds))
unaryInterceptors = append(unaryInterceptors, UnaryClientInvalidAuthInterceptor(rpcCreds))
// InvalidAuthInterceptor will not work for streaming calls; we can only
// intercept the creation of a stream, and the ensuring of authentication
// server-side happens per RPC request (per usage of the stream).
}

unaryInterceptor := grpc_middleware.ChainUnaryClient(unaryInterceptors...)
dialOpts = append(dialOpts, grpc.WithUnaryInterceptor(unaryInterceptor))
streamInterceptor := grpc_middleware.ChainStreamClient(streamInterceptors...)
dialOpts = append(dialOpts, grpc.WithStreamInterceptor(streamInterceptor))

var conn ClientConn
var cached bool
var err error
Expand Down
75 changes: 75 additions & 0 deletions rpc/dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,78 @@ func TestWithStreamInterceptor(t *testing.T) {
)
test.That(t, interceptedCount, test.ShouldEqual, 1)
}

func TestInvalidAuth(t *testing.T) {
logger := golog.NewTestLogger(t)

var timesAuthed, timesAuthEnsured int
fakeAuthHandler := AuthHandlerFunc(func(ctx context.Context, entity, payload string) (map[string]string, error) {
timesAuthed++
return map[string]string{}, nil
})
fakeEnsureAuthedHandler := func(ctx context.Context) (context.Context, error) {
timesAuthEnsured++
// Fail until client has reauthenticated (timesAuthed > 1).
if timesAuthed <= 1 {
return nil, status.Error(codes.Unauthenticated, "bad")
}
return context.Background(), nil
}
rpcServer, err := NewServer(
logger,
WithAuthHandler(CredentialsTypeAPIKey, fakeAuthHandler),
WithEnsureAuthedHandler(fakeEnsureAuthedHandler),
)
test.That(t, err, test.ShouldBeNil)

err = rpcServer.RegisterServiceServer(
context.Background(),
&pb.EchoService_ServiceDesc,
&echoserver.Server{},
pb.RegisterEchoServiceHandlerFromEndpoint,
)
test.That(t, err, test.ShouldBeNil)

httpListener, err := net.Listen("tcp", "localhost:0")
test.That(t, err, test.ShouldBeNil)

errChan := make(chan error, 1)
go func() {
errChan <- rpcServer.Serve(httpListener)
}()
defer func() {
test.That(t, rpcServer.Stop(), test.ShouldBeNil)
err = <-errChan
test.That(t, err, test.ShouldBeNil)
}()

conn, err := Dial(
context.Background(),
httpListener.Addr().String(),
logger,
WithForceDirectGRPC(),
WithInsecure(),
WithDialDebug(),
// Have to pass some creds to pass "authentication required" check.
WithEntityCredentials("foo", Credentials{Type: CredentialsTypeAPIKey, Payload: "bar"}),
)
test.That(t, err, test.ShouldBeNil)
defer func() {
test.That(t, conn.Close(), test.ShouldBeNil)
}()

// Test that first echo fails and server's `Authenticate` is called once.
client := pb.NewEchoServiceClient(conn)
_, err = client.Echo(context.Background(), &pb.EchoRequest{Message: "hello"})
test.That(t, err, test.ShouldNotBeNil)
test.That(t, err.Error(), test.ShouldContainSubstring, "Unauthenticated")
test.That(t, timesAuthed, test.ShouldEqual, 1)
test.That(t, timesAuthEnsured, test.ShouldEqual, 1)

// Test that second echo succeeds and `Authenticate` is called again (reauth).
echoResp, err := client.Echo(context.Background(), &pb.EchoRequest{Message: "hello"})
test.That(t, err, test.ShouldBeNil)
test.That(t, echoResp.GetMessage(), test.ShouldEqual, "hello")
test.That(t, timesAuthed, test.ShouldEqual, 2)
test.That(t, timesAuthEnsured, test.ShouldEqual, 2)
}
2 changes: 2 additions & 0 deletions rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ type simpleServer struct {
tlsAuthHandler func(ctx context.Context, entities ...string) error
authHandlersForCreds map[CredentialsType]credAuthHandlers
authToHandler AuthenticateToHandler
ensureAuthedHandler func(ctx context.Context) (context.Context, error)

// authAudience is the JWT audience (aud) that will be used/expected
// for our service.
Expand Down Expand Up @@ -296,6 +297,7 @@ func NewServer(logger golog.Logger, opts ...ServerOption) (Server, error) {
authToHandler: sOpts.authToHandler,
authAudience: sOpts.authAudience,
authIssuer: sOpts.authIssuer,
ensureAuthedHandler: sOpts.ensureAuthedHandler,
exemptMethods: make(map[string]bool),
publicMethods: make(map[string]bool),
tlsConfig: sOpts.tlsConfig,
Expand Down
5 changes: 5 additions & 0 deletions rpc/server_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ func (ss *simpleServer) tryAuth(ctx context.Context) (context.Context, error) {
}

func (ss *simpleServer) ensureAuthed(ctx context.Context) (context.Context, error) {
// Use handler if set (only used for testing).
if ss.ensureAuthedHandler != nil {
return ss.ensureAuthedHandler(ctx)
}

tokenString, err := tokenFromContext(ctx)
if err != nil {
// check TLS state
Expand Down
15 changes: 15 additions & 0 deletions rpc/server_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ type serverOptions struct {
// stats monitoring on the connections.
statsHandler stats.Handler

// ensureAuthedHandler is the callback used to ensure that the context of an
// incoming RPC request is properly authenticated.
ensureAuthedHandler func(ctx context.Context) (context.Context, error)

unknownStreamDesc *grpc.StreamDesc
}

Expand Down Expand Up @@ -358,6 +362,17 @@ func WithAuthHandler(forType CredentialsType, handler AuthHandler) ServerOption
})
}

// WithEnsureAuthedHandler returns a ServerOptions which adds custom logic for
// the ensuring of authentication on each incoming request. Can only be used
// in testing environments (will produce an error when ensuring authentication
// otherwise).
func WithEnsureAuthedHandler(eah func(ctx context.Context) (context.Context, error)) ServerOption {
return newFuncServerOption(func(o *serverOptions) error {
o.ensureAuthedHandler = eah
return nil
})
}

// WithEntityDataLoader returns a ServerOption which adds an entity data loader
// associated to the given credential type to use for loading data after the signed
// access token has been verified.
Expand Down

0 comments on commit 0a38aae

Please sign in to comment.