From ed9a54fe6efc9b263abe0320ac33e72d9e158efe Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Wed, 4 Sep 2024 15:51:00 +0800 Subject: [PATCH] fix(gRPC): retrieve status or biz error for non-ServerStreaming --- client/stream.go | 1 + client/stream_test.go | 60 +++++++++++++++++++++++++++++++++++ pkg/remote/codec/grpc/grpc.go | 17 ++++++++++ pkg/rpcinfo/interface.go | 1 + pkg/rpcinfo/mocks_test.go | 8 +++++ pkg/rpcinfo/mutable.go | 1 + pkg/rpcinfo/rpcconfig.go | 9 ++++++ 7 files changed, 97 insertions(+) diff --git a/client/stream.go b/client/stream.go index a184ac7ecc..161abb24bd 100644 --- a/client/stream.go +++ b/client/stream.go @@ -52,6 +52,7 @@ func (kc *kClient) Stream(ctx context.Context, method string, request, response ctx, ri, _ = kc.initRPCInfo(ctx, method, 0, nil) rpcinfo.AsMutableRPCConfig(ri.Config()).SetInteractionMode(rpcinfo.Streaming) + rpcinfo.AsMutableRPCConfig(ri.Config()).SetStreamingMode(kc.getStreamingMode(ri)) ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) ctx = kc.opt.TracerCtl.DoStart(ctx, ri) diff --git a/client/stream_test.go b/client/stream_test.go index 9c07647d2a..63e1098406 100644 --- a/client/stream_test.go +++ b/client/stream_test.go @@ -29,6 +29,7 @@ import ( mocksnet "github.com/cloudwego/kitex/internal/mocks/net" mock_remote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/remotecli" @@ -637,3 +638,62 @@ func Test_isRPCError(t *testing.T) { test.Assert(t, isRPCError(errors.New("error"))) }) } + +func Test_kClient_Stream_SetStreamingMode(t *testing.T) { + testcases := []struct { + method string + mode serviceinfo.StreamingMode + }{ + { + method: "None", + mode: serviceinfo.StreamingNone, + }, + { + method: "Unary", + mode: serviceinfo.StreamingUnary, + }, + { + method: "ClientStreaming", + mode: serviceinfo.StreamingClient, + }, + { + method: "ServerStreaming", + mode: serviceinfo.StreamingServer, + }, + { + method: "BidiStreaming", + mode: serviceinfo.StreamingBidirectional, + }, + } + info := &serviceinfo.ServiceInfo{Methods: make(map[string]serviceinfo.MethodInfo)} + for _, tc := range testcases { + info.Methods[tc.method] = serviceinfo.NewMethodInfo( + nil, nil, nil, false, + serviceinfo.WithStreamingMode(tc.mode), + ) + } + + for _, tc := range testcases { + t.Run(tc.method, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + opts := append(newOpts(ctrl), WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, req, resp interface{}) (err error) { + ri := rpcinfo.GetRPCInfo(ctx) + test.Assert(t, ri.Config().StreamingMode() == tc.mode) + return nil + } + })) + + kc := &kClient{ + opt: client.NewOptions(opts), + svcInfo: info, + } + + _ = kc.init() + + err := kc.Stream(context.Background(), tc.method, req, resp) + test.Assert(t, err == nil, err) + }) + } +} diff --git a/pkg/remote/codec/grpc/grpc.go b/pkg/remote/codec/grpc/grpc.go index 5cbc94fc2a..5a5428c320 100644 --- a/pkg/remote/codec/grpc/grpc.go +++ b/pkg/remote/codec/grpc/grpc.go @@ -21,6 +21,7 @@ import ( "encoding/binary" "errors" "fmt" + "io" "github.com/cloudwego/fastpb" "google.golang.org/protobuf/proto" @@ -176,6 +177,22 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) { d, err := decodeGRPCFrame(ctx, in) + // For ClientStreaming, server may return an err(e.g. status) as trailer frame after calling SendAndClose. + // We need to receive this trailer frame. + if message.RPCInfo().Config().StreamingMode() == serviceinfo.StreamingClient && message.RPCRole() == remote.Client && err == nil { + // Receive trailer frame + // If err == nil, wrong gRPC protocol implementation. + // If err == io.EOF, it means server returns nil, just ignore io.EOF. + // If err != io.EOF, it means server returns status err or BizStatusErr, or other gRPC transport error came out, + // we need to throw it to users. + _, err = decodeGRPCFrame(ctx, in) + if err == nil { + return errors.New("KITEX: grpc client streaming protocol violation: get , want ") + } + if err == io.EOF { + err = nil + } + } if rpcStats := rpcinfo.AsMutableRPCStats(message.RPCInfo().Stats()); rpcStats != nil { // record recv size, even when err != nil (0 is recorded to the lastRecvSize) rpcStats.IncrRecvSize(uint64(len(d))) diff --git a/pkg/rpcinfo/interface.go b/pkg/rpcinfo/interface.go index 05b8f4d379..f9b5b2b482 100644 --- a/pkg/rpcinfo/interface.go +++ b/pkg/rpcinfo/interface.go @@ -80,6 +80,7 @@ type RPCConfig interface { TransportProtocol() transport.Protocol InteractionMode() InteractionMode PayloadCodec() serviceinfo.PayloadCodec + StreamingMode() serviceinfo.StreamingMode } // Invocation contains specific information about the call. diff --git a/pkg/rpcinfo/mocks_test.go b/pkg/rpcinfo/mocks_test.go index 781b6b360f..3888ce5ee6 100644 --- a/pkg/rpcinfo/mocks_test.go +++ b/pkg/rpcinfo/mocks_test.go @@ -36,6 +36,7 @@ type MockRPCConfig struct { IOBufferSizeFunc func() (r int) TransportProtocolFunc func() transport.Protocol InteractionModeFunc func() (r rpcinfo.InteractionMode) + StreamingModeFunc func() (r serviceinfo.StreamingMode) } func (m *MockRPCConfig) PayloadCodec() serviceinfo.PayloadCodec { @@ -90,6 +91,13 @@ func (m *MockRPCConfig) TransportProtocol() (r transport.Protocol) { return } +func (m *MockRPCConfig) StreamingMode() (r serviceinfo.StreamingMode) { + if m.StreamingModeFunc != nil { + return m.StreamingModeFunc() + } + return +} + type MockRPCStats struct{} func (m *MockRPCStats) Record(context.Context, stats.Event, stats.Status, string) {} diff --git a/pkg/rpcinfo/mutable.go b/pkg/rpcinfo/mutable.go index 9c6a6c8802..035a5b30a0 100644 --- a/pkg/rpcinfo/mutable.go +++ b/pkg/rpcinfo/mutable.go @@ -52,6 +52,7 @@ type MutableRPCConfig interface { CopyFrom(from RPCConfig) ImmutableView() RPCConfig SetPayloadCodec(codec serviceinfo.PayloadCodec) + SetStreamingMode(mode serviceinfo.StreamingMode) } // MutableRPCStats is used to change the information in the RPCStats. diff --git a/pkg/rpcinfo/rpcconfig.go b/pkg/rpcinfo/rpcconfig.go index 3c0f654ca6..478a1fd8e4 100644 --- a/pkg/rpcinfo/rpcconfig.go +++ b/pkg/rpcinfo/rpcconfig.go @@ -66,6 +66,7 @@ type rpcConfig struct { transportProtocol transport.Protocol interactionMode InteractionMode payloadCodec serviceinfo.PayloadCodec + streamingMode serviceinfo.StreamingMode } func init() { @@ -193,6 +194,14 @@ func (r *rpcConfig) PayloadCodec() serviceinfo.PayloadCodec { return r.payloadCodec } +func (r *rpcConfig) SetStreamingMode(mode serviceinfo.StreamingMode) { + r.streamingMode = mode +} + +func (r *rpcConfig) StreamingMode() serviceinfo.StreamingMode { + return r.streamingMode +} + // Clone returns a copy of the current rpcConfig. func (r *rpcConfig) Clone() MutableRPCConfig { r2 := rpcConfigPool.Get().(*rpcConfig)