Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JellyTony committed Oct 17, 2023
1 parent 70e7f2d commit a7663df
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
3 changes: 2 additions & 1 deletion transport/grpc/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package grpc

import (
"context"

ic "github.com/nextmicro/next/internal/context"

"google.golang.org/grpc"
Expand Down Expand Up @@ -34,7 +35,7 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return handler(ctx, req)
}
if next := s.middleware.Match(tr.Operation()); len(next) > 0 {
if next := s.matcher.Match(tr.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}
reply, err := h(ctx, req)
Expand Down
52 changes: 43 additions & 9 deletions transport/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import (
"net/url"
"time"

chain "github.com/go-kratos/kratos/v2/middleware"
conf "github.com/nextmicro/next/config"
"github.com/nextmicro/next/internal/endpoint"
"github.com/nextmicro/next/internal/host"
"github.com/nextmicro/next/internal/matcher"
middleware2 "github.com/nextmicro/next/middleware"

"google.golang.org/grpc"
"google.golang.org/grpc/admin"
Expand Down Expand Up @@ -69,7 +72,7 @@ func Logger(_ log.Logger) ServerOption {
// Middleware with server middleware.
func Middleware(m ...middleware.Middleware) ServerOption {
return func(s *Server) {
s.middleware.Use(m...)
s.matcher.Use(m...)
}
}

Expand Down Expand Up @@ -126,7 +129,8 @@ type Server struct {
address string
endpoint *url.URL
timeout time.Duration
middleware matcher.Matcher
matcher matcher.Matcher
middleware []middleware.Middleware
unaryInts []grpc.UnaryServerInterceptor
streamInts []grpc.StreamServerInterceptor
grpcOpts []grpc.ServerOption
Expand All @@ -139,16 +143,24 @@ type Server struct {
// NewServer creates a gRPC server by options.
func NewServer(opts ...ServerOption) *Server {
srv := &Server{
baseCtx: context.Background(),
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
health: health.NewServer(),
middleware: matcher.New(),
baseCtx: context.Background(),
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
health: health.NewServer(),
matcher: matcher.New(),
}
for _, o := range opts {
o(srv)
}
serverMs := srv.buildMiddlewareOptions()
// server middleware first
if len(serverMs) > 0 {
userMs := srv.middleware
srv.middleware = append(serverMs, userMs...)
}
srv.matcher.Use(srv.middleware...)

unaryInts := []grpc.UnaryServerInterceptor{
srv.unaryServerInterceptor(),
}
Expand Down Expand Up @@ -184,13 +196,35 @@ func NewServer(opts ...ServerOption) *Server {
return srv
}

// buildMiddlewareOptions builds the http server options.
func (s *Server) buildMiddlewareOptions() []middleware.Middleware {
cfg := conf.ApplicationConfig().GetServer().GetGrpc()
if cfg.GetAddr() != "" {
s.address = cfg.GetAddr()
}
if cfg.GetNetwork() != "" {
s.network = cfg.GetNetwork()
}
if cfg.GetTimeout().AsDuration() != 0 {
s.timeout = cfg.GetTimeout().AsDuration()
}

ms := make([]chain.Middleware, 0, len(cfg.GetMiddlewares()))
if cfg != nil && cfg.GetMiddlewares() != nil {
serverMs, _ := middleware2.BuildMiddleware("server", cfg.GetMiddlewares())
ms = append(ms, serverMs...)
}

return ms
}

// Use uses a service middleware with selector.
// selector:
// - '/*'
// - '/helloworld.v1.Greeter/*'
// - '/helloworld.v1.Greeter/SayHello'
func (s *Server) Use(selector string, m ...middleware.Middleware) {
s.middleware.Add(selector, m...)
s.matcher.Add(selector, m...)
}

// Endpoint return a real address to registry endpoint.
Expand Down
3 changes: 3 additions & 0 deletions transport/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ func (s *Server) buildMiddlewareOptions() []middleware.Middleware {
if cfg.GetAddr() != "" {
s.address = cfg.GetAddr()
}
if cfg.GetNetwork() != "" {
s.network = cfg.GetNetwork()
}
if cfg.GetTimeout().AsDuration() != 0 {
s.timeout = cfg.GetTimeout().AsDuration()
}
Expand Down

0 comments on commit a7663df

Please sign in to comment.