Skip to content

Commit

Permalink
reuse serviceinfo.StreamingMode
Browse files Browse the repository at this point in the history
  • Loading branch information
DMwangnima committed Sep 5, 2024
1 parent cbdff65 commit 3f8df3e
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 19 deletions.
2 changes: 1 addition & 1 deletion client/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (kc *kClient) Stream(ctx context.Context, method string, request, response
ctx, ri, _ = kc.initRPCInfo(ctx, method, 0, nil)

rpcinfo.AsMutableRPCConfig(ri.Config()).SetInteractionMode(rpcinfo.Streaming)
rpcinfo.AsMutableRPCConfig(ri.Config()).SetStreamingMode(rpcinfo.StreamingMode(kc.getStreamingMode(ri)))
rpcinfo.AsMutableRPCConfig(ri.Config()).SetStreamingMode(kc.getStreamingMode(ri))
ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)

ctx = kc.opt.TracerCtl.DoStart(ctx, ri)
Expand Down
60 changes: 60 additions & 0 deletions client/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package client
import (
"context"
"errors"
"github.com/cloudwego/kitex/pkg/endpoint"
"io"
"testing"

Expand Down Expand Up @@ -637,3 +638,62 @@ func Test_isRPCError(t *testing.T) {
test.Assert(t, isRPCError(errors.New("error")))
})
}

func Test_kClient_Stream_SetStreamingMode(t *testing.T) {
testcases := []struct {
method string
mode serviceinfo.StreamingMode
}{
{
method: "None",
mode: serviceinfo.StreamingNone,
},
{
method: "Unary",
mode: serviceinfo.StreamingUnary,
},
{
method: "ClientStreaming",
mode: serviceinfo.StreamingClient,
},
{
method: "ServerStreaming",
mode: serviceinfo.StreamingServer,
},
{
method: "BidiStreaming",
mode: serviceinfo.StreamingBidirectional,
},
}
info := &serviceinfo.ServiceInfo{Methods: make(map[string]serviceinfo.MethodInfo)}
for _, tc := range testcases {
info.Methods[tc.method] = serviceinfo.NewMethodInfo(
nil, nil, nil, false,
serviceinfo.WithStreamingMode(tc.mode),
)
}

for _, tc := range testcases {
t.Run(tc.method, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
opts := append(newOpts(ctrl), WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) (err error) {
ri := rpcinfo.GetRPCInfo(ctx)
test.Assert(t, ri.Config().StreamingMode() == tc.mode)
return nil
}
}))

kc := &kClient{
opt: client.NewOptions(opts),
svcInfo: info,
}

_ = kc.init()

err := kc.Stream(context.Background(), tc.method, req, resp)
test.Assert(t, err == nil, err)
})
}
}
2 changes: 1 addition & 1 deletion pkg/remote/codec/grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remot
d, err := decodeGRPCFrame(ctx, in)
// For ClientStreaming, server may return an err(e.g. status) as trailer frame after calling SendAndClose.
// We need to receive this trailer frame.
if message.RPCInfo().Config().StreamingMode() == rpcinfo.ClientStreaming && message.RPCRole() == remote.Client && err == nil {
if message.RPCInfo().Config().StreamingMode() == serviceinfo.StreamingClient && message.RPCRole() == remote.Client && err == nil {
// Receive trailer frame
// If err == nil, wrong gRPC protocol implementation.
// If err == io.EOF, it means server returns nil, just ignore io.EOF.
Expand Down
2 changes: 1 addition & 1 deletion pkg/rpcinfo/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ type RPCConfig interface {
TransportProtocol() transport.Protocol
InteractionMode() InteractionMode
PayloadCodec() serviceinfo.PayloadCodec
StreamingMode() StreamingMode
StreamingMode() serviceinfo.StreamingMode
}

// Invocation contains specific information about the call.
Expand Down
4 changes: 2 additions & 2 deletions pkg/rpcinfo/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type MockRPCConfig struct {
IOBufferSizeFunc func() (r int)
TransportProtocolFunc func() transport.Protocol
InteractionModeFunc func() (r rpcinfo.InteractionMode)
StreamingModeFunc func() (r rpcinfo.StreamingMode)
StreamingModeFunc func() (r serviceinfo.StreamingMode)
}

func (m *MockRPCConfig) PayloadCodec() serviceinfo.PayloadCodec {
Expand Down Expand Up @@ -91,7 +91,7 @@ func (m *MockRPCConfig) TransportProtocol() (r transport.Protocol) {
return
}

func (m *MockRPCConfig) StreamingMode() (r rpcinfo.StreamingMode) {
func (m *MockRPCConfig) StreamingMode() (r serviceinfo.StreamingMode) {
if m.StreamingModeFunc != nil {
return m.StreamingModeFunc()
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/rpcinfo/mutable.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type MutableRPCConfig interface {
CopyFrom(from RPCConfig)
ImmutableView() RPCConfig
SetPayloadCodec(codec serviceinfo.PayloadCodec)
SetStreamingMode(mode StreamingMode)
SetStreamingMode(mode serviceinfo.StreamingMode)
}

// MutableRPCStats is used to change the information in the RPCStats.
Expand Down
16 changes: 3 additions & 13 deletions pkg/rpcinfo/rpcconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,6 @@ const (
Streaming InteractionMode = 2
)

type StreamingMode int32

const (
None StreamingMode = 0b0000
Unary StreamingMode = 0b0001
ClientStreaming StreamingMode = 0b0010
ServerStreaming StreamingMode = 0b0100
BidirectionalStreaming StreamingMode = 0b0110
)

// rpcConfig is a set of configurations used during RPC calls.
type rpcConfig struct {
readOnlyMask int
Expand All @@ -76,7 +66,7 @@ type rpcConfig struct {
transportProtocol transport.Protocol
interactionMode InteractionMode
payloadCodec serviceinfo.PayloadCodec
streamingMode StreamingMode
streamingMode serviceinfo.StreamingMode
}

func init() {
Expand Down Expand Up @@ -204,11 +194,11 @@ func (r *rpcConfig) PayloadCodec() serviceinfo.PayloadCodec {
return r.payloadCodec
}

func (r *rpcConfig) SetStreamingMode(mode StreamingMode) {
func (r *rpcConfig) SetStreamingMode(mode serviceinfo.StreamingMode) {
r.streamingMode = mode
}

func (r *rpcConfig) StreamingMode() StreamingMode {
func (r *rpcConfig) StreamingMode() serviceinfo.StreamingMode {
return r.streamingMode
}

Expand Down

0 comments on commit 3f8df3e

Please sign in to comment.