diff --git a/interceptors/auth/auth.go b/interceptors/auth/auth.go index b015d60e5..604bb6fe3 100644 --- a/interceptors/auth/auth.go +++ b/interceptors/auth/auth.go @@ -5,11 +5,18 @@ package auth import ( "context" + "errors" middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" "google.golang.org/grpc" ) +// ErrNoAuthOverrideMatch is to support partial AuthFuncOverride implementations. +// If your service implements AuthFuncOverride and returns this error, we would +// proceed the authentication using the configured AuthFunc and ignore the error. +// Any other error would be returned directly by the interceptor. +var ErrNoAuthOverrideMatch = errors.New("no AuthFuncOverride match") + // AuthFunc is the pluggable function that performs authentication. // // The passed in `Context` will contain the gRPC metadata.MD object (for header-based authentication) and @@ -37,11 +44,16 @@ func UnaryServerInterceptor(authFunc AuthFunc) grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { var newCtx context.Context var err error - if overrideSrv, ok := info.Server.(ServiceAuthFuncOverride); ok { + + overrideSrv, ok := info.Server.(ServiceAuthFuncOverride) + if ok { newCtx, err = overrideSrv.AuthFuncOverride(ctx, info.FullMethod) - } else { + } + + if !ok || errors.Is(err, ErrNoAuthOverrideMatch) { newCtx, err = authFunc(ctx) } + if err != nil { return nil, err } @@ -55,11 +67,16 @@ func StreamServerInterceptor(authFunc AuthFunc) grpc.StreamServerInterceptor { return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { var newCtx context.Context var err error - if overrideSrv, ok := srv.(ServiceAuthFuncOverride); ok { + + overrideSrv, ok := srv.(ServiceAuthFuncOverride) + if ok { newCtx, err = overrideSrv.AuthFuncOverride(stream.Context(), info.FullMethod) - } else { + } + + if !ok || errors.Is(err, ErrNoAuthOverrideMatch) { newCtx, err = authFunc(stream.Context()) } + if err != nil { return err } diff --git a/interceptors/auth/auth_test.go b/interceptors/auth/auth_test.go index d30cc8a57..4b319d6d2 100644 --- a/interceptors/auth/auth_test.go +++ b/interceptors/auth/auth_test.go @@ -147,11 +147,17 @@ func (s *AuthTestSuite) TestStream_PassesWithPerRpcCredentials() { type authOverrideTestService struct { testpb.TestServiceServer - T *testing.T + T *testing.T + OverrideNoMatchingError bool } func (s *authOverrideTestService) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) { assert.NotEmpty(s.T, fullMethodName, "method name of caller is passed around") + + if s.OverrideNoMatchingError { + return nil, auth.ErrNoAuthOverrideMatch + } + return buildDummyAuthFunction("bearer", overrideAuthToken)(ctx) } @@ -159,7 +165,11 @@ func TestAuthOverrideTestSuite(t *testing.T) { authFunc := buildDummyAuthFunction("bearer", commonAuthToken) s := &AuthOverrideTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ - TestService: &authOverrideTestService{&assertingPingService{&testpb.TestPingService{}, t}, t}, + TestService: &authOverrideTestService{ + &assertingPingService{&testpb.TestPingService{}, t}, + t, + false, + }, ServerOpts: []grpc.ServerOption{ grpc.StreamInterceptor(auth.StreamServerInterceptor(authFunc)), grpc.UnaryInterceptor(auth.UnaryServerInterceptor(authFunc)), @@ -169,17 +179,46 @@ func TestAuthOverrideTestSuite(t *testing.T) { suite.Run(t, s) } +func TestAuthOverrideNotMatchingErrTestSuite(t *testing.T) { + authFunc := buildDummyAuthFunction("bearer", commonAuthToken) + s := &AuthOverrideTestSuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + TestService: &authOverrideTestService{ + &assertingPingService{&testpb.TestPingService{}, t}, + t, + true, + }, + ServerOpts: []grpc.ServerOption{ + grpc.StreamInterceptor(auth.StreamServerInterceptor(authFunc)), + grpc.UnaryInterceptor(auth.UnaryServerInterceptor(authFunc)), + }, + }, + WithOverrideNoMatchError: true, + } + suite.Run(t, s) +} + type AuthOverrideTestSuite struct { *testpb.InterceptorTestSuite + WithOverrideNoMatchError bool } func (s *AuthOverrideTestSuite) TestUnary_PassesAuth() { - _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", overrideAuthToken), testpb.GoodPing) + selectedToken := overrideAuthToken + if s.WithOverrideNoMatchError { + selectedToken = commonAuthToken + } + _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", selectedToken), testpb.GoodPing) require.NoError(s.T(), err, "no error must occur") } func (s *AuthOverrideTestSuite) TestStream_PassesAuth() { - stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", overrideAuthToken), testpb.GoodPingList) + selectedToken := overrideAuthToken + if s.WithOverrideNoMatchError { + selectedToken = commonAuthToken + } + + stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", selectedToken), testpb.GoodPingList) require.NoError(s.T(), err, "should not fail on establishing the stream") pong, err := stream.Recv() require.NoError(s.T(), err, "no error must occur") diff --git a/interceptors/auth/doc.go b/interceptors/auth/doc.go index 27cadbc61..c09548a55 100644 --- a/interceptors/auth/doc.go +++ b/interceptors/auth/doc.go @@ -15,7 +15,8 @@ The middleware takes a user-customizable `AuthFunc`, which can be customized to auth information from the request. The extracted information can be put in the `context.Context` of handlers downstream for retrieval. -It also allows for per-service implementation overrides of `AuthFunc`. See `ServiceAuthFuncOverride`. +It also allows for per-service implementation overrides of `AuthFunc`. See `ServiceAuthFuncOverride`. In addition, +it supports partial per service override of the `AuthFunc` by using `ErrNoAuthOverrideMatch`. Please see examples for simple examples of use. */ diff --git a/interceptors/logging/examples/go.mod b/interceptors/logging/examples/go.mod index 984e9c48f..02c9d885d 100644 --- a/interceptors/logging/examples/go.mod +++ b/interceptors/logging/examples/go.mod @@ -11,22 +11,22 @@ require ( github.com/sirupsen/logrus v1.9.0 go.uber.org/zap v1.24.0 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 - google.golang.org/grpc v1.53.0 + google.golang.org/grpc v1.61.1 k8s.io/klog/v2 v2.90.1 ) require ( github.com/go-logfmt/logfmt v0.5.1 // indirect - github.com/golang/protobuf v1.5.2 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect - golang.org/x/net v0.14.0 // indirect - golang.org/x/sys v0.11.0 // indirect - golang.org/x/text v0.12.0 // indirect - google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect - google.golang.org/protobuf v1.31.0 // indirect + golang.org/x/net v0.21.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/text v0.14.0 // indirect + google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 // indirect + google.golang.org/protobuf v1.32.0 // indirect ) replace github.com/grpc-ecosystem/go-grpc-middleware/v2 => ../../../ diff --git a/interceptors/logging/examples/go.sum b/interceptors/logging/examples/go.sum index ff873eb00..e29393968 100644 --- a/interceptors/logging/examples/go.sum +++ b/interceptors/logging/examples/go.sum @@ -14,6 +14,7 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= @@ -48,25 +49,31 @@ golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnL golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f h1:BWUVssLB0HVOSY78gIdvk1dTVYtT1y8SBWtPYuTJ/6w= google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM= +google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17/go.mod h1:J7XzRzVy1+IPwWHZUzoD0IccYZIrXILAQpc+Qy9CMhY= google.golang.org/grpc v1.53.0 h1:LAv2ds7cmFV/XTS3XG1NneeENYrXGmorPxsBbptIjNc= google.golang.org/grpc v1.53.0/go.mod h1:OnIrk0ipVdj4N5d9IUoFUx72/VlD7+jUsHwZgwSMQpw= +google.golang.org/grpc v1.61.1/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=