diff --git a/client/stream.go b/client/stream.go index a184ac7ecc..f6ccf5eb16 100644 --- a/client/stream.go +++ b/client/stream.go @@ -52,6 +52,7 @@ func (kc *kClient) Stream(ctx context.Context, method string, request, response ctx, ri, _ = kc.initRPCInfo(ctx, method, 0, nil) rpcinfo.AsMutableRPCConfig(ri.Config()).SetInteractionMode(rpcinfo.Streaming) + rpcinfo.AsMutableRPCConfig(ri.Config()).SetStreamingMode(rpcinfo.StreamingMode(kc.getStreamingMode(ri))) ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) ctx = kc.opt.TracerCtl.DoStart(ctx, ri) diff --git a/pkg/remote/codec/grpc/grpc.go b/pkg/remote/codec/grpc/grpc.go index ef8c640af5..9d5fe4248a 100644 --- a/pkg/remote/codec/grpc/grpc.go +++ b/pkg/remote/codec/grpc/grpc.go @@ -52,7 +52,6 @@ type protobufV2MsgCodec interface { type grpcCodec struct { ThriftCodec remote.PayloadCodec - ServiceInfo *serviceinfo.ServiceInfo } type CodecOption func(c *grpcCodec) @@ -63,12 +62,6 @@ func WithThriftCodec(t remote.PayloadCodec) CodecOption { } } -func WithServiceInfo(info *serviceinfo.ServiceInfo) CodecOption { - return func(c *grpcCodec) { - c.ServiceInfo = info - } -} - // NewGRPCCodec create grpc and protobuf codec func NewGRPCCodec(opts ...CodecOption) remote.Codec { codec := &grpcCodec{} @@ -183,9 +176,15 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo } func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) { - ri := rpcinfo.GetRPCInfo(ctx) d, err := decodeGRPCFrame(ctx, in) - if needCheckTrailer(ri, c.ServiceInfo) && err == nil { + // For ClientStreaming, server may return an err(e.g. status) as trailer frame after calling SendAndClose. + // We need to receive this trailer frame. + if message.RPCInfo().Config().StreamingMode() == rpcinfo.ClientStreaming && message.RPCRole() == remote.Client && err == nil { + // Receive trailer frame + // If err == nil, wrong gRPC protocol implementation. + // If err == io.EOF, it means server returns nil, just ignore io.EOF. + // If err != io.EOF, it means server returns status err or BizStatusErr, or other gRPC transport error came out, + // we need to throw it to users. _, err = decodeGRPCFrame(ctx, in) if err == nil { return errors.New("KITEX: grpc client streaming protocol violation: get , want ") @@ -242,16 +241,3 @@ func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remot func (c *grpcCodec) Name() string { return "grpc" } - -func needCheckTrailer(ri rpcinfo.RPCInfo, svcInfo *serviceinfo.ServiceInfo) bool { - // server-side - if svcInfo == nil { - return false - } - methodInfo := svcInfo.MethodInfo(ri.Invocation().MethodName()) - // is there possibility that methodInfo is nil? - if methodInfo == nil { - return false - } - return methodInfo.StreamingMode() == serviceinfo.StreamingClient -} diff --git a/pkg/remote/trans/nphttp2/client_handler.go b/pkg/remote/trans/nphttp2/client_handler.go index 5bbf213624..9b98df5fbb 100644 --- a/pkg/remote/trans/nphttp2/client_handler.go +++ b/pkg/remote/trans/nphttp2/client_handler.go @@ -40,11 +40,8 @@ func (f *cliTransHandlerFactory) NewTransHandler(opt *remote.ClientOption) (remo func newCliTransHandler(opt *remote.ClientOption) (*cliTransHandler, error) { return &cliTransHandler{ - opt: opt, - codec: grpc.NewGRPCCodec( - grpc.WithThriftCodec(opt.PayloadCodec), - grpc.WithServiceInfo(opt.SvcInfo), - ), + opt: opt, + codec: grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)), }, nil } diff --git a/pkg/rpcinfo/interface.go b/pkg/rpcinfo/interface.go index 05b8f4d379..ed3b6fd362 100644 --- a/pkg/rpcinfo/interface.go +++ b/pkg/rpcinfo/interface.go @@ -80,6 +80,7 @@ type RPCConfig interface { TransportProtocol() transport.Protocol InteractionMode() InteractionMode PayloadCodec() serviceinfo.PayloadCodec + StreamingMode() StreamingMode } // Invocation contains specific information about the call. diff --git a/pkg/rpcinfo/mocks_test.go b/pkg/rpcinfo/mocks_test.go index 781b6b360f..a9d69860c1 100644 --- a/pkg/rpcinfo/mocks_test.go +++ b/pkg/rpcinfo/mocks_test.go @@ -36,6 +36,7 @@ type MockRPCConfig struct { IOBufferSizeFunc func() (r int) TransportProtocolFunc func() transport.Protocol InteractionModeFunc func() (r rpcinfo.InteractionMode) + StreamingModeFunc func() (r rpcinfo.StreamingMode) } func (m *MockRPCConfig) PayloadCodec() serviceinfo.PayloadCodec { @@ -90,6 +91,13 @@ func (m *MockRPCConfig) TransportProtocol() (r transport.Protocol) { return } +func (m *MockRPCConfig) StreamingMode() (r rpcinfo.StreamingMode) { + if m.StreamingModeFunc != nil { + return m.StreamingModeFunc() + } + return +} + type MockRPCStats struct{} func (m *MockRPCStats) Record(context.Context, stats.Event, stats.Status, string) {} diff --git a/pkg/rpcinfo/mutable.go b/pkg/rpcinfo/mutable.go index 9c6a6c8802..73469ae5ac 100644 --- a/pkg/rpcinfo/mutable.go +++ b/pkg/rpcinfo/mutable.go @@ -52,6 +52,7 @@ type MutableRPCConfig interface { CopyFrom(from RPCConfig) ImmutableView() RPCConfig SetPayloadCodec(codec serviceinfo.PayloadCodec) + SetStreamingMode(mode StreamingMode) } // MutableRPCStats is used to change the information in the RPCStats. diff --git a/pkg/rpcinfo/rpcconfig.go b/pkg/rpcinfo/rpcconfig.go index 3c0f654ca6..5fa0ef0b9c 100644 --- a/pkg/rpcinfo/rpcconfig.go +++ b/pkg/rpcinfo/rpcconfig.go @@ -56,6 +56,16 @@ const ( Streaming InteractionMode = 2 ) +type StreamingMode int32 + +const ( + None StreamingMode = 0b0000 + Unary StreamingMode = 0b0001 + ClientStreaming StreamingMode = 0b0010 + ServerStreaming StreamingMode = 0b0100 + BidirectionalStreaming StreamingMode = 0b0110 +) + // rpcConfig is a set of configurations used during RPC calls. type rpcConfig struct { readOnlyMask int @@ -66,6 +76,7 @@ type rpcConfig struct { transportProtocol transport.Protocol interactionMode InteractionMode payloadCodec serviceinfo.PayloadCodec + streamingMode StreamingMode } func init() { @@ -193,6 +204,14 @@ func (r *rpcConfig) PayloadCodec() serviceinfo.PayloadCodec { return r.payloadCodec } +func (r *rpcConfig) SetStreamingMode(mode StreamingMode) { + r.streamingMode = mode +} + +func (r *rpcConfig) StreamingMode() StreamingMode { + return r.streamingMode +} + // Clone returns a copy of the current rpcConfig. func (r *rpcConfig) Clone() MutableRPCConfig { r2 := rpcConfigPool.Get().(*rpcConfig)