Skip to content

Commit

Permalink
support interface overriding
Browse files Browse the repository at this point in the history
  • Loading branch information
DMwangnima committed Dec 9, 2023
1 parent 75ae4b2 commit fadf5e1
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 56 deletions.
41 changes: 21 additions & 20 deletions protocol/triple/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ func NewServer() *Server {
// Start TRIPLE server
func (s *Server) Start(invoker protocol.Invoker, info *server.ServiceInfo) {
var (
addr string
URL *common.URL
hanOpts []tri.HandlerOption
addr string
URL *common.URL
)
URL = invoker.GetURL()
addr = URL.Location
Expand Down Expand Up @@ -103,12 +102,13 @@ func (s *Server) Start(invoker protocol.Invoker, info *server.ServiceInfo) {
//srv.TLSConfig = cfg
// todo:// move tls config to handleService

hanOpts = getHanOpts(URL)
hanOpts := getHanOpts(URL)
intfName := URL.Interface()
if info != nil {
s.handleServiceWithInfo(invoker, info, hanOpts...)
s.saveServiceInfo(info)
s.handleServiceWithInfo(intfName, invoker, info, hanOpts...)
s.saveServiceInfo(intfName, info)
} else {
s.compatHandleService(URL, hanOpts...)
s.compatHandleService(intfName, hanOpts...)
}
reflection.Register(s)

Expand All @@ -134,11 +134,12 @@ func (s *Server) RefreshService(invoker protocol.Invoker, info *server.ServiceIn
panic(fmt.Sprintf("Unsupported serialization: %s", serialization))
}
hanOpts = getHanOpts(URL)
intfName := URL.Interface()
if info != nil {
s.handleServiceWithInfo(invoker, info, hanOpts...)
s.saveServiceInfo(info)
s.handleServiceWithInfo(intfName, invoker, info, hanOpts...)
s.saveServiceInfo(intfName, info)
} else {
s.compatHandleService(URL, hanOpts...)
s.compatHandleService(intfName, hanOpts...)
}
}

Expand Down Expand Up @@ -199,14 +200,14 @@ func waitTripleExporter(providerServices map[string]*config.ServiceConfig) {

// *Important*, this function is responsible for being compatible with old triple-gen code
// compatHandleService registers handler based on ServiceConfig and provider service.
func (s *Server) compatHandleService(url *common.URL, opts ...tri.HandlerOption) {
func (s *Server) compatHandleService(intefaceName string, opts ...tri.HandlerOption) {
providerServices := config.GetProviderConfig().Services
if len(providerServices) == 0 {
logger.Info("Provider service map is null")
}
//waitTripleExporter(providerServices)
for key, providerService := range providerServices {
if providerService.Interface != url.Interface() {
if providerService.Interface != intefaceName {
continue
}
// todo(DMwangnima): judge protocol type
Expand All @@ -230,25 +231,25 @@ func (s *Server) compatHandleService(url *common.URL, opts ...tri.HandlerOption)

// inject invoker, it has all invocation logics
ds.XXX_SetProxyImpl(invoker)
s.compatRegisterHandler(ds, opts...)
s.compatRegisterHandler(intefaceName, ds, opts...)
}
}

func (s *Server) compatRegisterHandler(svc dubbo3.Dubbo3GrpcService, opts ...tri.HandlerOption) {
func (s *Server) compatRegisterHandler(interfaceName string, svc dubbo3.Dubbo3GrpcService, opts ...tri.HandlerOption) {
desc := svc.XXX_ServiceDesc()
// init unary handlers
for _, method := range desc.Methods {
// please refer to protocol/triple/internal/proto/triple_gen/greettriple for procedure examples
// error could be ignored because base is empty string
procedure := joinProcedure(desc.ServiceName, method.MethodName)
procedure := joinProcedure(interfaceName, method.MethodName)
_ = s.triServer.RegisterCompatUnaryHandler(procedure, svc, tri.MethodHandler(method.Handler), opts...)
}

// init stream handlers
for _, stream := range desc.Streams {
// please refer to protocol/triple/internal/proto/triple_gen/greettriple for procedure examples
// error could be ignored because base is empty string
procedure := joinProcedure(desc.ServiceName, stream.StreamName)
procedure := joinProcedure(interfaceName, stream.StreamName)
var typ tri.StreamType
switch {
case stream.ClientStreams && stream.ServerStreams:
Expand All @@ -263,10 +264,10 @@ func (s *Server) compatRegisterHandler(svc dubbo3.Dubbo3GrpcService, opts ...tri
}

// handleServiceWithInfo injects invoker and create handler based on ServiceInfo
func (s *Server) handleServiceWithInfo(invoker protocol.Invoker, info *server.ServiceInfo, opts ...tri.HandlerOption) {
func (s *Server) handleServiceWithInfo(interfaceName string, invoker protocol.Invoker, info *server.ServiceInfo, opts ...tri.HandlerOption) {
for _, method := range info.Methods {
m := method
procedure := joinProcedure(info.InterfaceName, method.Name)
procedure := joinProcedure(interfaceName, method.Name)
switch m.Type {
case constant.CallUnary:
_ = s.triServer.RegisterUnaryHandler(
Expand Down Expand Up @@ -324,7 +325,7 @@ func (s *Server) handleServiceWithInfo(invoker protocol.Invoker, info *server.Se
}
}

func (s *Server) saveServiceInfo(info *server.ServiceInfo) {
func (s *Server) saveServiceInfo(interfaceName string, info *server.ServiceInfo) {
ret := grpc.ServiceInfo{}
ret.Methods = make([]grpc.MethodInfo, 0, len(info.Methods))
for _, method := range info.Methods {
Expand All @@ -349,7 +350,7 @@ func (s *Server) saveServiceInfo(info *server.ServiceInfo) {
ret.Metadata = info
s.mu.Lock()
defer s.mu.Unlock()
s.services[info.InterfaceName] = ret
s.services[interfaceName] = ret
}

func (s *Server) GetServiceInfo() map[string]grpc.ServiceInfo {
Expand Down
3 changes: 2 additions & 1 deletion protocol/triple/triple.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func (tp *TripleProtocol) exportForTest(invoker protocol.Invoker, info *server.S
tp.SetExporterMap(serviceKey, exporter)
logger.Infof("[TRIPLE Protocol] Export service: %s", url.String())
tp.openServer(invoker, info)
health.SetServingStatusServing(url.Service())
return exporter
}

Expand All @@ -98,8 +99,8 @@ func (tp *TripleProtocol) openServer(invoker protocol.Invoker, info *server.Serv
}

srv := NewServer()
tp.serverMap[url.Location] = srv
srv.Start(invoker, info)
tp.serverMap[url.Location] = srv
}

// Refer a remote triple service
Expand Down
5 changes: 4 additions & 1 deletion protocol/triple/triple_protocol/handler_compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,12 @@ func generateCompatUnaryHandlerFunc(
}

func compatError(err error) (*Error, bool) {
if err == nil {
return nil, false
}
s, ok := status.FromError(err)
if !ok {
return nil, ok
return nil, false
}

triErr := NewError(Code(s.Code()), errors.New(s.Message()))
Expand Down
18 changes: 6 additions & 12 deletions protocol/triple/triple_protocol/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@ func (s *Server) RegisterUnaryHandler(
if !ok {
hdl = NewUnaryHandler(procedure, reqInitFunc, unary, options...)
s.handlers[procedure] = hdl
s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
implementation := generateUnaryHandlerFunc(procedure, reqInitFunc, unary, config.Interceptor)
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
}

s.mux.Handle(procedure, hdl)

return nil
}

Expand All @@ -67,14 +66,13 @@ func (s *Server) RegisterClientStreamHandler(
if !ok {
hdl = NewClientStreamHandler(procedure, stream, options...)
s.handlers[procedure] = hdl
s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
implementation := generateClientStreamHandlerFunc(procedure, stream, config.Interceptor)
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
}

s.mux.Handle(procedure, hdl)

return nil
}

Expand All @@ -88,14 +86,13 @@ func (s *Server) RegisterServerStreamHandler(
if !ok {
hdl = NewServerStreamHandler(procedure, reqInitFunc, stream, options...)
s.handlers[procedure] = hdl
s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
implementation := generateServerStreamHandlerFunc(procedure, reqInitFunc, stream, config.Interceptor)
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
}

s.mux.Handle(procedure, hdl)

return nil
}

Expand All @@ -108,14 +105,13 @@ func (s *Server) RegisterBidiStreamHandler(
if !ok {
hdl = NewBidiStreamHandler(procedure, stream, options...)
s.handlers[procedure] = hdl
s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
implementation := generateBidiStreamHandlerFunc(procedure, stream, config.Interceptor)
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
}

s.mux.Handle(procedure, hdl)

return nil
}

Expand All @@ -129,14 +125,13 @@ func (s *Server) RegisterCompatUnaryHandler(
if !ok {
hdl = NewCompatUnaryHandler(procedure, srv, unary, options...)
s.handlers[procedure] = hdl
s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
implementation := generateCompatUnaryHandlerFunc(procedure, srv, unary, config.Interceptor)
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
}

s.mux.Handle(procedure, hdl)

return nil
}

Expand All @@ -151,14 +146,13 @@ func (s *Server) RegisterCompatStreamHandler(
if !ok {
hdl = NewCompatStreamHandler(procedure, srv, typ, streamFunc, options...)
s.handlers[procedure] = hdl
s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
implementation := generateCompatStreamHandlerFunc(procedure, srv, streamFunc, config.Interceptor)
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
}

s.mux.Handle(procedure, hdl)

return nil
}

Expand Down
46 changes: 25 additions & 21 deletions protocol/triple/triple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ import (
)

const (
triplePort = "21000"
dubbo3Port = "21001"
listenAddr = "0.0.0.0"
localAddr = "127.0.0.1"
name = "triple"
group = "g1"
version = "v1"
triplePort = "21000"
dubbo3Port = "21001"
listenAddr = "0.0.0.0"
localAddr = "127.0.0.1"
name = "triple"
group = "g1"
version = "v1"
customTripleInterfaceName = "apache.dubbo.org.triple"
customDubbo3InterfaceName = "apache.dubbo.org.dubbo3"
)

type tripleInvoker struct {
Expand Down Expand Up @@ -103,6 +105,8 @@ func runTripleServer(interfaceName string, group string, version string, addr st
common.WithPath(interfaceName),
common.WithLocation(addr),
common.WithPort(triplePort),
common.WithProtocol(TRIPLE),
common.WithInterface(interfaceName),
)
url.SetParam(constant.GroupKey, group)
url.SetParam(constant.VersionKey, version)
Expand All @@ -121,7 +125,7 @@ func runTripleServer(interfaceName string, group string, version string, addr st
func runOldTripleServer(interfaceName string, group string, version string, addr string, desc *grpc_go.ServiceDesc, svc common.RPCService) {
url := common.NewURLWithOptions(
// todo(DMwangnima): figure this out
common.WithPath(desc.ServiceName),
common.WithPath(interfaceName),
common.WithLocation(addr),
common.WithPort(dubbo3Port),
common.WithProtocol(TRIPLE),
Expand All @@ -143,38 +147,38 @@ func runOldTripleServer(interfaceName string, group string, version string, addr
Build()).
Build())
config.SetProviderService(svc)
common.ServiceMap.Register(desc.ServiceName, TRIPLE, group, version, svc)
common.ServiceMap.Register(interfaceName, TRIPLE, group, version, svc)
invoker := extension.GetProxyFactory("default").GetInvoker(url)
GetProtocol().(*TripleProtocol).exportForTest(invoker, nil)
}

func TestMain(m *testing.M) {
runTripleServer(
greettriple.GreetServiceName,
customTripleInterfaceName,
"",
"",
listenAddr,
&greettriple.GreetService_ServiceInfo,
new(api.GreetTripleServer),
)
runTripleServer(
greettriple.GreetServiceName,
customTripleInterfaceName,
group,
version,
listenAddr,
&greettriple.GreetService_ServiceInfo,
new(api.GreetTripleServerGroup1Version1),
)
runOldTripleServer(
dubbo3_greet.GreetService_ServiceDesc.ServiceName,
customDubbo3InterfaceName,
"",
"",
listenAddr,
&dubbo3_greet.GreetService_ServiceDesc,
new(dubbo3_api.GreetDubbo3Server),
)
runOldTripleServer(
dubbo3_greet.GreetService_ServiceDesc.ServiceName,
customDubbo3InterfaceName,
group,
version,
listenAddr,
Expand Down Expand Up @@ -431,46 +435,46 @@ func TestInvoke(t *testing.T) {
}

t.Run("triple2triple", func(t *testing.T) {
invoker, err := tripleInvokerInit(localAddr, triplePort, greettriple.GreetService_ClientInfo.InterfaceName, "", "", greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
invoker, err := tripleInvokerInit(localAddr, triplePort, customTripleInterfaceName, "", "", greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
assert.Nil(t, err)
invokeTripleCodeFunc(t, invoker, "")
})
t.Run("triple2triple_Group1Version1", func(t *testing.T) {
invoker, err := tripleInvokerInit(localAddr, triplePort, greettriple.GreetService_ClientInfo.InterfaceName, group, version, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
invoker, err := tripleInvokerInit(localAddr, triplePort, customTripleInterfaceName, group, version, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
assert.Nil(t, err)
invokeTripleCodeFunc(t, invoker, api.GroupVersionIdentifier)
})
t.Run("triple2dubbo3", func(t *testing.T) {
invoker, err := tripleInvokerInit(localAddr, dubbo3Port, greettriple.GreetService_ClientInfo.InterfaceName, "", "", greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
invoker, err := tripleInvokerInit(localAddr, dubbo3Port, customDubbo3InterfaceName, "", "", greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
assert.Nil(t, err)
invokeTripleCodeFunc(t, invoker, "")
})
t.Run("triple2dubbo3_Group1Version1", func(t *testing.T) {
invoker, err := tripleInvokerInit(localAddr, dubbo3Port, greettriple.GreetService_ClientInfo.InterfaceName, group, version, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
invoker, err := tripleInvokerInit(localAddr, dubbo3Port, customDubbo3InterfaceName, group, version, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
assert.Nil(t, err)
invokeTripleCodeFunc(t, invoker, dubbo3_api.GroupVersionIdentifier)
})
t.Run("dubbo32triple", func(t *testing.T) {
svc := new(dubbo3_greet.GreetServiceClientImpl)
invoker, err := dubbo3InvokerInit(localAddr, triplePort, dubbo3_greet.GreetService_ServiceDesc.ServiceName, "", "", svc)
invoker, err := dubbo3InvokerInit(localAddr, triplePort, customTripleInterfaceName, "", "", svc)
assert.Nil(t, err)
invokeDubbo3CodeFunc(t, invoker, svc, "")
})
t.Run("dubbo32triple_Group1Version1", func(t *testing.T) {
svc := new(dubbo3_greet.GreetServiceClientImpl)
invoker, err := dubbo3InvokerInit(localAddr, triplePort, dubbo3_greet.GreetService_ServiceDesc.ServiceName, group, version, svc)
invoker, err := dubbo3InvokerInit(localAddr, triplePort, customTripleInterfaceName, group, version, svc)
assert.Nil(t, err)
invokeDubbo3CodeFunc(t, invoker, svc, api.GroupVersionIdentifier)
})
t.Run("dubbo32dubbo3", func(t *testing.T) {
svc := new(dubbo3_greet.GreetServiceClientImpl)
invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, dubbo3_greet.GreetService_ServiceDesc.ServiceName, "", "", svc)
invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, customDubbo3InterfaceName, "", "", svc)
assert.Nil(t, err)
invokeDubbo3CodeFunc(t, invoker, svc, "")
})
t.Run("dubbo32dubbo3_Group1Version1", func(t *testing.T) {
svc := new(dubbo3_greet.GreetServiceClientImpl)
invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, dubbo3_greet.GreetService_ServiceDesc.ServiceName, group, version, svc)
invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, customDubbo3InterfaceName, group, version, svc)
assert.Nil(t, err)
invokeDubbo3CodeFunc(t, invoker, svc, dubbo3_api.GroupVersionIdentifier)
})
Expand Down
Loading

0 comments on commit fadf5e1

Please sign in to comment.