diff --git a/transport/grpc/interceptor.go b/transport/grpc/interceptor.go index 7d5ce50..beb0991 100644 --- a/transport/grpc/interceptor.go +++ b/transport/grpc/interceptor.go @@ -2,6 +2,7 @@ package grpc import ( "context" + ic "github.com/nextmicro/next/internal/context" "google.golang.org/grpc" @@ -34,7 +35,7 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { h := func(ctx context.Context, req interface{}) (interface{}, error) { return handler(ctx, req) } - if next := s.middleware.Match(tr.Operation()); len(next) > 0 { + if next := s.matcher.Match(tr.Operation()); len(next) > 0 { h = middleware.Chain(next...)(h) } reply, err := h(ctx, req) diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 33fe43e..d2f519a 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -7,9 +7,12 @@ import ( "net/url" "time" + chain "github.com/go-kratos/kratos/v2/middleware" + conf "github.com/nextmicro/next/config" "github.com/nextmicro/next/internal/endpoint" "github.com/nextmicro/next/internal/host" "github.com/nextmicro/next/internal/matcher" + middleware2 "github.com/nextmicro/next/middleware" "google.golang.org/grpc" "google.golang.org/grpc/admin" @@ -69,7 +72,7 @@ func Logger(_ log.Logger) ServerOption { // Middleware with server middleware. func Middleware(m ...middleware.Middleware) ServerOption { return func(s *Server) { - s.middleware.Use(m...) + s.matcher.Use(m...) } } @@ -126,7 +129,8 @@ type Server struct { address string endpoint *url.URL timeout time.Duration - middleware matcher.Matcher + matcher matcher.Matcher + middleware []middleware.Middleware unaryInts []grpc.UnaryServerInterceptor streamInts []grpc.StreamServerInterceptor grpcOpts []grpc.ServerOption @@ -139,16 +143,24 @@ type Server struct { // NewServer creates a gRPC server by options. func NewServer(opts ...ServerOption) *Server { srv := &Server{ - baseCtx: context.Background(), - network: "tcp", - address: ":0", - timeout: 1 * time.Second, - health: health.NewServer(), - middleware: matcher.New(), + baseCtx: context.Background(), + network: "tcp", + address: ":0", + timeout: 1 * time.Second, + health: health.NewServer(), + matcher: matcher.New(), } for _, o := range opts { o(srv) } + serverMs := srv.buildMiddlewareOptions() + // server middleware first + if len(serverMs) > 0 { + userMs := srv.middleware + srv.middleware = append(serverMs, userMs...) + } + srv.matcher.Use(srv.middleware...) + unaryInts := []grpc.UnaryServerInterceptor{ srv.unaryServerInterceptor(), } @@ -184,13 +196,35 @@ func NewServer(opts ...ServerOption) *Server { return srv } +// buildMiddlewareOptions builds the http server options. +func (s *Server) buildMiddlewareOptions() []middleware.Middleware { + cfg := conf.ApplicationConfig().GetServer().GetGrpc() + if cfg.GetAddr() != "" { + s.address = cfg.GetAddr() + } + if cfg.GetNetwork() != "" { + s.network = cfg.GetNetwork() + } + if cfg.GetTimeout().AsDuration() != 0 { + s.timeout = cfg.GetTimeout().AsDuration() + } + + ms := make([]chain.Middleware, 0, len(cfg.GetMiddlewares())) + if cfg != nil && cfg.GetMiddlewares() != nil { + serverMs, _ := middleware2.BuildMiddleware("server", cfg.GetMiddlewares()) + ms = append(ms, serverMs...) + } + + return ms +} + // Use uses a service middleware with selector. // selector: // - '/*' // - '/helloworld.v1.Greeter/*' // - '/helloworld.v1.Greeter/SayHello' func (s *Server) Use(selector string, m ...middleware.Middleware) { - s.middleware.Add(selector, m...) + s.matcher.Add(selector, m...) } // Endpoint return a real address to registry endpoint. diff --git a/transport/http/server.go b/transport/http/server.go index 18377d7..2c9b6ef 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -209,6 +209,9 @@ func (s *Server) buildMiddlewareOptions() []middleware.Middleware { if cfg.GetAddr() != "" { s.address = cfg.GetAddr() } + if cfg.GetNetwork() != "" { + s.network = cfg.GetNetwork() + } if cfg.GetTimeout().AsDuration() != 0 { s.timeout = cfg.GetTimeout().AsDuration() }