diff --git a/rpc/client_interceptors.go b/rpc/client_interceptors.go index adae57a7..98ee5041 100644 --- a/rpc/client_interceptors.go +++ b/rpc/client_interceptors.go @@ -6,7 +6,9 @@ import ( "go.opencensus.io/trace" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) // UnaryClientTracingInterceptor adds the current Span's metadata to the context. @@ -45,3 +47,19 @@ func contextWithSpanMetadata(ctx context.Context) context.Context { ) return ctx } + +// UnaryClientInvalidAuthInterceptor clears the access token stored on creds in +// the event of an "Unauthenticated" or "PermissionDenied" gRPC error to force +// re-auth. +func UnaryClientInvalidAuthInterceptor(creds *perRPCJWTCredentials) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, + cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption, + ) error { + err := invoker(ctx, method, req, reply, cc, opts...) + if c := status.Code(err); c == codes.Unauthenticated || c == codes.PermissionDenied && + creds != nil { + creds.accessToken = "" + } + return err + } +} diff --git a/rpc/dialer.go b/rpc/dialer.go index ee9d02b1..041e0d46 100644 --- a/rpc/dialer.go +++ b/rpc/dialer.go @@ -290,8 +290,6 @@ func dialDirectGRPC(ctx context.Context, address string, dOpts dialOptions, logg if dOpts.unaryInterceptor != nil { unaryInterceptors = append(unaryInterceptors, dOpts.unaryInterceptor) } - unaryInterceptor := grpc_middleware.ChainUnaryClient(unaryInterceptors...) - dialOpts = append(dialOpts, grpc.WithUnaryInterceptor(unaryInterceptor)) var streamInterceptors []grpc.StreamClientInterceptor streamInterceptors = append(streamInterceptors, grpc_zap.StreamClientInterceptor(grpcLogger)) @@ -299,8 +297,6 @@ func dialDirectGRPC(ctx context.Context, address string, dOpts dialOptions, logg if dOpts.streamInterceptor != nil { streamInterceptors = append(streamInterceptors, dOpts.streamInterceptor) } - streamInterceptor := grpc_middleware.ChainStreamClient(streamInterceptors...) - dialOpts = append(dialOpts, grpc.WithStreamInterceptor(streamInterceptor)) var connPtr *ClientConn var closeCredsFunc func() error @@ -335,8 +331,17 @@ func dialDirectGRPC(ctx context.Context, address string, dOpts dialOptions, logg connPtr = &rpcCreds.conn } dialOpts = append(dialOpts, grpc.WithPerRPCCredentials(rpcCreds)) + unaryInterceptors = append(unaryInterceptors, UnaryClientInvalidAuthInterceptor(rpcCreds)) + // InvalidAuthInterceptor will not work for streaming calls; we can only + // intercept the creation of a stream, and the ensuring of authentication + // server-side happens per RPC request (per usage of the stream). } + unaryInterceptor := grpc_middleware.ChainUnaryClient(unaryInterceptors...) + dialOpts = append(dialOpts, grpc.WithUnaryInterceptor(unaryInterceptor)) + streamInterceptor := grpc_middleware.ChainStreamClient(streamInterceptors...) + dialOpts = append(dialOpts, grpc.WithStreamInterceptor(streamInterceptor)) + var conn ClientConn var cached bool var err error diff --git a/rpc/dialer_test.go b/rpc/dialer_test.go index b79e112e..e3d6a47a 100644 --- a/rpc/dialer_test.go +++ b/rpc/dialer_test.go @@ -542,3 +542,78 @@ func TestWithStreamInterceptor(t *testing.T) { ) test.That(t, interceptedCount, test.ShouldEqual, 1) } + +func TestInvalidAuth(t *testing.T) { + logger := golog.NewTestLogger(t) + + var timesAuthed, timesAuthEnsured int + fakeAuthHandler := AuthHandlerFunc(func(ctx context.Context, entity, payload string) (map[string]string, error) { + timesAuthed++ + return map[string]string{}, nil + }) + fakeEnsureAuthedHandler := func(ctx context.Context) (context.Context, error) { + timesAuthEnsured++ + // Fail until client has reauthenticated (timesAuthed > 1). + if timesAuthed <= 1 { + return nil, status.Error(codes.Unauthenticated, "bad") + } + return context.Background(), nil + } + rpcServer, err := NewServer( + logger, + WithAuthHandler(CredentialsTypeAPIKey, fakeAuthHandler), + WithEnsureAuthedHandler(fakeEnsureAuthedHandler), + ) + test.That(t, err, test.ShouldBeNil) + + err = rpcServer.RegisterServiceServer( + context.Background(), + &pb.EchoService_ServiceDesc, + &echoserver.Server{}, + pb.RegisterEchoServiceHandlerFromEndpoint, + ) + test.That(t, err, test.ShouldBeNil) + + httpListener, err := net.Listen("tcp", "localhost:0") + test.That(t, err, test.ShouldBeNil) + + errChan := make(chan error, 1) + go func() { + errChan <- rpcServer.Serve(httpListener) + }() + defer func() { + test.That(t, rpcServer.Stop(), test.ShouldBeNil) + err = <-errChan + test.That(t, err, test.ShouldBeNil) + }() + + conn, err := Dial( + context.Background(), + httpListener.Addr().String(), + logger, + WithForceDirectGRPC(), + WithInsecure(), + WithDialDebug(), + // Have to pass some creds to pass "authentication required" check. + WithEntityCredentials("foo", Credentials{Type: CredentialsTypeAPIKey, Payload: "bar"}), + ) + test.That(t, err, test.ShouldBeNil) + defer func() { + test.That(t, conn.Close(), test.ShouldBeNil) + }() + + // Test that first echo fails and server's `Authenticate` is called once. + client := pb.NewEchoServiceClient(conn) + _, err = client.Echo(context.Background(), &pb.EchoRequest{Message: "hello"}) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "Unauthenticated") + test.That(t, timesAuthed, test.ShouldEqual, 1) + test.That(t, timesAuthEnsured, test.ShouldEqual, 1) + + // Test that second echo succeeds and `Authenticate` is called again (reauth). + echoResp, err := client.Echo(context.Background(), &pb.EchoRequest{Message: "hello"}) + test.That(t, err, test.ShouldBeNil) + test.That(t, echoResp.GetMessage(), test.ShouldEqual, "hello") + test.That(t, timesAuthed, test.ShouldEqual, 2) + test.That(t, timesAuthEnsured, test.ShouldEqual, 2) +} diff --git a/rpc/server.go b/rpc/server.go index ffec9684..83cd8a25 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -138,6 +138,7 @@ type simpleServer struct { tlsAuthHandler func(ctx context.Context, entities ...string) error authHandlersForCreds map[CredentialsType]credAuthHandlers authToHandler AuthenticateToHandler + ensureAuthedHandler func(ctx context.Context) (context.Context, error) // authAudience is the JWT audience (aud) that will be used/expected // for our service. @@ -296,6 +297,7 @@ func NewServer(logger golog.Logger, opts ...ServerOption) (Server, error) { authToHandler: sOpts.authToHandler, authAudience: sOpts.authAudience, authIssuer: sOpts.authIssuer, + ensureAuthedHandler: sOpts.ensureAuthedHandler, exemptMethods: make(map[string]bool), publicMethods: make(map[string]bool), tlsConfig: sOpts.tlsConfig, diff --git a/rpc/server_auth.go b/rpc/server_auth.go index 351b67e6..2150ae12 100644 --- a/rpc/server_auth.go +++ b/rpc/server_auth.go @@ -284,6 +284,11 @@ func (ss *simpleServer) tryAuth(ctx context.Context) (context.Context, error) { } func (ss *simpleServer) ensureAuthed(ctx context.Context) (context.Context, error) { + // Use handler if set (only used for testing). + if ss.ensureAuthedHandler != nil { + return ss.ensureAuthedHandler(ctx) + } + tokenString, err := tokenFromContext(ctx) if err != nil { // check TLS state diff --git a/rpc/server_options.go b/rpc/server_options.go index 3f053e4c..2d41a3d9 100644 --- a/rpc/server_options.go +++ b/rpc/server_options.go @@ -67,6 +67,10 @@ type serverOptions struct { // stats monitoring on the connections. statsHandler stats.Handler + // ensureAuthedHandler is the callback used to ensure that the context of an + // incoming RPC request is properly authenticated. + ensureAuthedHandler func(ctx context.Context) (context.Context, error) + unknownStreamDesc *grpc.StreamDesc } @@ -358,6 +362,17 @@ func WithAuthHandler(forType CredentialsType, handler AuthHandler) ServerOption }) } +// WithEnsureAuthedHandler returns a ServerOptions which adds custom logic for +// the ensuring of authentication on each incoming request. Can only be used +// in testing environments (will produce an error when ensuring authentication +// otherwise). +func WithEnsureAuthedHandler(eah func(ctx context.Context) (context.Context, error)) ServerOption { + return newFuncServerOption(func(o *serverOptions) error { + o.ensureAuthedHandler = eah + return nil + }) +} + // WithEntityDataLoader returns a ServerOption which adds an entity data loader // associated to the given credential type to use for loading data after the signed // access token has been verified.