diff --git a/interceptors/protovalidate/example_stream_test.go b/interceptors/protovalidate/example_stream_test.go index 01ee2b992..0e3f820f8 100644 --- a/interceptors/protovalidate/example_stream_test.go +++ b/interceptors/protovalidate/example_stream_test.go @@ -10,6 +10,8 @@ import ( protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate" testvalidatev1 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate/v1" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type StreamService struct { @@ -32,7 +34,13 @@ func ExampleStreamServerInterceptor() { protovalidate_middleware.StreamServerInterceptor(validator, protovalidate_middleware.WithIgnoreMessages( (&testvalidatev1.SendStreamRequest{}).ProtoReflect().Type(), - )), + ), + protovalidate_middleware.WithErrorConverter( + func(err error) error { + return status.Error(codes.InvalidArgument, err.Error()) + }, + ), + ), ), ) svc = &StreamService{} diff --git a/interceptors/protovalidate/example_unary_test.go b/interceptors/protovalidate/example_unary_test.go index a6fe75d6e..0393a0736 100644 --- a/interceptors/protovalidate/example_unary_test.go +++ b/interceptors/protovalidate/example_unary_test.go @@ -11,6 +11,8 @@ import ( protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate" testvalidatev1 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate/v1" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type UnaryService struct { @@ -34,6 +36,11 @@ func ExampleUnaryServerInterceptor() { protovalidate_middleware.WithIgnoreMessages( (&testvalidatev1.SendRequest{}).ProtoReflect().Type(), ), + protovalidate_middleware.WithErrorConverter( + func(err error) error { + return status.Error(codes.InvalidArgument, err.Error()) + }, + ), ), ), ) diff --git a/interceptors/protovalidate/options.go b/interceptors/protovalidate/options.go index 764b55e74..1bd90aa81 100644 --- a/interceptors/protovalidate/options.go +++ b/interceptors/protovalidate/options.go @@ -8,11 +8,25 @@ package protovalidate import ( "golang.org/x/exp/slices" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/reflect/protoreflect" ) +// DefaultErrorConverter returns InvalidArgument status with error message from validator. +func DefaultErrorConverter(err error) error { + return status.Error(codes.InvalidArgument, err.Error()) +} + +var ( + defaultOptions = &options{ + errorConverter: DefaultErrorConverter, + } +) + type options struct { ignoreMessages []protoreflect.MessageType + errorConverter ErrorConverter } // An Option lets you add options to protovalidate interceptors using With* funcs. @@ -20,6 +34,7 @@ type Option func(*options) func evaluateOpts(opts []Option) *options { optCopy := &options{} + *optCopy = *defaultOptions for _, o := range opts { o(optCopy) } @@ -39,3 +54,15 @@ func (o *options) shouldIgnoreMessage(m protoreflect.MessageType) bool { return m == t }) } + +// ErrorConverter function customize the error returned by protovalidate.Validator. +type ErrorConverter = func(err error) error + +// WithErrorConverter customizes the function for mapping errors. +// +// By default, DefaultErrorConverter used. +func WithErrorConverter(errorConverter ErrorConverter) Option { + return func(o *options) { + o.errorConverter = errorConverter + } +} diff --git a/interceptors/protovalidate/protovalidate.go b/interceptors/protovalidate/protovalidate.go index cf337db11..a7ada3126 100644 --- a/interceptors/protovalidate/protovalidate.go +++ b/interceptors/protovalidate/protovalidate.go @@ -5,34 +5,38 @@ package protovalidate import ( "context" - "errors" "github.com/bufbuild/protovalidate-go" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" ) +func validateMessage(validator *protovalidate.Validator, o *options, req any) error { + msg := req.(proto.Message) + + if o.shouldIgnoreMessage(msg.ProtoReflect().Type()) { + return nil + } + + if err := validator.Validate(msg); err != nil { + return o.errorConverter(err) + } + + return nil +} + // UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages. func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option) grpc.UnaryServerInterceptor { + o := evaluateOpts(opts) + return func( ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (resp interface{}, err error) { - o := evaluateOpts(opts) - switch msg := req.(type) { - case proto.Message: - if o.shouldIgnoreMessage(msg.ProtoReflect().Type()) { - break - } - if err = validator.Validate(msg); err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) - } - default: - return nil, errors.New("unsupported message type") + if err := validateMessage(validator, o, req); err != nil { + return nil, err } return handler(ctx, req) @@ -41,55 +45,40 @@ func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option) // StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages. func StreamServerInterceptor(validator *protovalidate.Validator, opts ...Option) grpc.StreamServerInterceptor { + o := evaluateOpts(opts) + return func( srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { - ctx := stream.Context() - - wrapped := wrapServerStream(stream) - wrapped.wrappedContext = ctx - wrapped.validator = validator - wrapped.options = evaluateOpts(opts) + wrapped := wrapServerStream(stream, validator, o) return handler(srv, wrapped) } } func (w *wrappedServerStream) RecvMsg(m interface{}) error { - if err := w.ServerStream.RecvMsg(m); err != nil { + if err := validateMessage(w.validator, w.options, m); err != nil { return err } - msg := m.(proto.Message) - if w.options.shouldIgnoreMessage(msg.ProtoReflect().Type()) { - return nil - } - if err := w.validator.Validate(msg); err != nil { - return status.Error(codes.InvalidArgument, err.Error()) - } - - return nil + return w.ServerStream.RecvMsg(m) } -// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context. +// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows to validate messages. type wrappedServerStream struct { grpc.ServerStream - // wrappedContext is the wrapper's own Context. You can assign it. - wrappedContext context.Context - validator *protovalidate.Validator options *options } -// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context() -func (w *wrappedServerStream) Context() context.Context { - return w.wrappedContext -} - -// wrapServerStream returns a ServerStream that has the ability to overwrite context. -func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream { - return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()} +// wrapServerStream returns a ServerStream that has the ability to validate messages. +func wrapServerStream( + stream grpc.ServerStream, + validator *protovalidate.Validator, + options *options, +) *wrappedServerStream { + return &wrappedServerStream{ServerStream: stream, validator: validator, options: options} } diff --git a/interceptors/protovalidate/protovalidate_test.go b/interceptors/protovalidate/protovalidate_test.go index f626f395c..2a8a9f896 100644 --- a/interceptors/protovalidate/protovalidate_test.go +++ b/interceptors/protovalidate/protovalidate_test.go @@ -5,6 +5,7 @@ package protovalidate_test import ( "context" + "fmt" "log" "net" "testing" @@ -19,12 +20,15 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" - "google.golang.org/protobuf/reflect/protoreflect" ) +func customErrorConverter(err error) error { + return fmt.Errorf("my custom wrapper: %w", err) +} + func TestUnaryServerInterceptor(t *testing.T) { validator, err := protovalidate.New() - assert.Nil(t, err) + assert.NoError(t, err) interceptor := protovalidate_middleware.UnaryServerInterceptor(validator) @@ -38,7 +42,7 @@ func TestUnaryServerInterceptor(t *testing.T) { } resp, err := interceptor(context.TODO(), testvalidate.GoodUnaryRequest, info, handler) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, resp, "good") }) @@ -62,9 +66,24 @@ func TestUnaryServerInterceptor(t *testing.T) { } resp, err := interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, resp, "good") }) + + interceptor = protovalidate_middleware.UnaryServerInterceptor(validator, + protovalidate_middleware.WithErrorConverter(customErrorConverter), + ) + + t.Run("custom_error_converter", func(t *testing.T) { + info := &grpc.UnaryServerInfo{ + FullMethod: "FakeMethod", + } + + _, err = interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler) + assert.Error(t, err) + assert.Equal(t, codes.Unknown, status.Code(err)) + assert.EqualError(t, err, "my custom wrapper: validation error:\n - message: value must be a valid email address [string.email]") + }) } type server struct { @@ -84,7 +103,7 @@ func (g *server) SendStream( const bufSize = 1024 * 1024 -func startGrpcServer(t *testing.T, ignoreMessages ...protoreflect.MessageType) *grpc.ClientConn { +func startGrpcServer(t *testing.T, opts ...protovalidate_middleware.Option) *grpc.ClientConn { lis := bufconn.Listen(bufSize) validator, err := protovalidate.New() @@ -92,9 +111,7 @@ func startGrpcServer(t *testing.T, ignoreMessages ...protoreflect.MessageType) * s := grpc.NewServer( grpc.StreamInterceptor( - protovalidate_middleware.StreamServerInterceptor(validator, - protovalidate_middleware.WithIgnoreMessages(ignoreMessages...), - ), + protovalidate_middleware.StreamServerInterceptor(validator, opts...), ), ) testvalidatev1.RegisterTestValidateServiceServer(s, &server{}) @@ -133,7 +150,7 @@ func TestStreamServerInterceptor(t *testing.T) { ) _, err := client.SendStream(context.Background(), testvalidate.GoodStreamRequest) - assert.Nil(t, err) + assert.NoError(t, err) }) t.Run("invalid_email", func(t *testing.T) { @@ -151,13 +168,31 @@ func TestStreamServerInterceptor(t *testing.T) { t.Run("invalid_email_ignored", func(t *testing.T) { client := testvalidatev1.NewTestValidateServiceClient( - startGrpcServer(t, testvalidate.BadStreamRequest.ProtoReflect().Type()), + startGrpcServer( + t, + protovalidate_middleware.WithIgnoreMessages(testvalidate.BadStreamRequest.ProtoReflect().Type()), + ), ) out, err := client.SendStream(context.Background(), testvalidate.BadStreamRequest) - assert.Nil(t, err) + assert.NoError(t, err) _, err = out.Recv() - assert.Nil(t, err) + assert.NoError(t, err) + }) + + t.Run("custom_error_converter", func(t *testing.T) { + client := testvalidatev1.NewTestValidateServiceClient( + startGrpcServer(t, protovalidate_middleware.WithErrorConverter(customErrorConverter)), + ) + + out, err := client.SendStream(context.Background(), testvalidate.BadStreamRequest) + assert.NoError(t, err) + + _, err = out.Recv() + assert.Error(t, err) + st, _ := status.FromError(err) + assert.Equal(t, codes.Unknown, st.Code()) + assert.Equal(t, "my custom wrapper: validation error:\n - message: value must be a valid email address [string.email]", st.Message()) }) }