From e2506f7d7702004418da73d366bd403456a6e75f Mon Sep 17 00:00:00 2001 From: actatum Date: Thu, 7 Dec 2023 15:03:32 -0500 Subject: [PATCH] Micro update (#14) * use nats micro package * upgrade to go1.21 and use slog for the logger middleware * move protoc-gen-stormrpc to cmd folder, fix passing of headers to handler function and context, fix endpoint name to pass micro endpoint name validation * move protoc-gen-stormrpc to cmd folder, fix passing of headers to handler function and context, fix endpoint name to pass micro endpoint name validation * cleaning up some of the codegen internals, adding some comments * bump github actions test matrix to go1.21 * bump nats server version, bump protobuf go library version * trying to track down this data race and determine if its my code or nats.go * using my fork of nats.go with fix for race condition until it can get merged into upstream * fix linting errors * update nats.go dependency to my fork's main branch * it works --- .github/workflows/actions.yaml | 2 +- .github/workflows/main.yaml | 2 +- README.md | 128 +++++--- client.go | 2 +- client_test.go | 48 +-- examples/protogen/genproto.sh | 4 +- examples/protogen/pb/echo.pb.go | 2 +- examples/protogen/server/main.go | 5 +- examples/simple/server/main.go | 7 +- go.mod | 31 +- go.sum | 59 ++-- headers.go | 1 + internal/gen/gen.go | 25 +- middleware/logging.go | 37 ++- middleware/logging_test.go | 80 +++-- options.go | 2 +- protoc-gen-stormrpc/main.go | 19 -- prototest/protoc.sh | 2 +- server.go | 199 ++++++++---- server_test.go | 539 +++++++++++++++---------------- 20 files changed, 642 insertions(+), 552 deletions(-) delete mode 100644 protoc-gen-stormrpc/main.go diff --git a/.github/workflows/actions.yaml b/.github/workflows/actions.yaml index c745765..9b24ff3 100644 --- a/.github/workflows/actions.yaml +++ b/.github/workflows/actions.yaml @@ -5,7 +5,7 @@ jobs: test: strategy: matrix: - go-version: [1.19.x, 1.20.x] + go-version: [1.21.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} steps: diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 16ff8fe..5ddda52 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -7,7 +7,7 @@ jobs: test: strategy: matrix: - go-version: [1.19.x, 1.20.x] + go-version: [1.21.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} steps: diff --git a/README.md b/README.md index 9c1e26e..f054a89 100644 --- a/README.md +++ b/README.md @@ -6,29 +6,31 @@ [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/actatum/stormrpc) [![Release](https://img.shields.io/github/release/actatum/stormrpc.svg)](https://github.com/actatum/stormrpc/releases/latest) - StormRPC is an abstraction or wrapper on [`NATS`] Request/Reply messaging capabilities. It provides some convenient features including: -* **Middleware** +- **Middleware** + + Middleware are decorators around `HandlerFunc`s. Some middleware are available within the package including `RequestID`, `Tracing` (via OpenTelemetry) `Logger` and `Recoverer`. + +- **Body encoding and decoding** + + Marshalling and unmarshalling request bodies to structs. JSON, Protobuf, and Msgpack are supported out of the box. - Middleware are decorators around `HandlerFunc`s. Some middleware are available within the package including `RequestID`, `Tracing` (via OpenTelemetry) `Logger` and `Recoverer`. -* **Body encoding and decoding** +- **Deadline propagation** - Marshalling and unmarshalling request bodies to structs. JSON, Protobuf, and Msgpack are supported out of the box. -* **Deadline propagation** + Request deadlines are propagated from client to server so both ends will stop processing once the deadline has passed. - Request deadlines are propagated from client to server so both ends will stop processing once the deadline has passed. -* **Error propagation** +- **Error propagation** - Responses have an `Error` attribute and these are propagated across the wire without needing to tweak your request/response schemas. + Responses have an `Error` attribute and these are propagated across the wire without needing to tweak your request/response schemas. ## Installation ### Runtime Library -The runtime library package ```github.com/actatum/stormrpc``` contains common types like ```stormrpc.Error```, ```stormrpc.Client``` and ```stormrpc.Server```. If you aren't generating servers and clients from protobuf definitions you only need to import the stormrpc package. +The runtime library package `github.com/actatum/stormrpc` contains common types like `stormrpc.Error`, `stormrpc.Client` and `stormrpc.Server`. If you aren't generating servers and clients from protobuf definitions you only need to import the stormrpc package. ```bash go get github.com/actatum/stormrpc @@ -36,19 +38,19 @@ go get github.com/actatum/stormrpc ### Code Generator -You need to install ```go``` and the ```protoc``` compiler on your system. Then, install the protoc plugins ```protoc-gen-stormrpc``` and ```protoc-gen-go``` to generate Go code. +You need to install `go` and the `protoc` compiler on your system. Then, install the protoc plugins `protoc-gen-stormrpc` and `protoc-gen-go` to generate Go code. ```bash -go install github.com/actatum/stormrpc/protoc-gen-stormrpc@latest +go install github.com/actatum/stormrpc/cmd/protoc-gen-stormrpc@latest go install google.golang.org/protobuf/cmd/protoc-gen-go@latest ``` To generate client and server stubs use the following command + ```bash protoc --go_out=$output_dir --stormrpc_out=$output_dir $input_proto_file ``` - Code generation examples can be found [here](https://github.com/actatum/stormrpc/tree/main/examples/protogen) ## Basic Usage @@ -59,53 +61,73 @@ Code generation examples can be found [here](https://github.com/actatum/stormrpc package main import ( - "context" - "log" - "os" - "os/signal" - "syscall" - "time" - - "github.com/actatum/stormrpc" - "github.com/nats-io/nats.go" + "context" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/actatum/stormrpc" + "github.com/nats-io/nats-server/v2/server" ) func echo(ctx context.Context, req stormrpc.Request) stormrpc.Response { - var b any - if err := req.Decode(&b); err != nil { - return stormrpc.NewErrorResponse(req.Reply, err) - } + var b any + if err := req.Decode(&b); err != nil { + return stormrpc.NewErrorResponse(req.Reply, err) + } - resp, err := stormrpc.NewResponse(req.Reply, b) - if err != nil { - return stormrpc.NewErrorResponse(req.Reply, err) - } + resp, err := stormrpc.NewResponse(req.Reply, b) + if err != nil { + return stormrpc.NewErrorResponse(req.Reply, err) + } - return resp + return resp } func main() { - srv, err := stormrpc.NewServer("echo", nats.DefaultURL) - if err != nil { - log.Fatal(err) - } - srv.Handle("echo", echo) - - go func() { - _ = srv.Run() - }() - log.Printf("👋 Listening on %v", srv.Subjects()) - - done := make(chan os.Signal, 1) - signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) - <-done - log.Printf("💀 Shutting down") - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err = srv.Shutdown(ctx); err != nil { - log.Fatal(err) - } + ns, err := server.NewServer(&server.Options{ + Port: 40897, + }) + if err != nil { + log.Fatal(err) + } + ns.Start() + defer func() { + ns.Shutdown() + ns.WaitForShutdown() + }() + + if !ns.ReadyForConnections(1 * time.Second) { + log.Fatal("timeout waiting for nats server") + } + + srv, err := stormrpc.NewServer(&stormrpc.ServerConfig{ + NatsURL: ns.ClientURL(), + Name: "echo", + }) + if err != nil { + log.Fatal(err) + } + + srv.Handle("echo", echo) + + go func() { + _ = srv.Run() + }() + log.Printf("👋 Listening on %v", srv.Subjects()) + + done := make(chan os.Signal, 1) + signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) + <-done + log.Printf("💀 Shutting down") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err = srv.Shutdown(ctx); err != nil { + log.Fatal(err) + } } ``` @@ -156,4 +178,4 @@ func main() { ``` [`nats.go`]: https://github.com/nats-io/nats.go -[`NATS`]: https://docs.nats.io/ \ No newline at end of file +[`NATS`]: https://docs.nats.io/ diff --git a/client.go b/client.go index 456bca9..fabc6f0 100644 --- a/client.go +++ b/client.go @@ -15,7 +15,7 @@ type Client struct { } // NewClient returns a new instance of a Client. -func NewClient(natsURL string, opts ...ClientOption) (*Client, error) { +func NewClient(natsURL string, _ ...ClientOption) (*Client, error) { nc, err := nats.Connect(natsURL) if err != nil { return nil, err diff --git a/client_test.go b/client_test.go index daac78b..0157d06 100644 --- a/client_test.go +++ b/client_test.go @@ -50,7 +50,7 @@ func TestNewClient(t *testing.T) { } func TestClient_Do(t *testing.T) { - t.Parallel() + // t.Parallel() ns, err := server.NewServer(&server.Options{ Port: 41397, @@ -69,12 +69,15 @@ func TestClient_Do(t *testing.T) { return } - t.Run("deadline exceeded", func(t *testing.T) { - t.Parallel() + clientURL := ns.ClientURL() + t.Run("deadline exceeded", func(t *testing.T) { timeout := 50 * time.Millisecond subject := strconv.Itoa(rand.Int()) - srv, err := NewServer("test", ns.ClientURL()) + srv, err := NewServer(&ServerConfig{ + NatsURL: clientURL, + Name: "test", + }) if err != nil { t.Fatal(err) } @@ -89,7 +92,7 @@ func TestClient_Do(t *testing.T) { _ = srv.Shutdown(context.Background()) }) - client, err := NewClient(ns.ClientURL()) + client, err := NewClient(clientURL) if err != nil { t.Fatal(err) } @@ -112,11 +115,12 @@ func TestClient_Do(t *testing.T) { }) t.Run("rpc error", func(t *testing.T) { - t.Parallel() - timeout := 50 * time.Millisecond subject := strconv.Itoa(rand.Int()) - srv, err := NewServer("test", ns.ClientURL()) + srv, err := NewServer(&ServerConfig{ + NatsURL: clientURL, + Name: "test", + }) if err != nil { t.Fatal(err) } @@ -130,7 +134,7 @@ func TestClient_Do(t *testing.T) { _ = srv.Shutdown(context.Background()) }) - client, err := NewClient(ns.ClientURL()) + client, err := NewClient(clientURL) if err != nil { t.Fatal(err) } @@ -158,11 +162,9 @@ func TestClient_Do(t *testing.T) { }) t.Run("no servers", func(t *testing.T) { - t.Parallel() - subject := strconv.Itoa(rand.Int()) - client, err := NewClient(ns.ClientURL()) + client, err := NewClient(clientURL) if err != nil { t.Fatal(err) } @@ -194,9 +196,7 @@ func TestClient_Do(t *testing.T) { }) t.Run("request option errors", func(t *testing.T) { - t.Parallel() - - client, err := NewClient(ns.ClientURL()) + client, err := NewClient(clientURL) if err != nil { t.Fatal(err) } @@ -225,11 +225,12 @@ func TestClient_Do(t *testing.T) { }) t.Run("successful request", func(t *testing.T) { - t.Parallel() - timeout := 50 * time.Millisecond subject := strconv.Itoa(rand.Int()) - srv, err := NewServer("test", ns.ClientURL()) + srv, err := NewServer(&ServerConfig{ + NatsURL: clientURL, + Name: "test", + }) if err != nil { t.Fatal(err) } @@ -248,7 +249,7 @@ func TestClient_Do(t *testing.T) { _ = srv.Shutdown(context.Background()) }) - client, err := NewClient(ns.ClientURL()) + client, err := NewClient(clientURL) if err != nil { t.Fatal(err) } @@ -276,12 +277,13 @@ func TestClient_Do(t *testing.T) { }) t.Run("successful request w/headers option", func(t *testing.T) { - t.Parallel() - apiKey := uuid.NewString() timeout := 50 * time.Millisecond subject := strconv.Itoa(rand.Int()) - srv, err := NewServer("test", ns.ClientURL()) + srv, err := NewServer(&ServerConfig{ + NatsURL: clientURL, + Name: "test", + }) if err != nil { t.Fatal(err) } @@ -303,7 +305,7 @@ func TestClient_Do(t *testing.T) { _ = srv.Shutdown(context.Background()) }) - client, err := NewClient(ns.ClientURL()) + client, err := NewClient(clientURL) if err != nil { t.Fatal(err) } diff --git a/examples/protogen/genproto.sh b/examples/protogen/genproto.sh index 65f31cb..6546144 100755 --- a/examples/protogen/genproto.sh +++ b/examples/protogen/genproto.sh @@ -1,2 +1,2 @@ -go install ../../protoc-gen-stormrpc -protoc --go_out=./pb --stormrpc_out=./pb pb/echo.proto \ No newline at end of file +go install ../../cmd/protoc-gen-stormrpc +protoc --go_out=./pb --stormrpc_out=./pb pb/echo.proto diff --git a/examples/protogen/pb/echo.pb.go b/examples/protogen/pb/echo.pb.go index a93c605..bacb256 100644 --- a/examples/protogen/pb/echo.pb.go +++ b/examples/protogen/pb/echo.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.0 +// protoc-gen-go v1.31.0 // protoc v3.6.1 // source: pb/echo.proto diff --git a/examples/protogen/server/main.go b/examples/protogen/server/main.go index aa61ddb..b9db34b 100644 --- a/examples/protogen/server/main.go +++ b/examples/protogen/server/main.go @@ -41,7 +41,10 @@ func main() { log.Fatal("timeout waiting for nats server") } - srv, err := stormrpc.NewServer("echo", ns.ClientURL(), stormrpc.WithErrorHandler(logError)) + srv, err := stormrpc.NewServer(&stormrpc.ServerConfig{ + NatsURL: ns.ClientURL(), + Name: "echo", + }, stormrpc.WithErrorHandler(logError)) if err != nil { log.Fatal(err) } diff --git a/examples/simple/server/main.go b/examples/simple/server/main.go index d6e6f8e..2a89494 100644 --- a/examples/simple/server/main.go +++ b/examples/simple/server/main.go @@ -33,7 +33,7 @@ func main() { if err != nil { log.Fatal(err) } - go ns.Start() + ns.Start() defer func() { ns.Shutdown() ns.WaitForShutdown() @@ -43,7 +43,10 @@ func main() { log.Fatal("timeout waiting for nats server") } - srv, err := stormrpc.NewServer("echo", ns.ClientURL()) + srv, err := stormrpc.NewServer(&stormrpc.ServerConfig{ + NatsURL: ns.ClientURL(), + Name: "echo", + }) if err != nil { log.Fatal(err) } diff --git a/go.mod b/go.mod index 883c81e..d395167 100644 --- a/go.mod +++ b/go.mod @@ -1,33 +1,34 @@ module github.com/actatum/stormrpc -go 1.20 +go 1.21 + +// replace github.com/nats-io/nats.go => ../nats.go + +replace github.com/nats-io/nats.go => github.com/actatum/nats.go v1.31.1-0.20231207185944-7538a5cd8e3f require ( github.com/google/uuid v1.3.0 - github.com/nats-io/nats-server/v2 v2.9.15 - github.com/nats-io/nats.go v1.24.0 + github.com/nats-io/nats-server/v2 v2.10.7 + github.com/nats-io/nats.go v1.31.1-0.20231201130123-4af26aae2522 github.com/vmihailenco/msgpack/v5 v5.3.5 go.opentelemetry.io/otel v1.14.0 go.opentelemetry.io/otel/sdk v1.7.0 go.opentelemetry.io/otel/trace v1.14.0 - go.uber.org/zap v1.24.0 - google.golang.org/protobuf v1.29.0 + google.golang.org/protobuf v1.31.0 ) require ( - github.com/benbjohnson/clock v1.1.0 // indirect github.com/go-logr/logr v1.2.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/golang/protobuf v1.5.2 // indirect - github.com/klauspost/compress v1.16.0 // indirect + github.com/klauspost/compress v1.17.4 // indirect github.com/minio/highwayhash v1.0.2 // indirect - github.com/nats-io/jwt/v2 v2.3.0 // indirect - github.com/nats-io/nkeys v0.3.0 // indirect + github.com/nats-io/jwt/v2 v2.5.3 // indirect + github.com/nats-io/nkeys v0.4.6 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect - go.uber.org/atomic v1.10.0 // indirect - go.uber.org/multierr v1.10.0 // indirect - golang.org/x/crypto v0.7.0 // indirect - golang.org/x/sys v0.6.0 // indirect - golang.org/x/time v0.3.0 // indirect + golang.org/x/crypto v0.16.0 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/time v0.5.0 // indirect ) + +// v1.31.1-0.20231207012943-0e824f8b9d26 diff --git a/go.sum b/go.sum index 6a43f48..d50c436 100644 --- a/go.sum +++ b/go.sum @@ -1,41 +1,39 @@ -github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= -github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/actatum/nats.go v1.31.1-0.20231207185944-7538a5cd8e3f h1:KQ4ivpBYqvuNwb3FRTsipUIX+uk6B9FE1t65V7a2orI= +github.com/actatum/nats.go v1.31.1-0.20231207185944-7538a5cd8e3f/go.mod h1:uCwt8khnwboRrH1RbNzJh9C/GEnXnnwkcB/bUoz8eJs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/klauspost/compress v1.16.0 h1:iULayQNOReoYUe+1qtKOqw9CwJv3aNQu8ivo7lw1HU4= -github.com/klauspost/compress v1.16.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= -github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= -github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.15 h1:MuwEJheIwpvFgqvbs20W8Ish2azcygjf4Z0liVu2I4c= -github.com/nats-io/nats-server/v2 v2.9.15/go.mod h1:QlCTy115fqpx4KSOPFIxSV7DdI6OxtZsGOL1JLdeRlE= -github.com/nats-io/nats.go v1.24.0 h1:CRiD8L5GOQu/DcfkmgBcTTIQORMwizF+rPk6T0RaHVQ= -github.com/nats-io/nats.go v1.24.0/go.mod h1:dVQF+BK3SzUZpwyzHedXsvH3EO38aVKuOPkkHlv5hXA= -github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= -github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= +github.com/nats-io/jwt/v2 v2.5.3 h1:/9SWvzc6hTfamcgXJ3uYRpgj+QuY2aLNqRiqrKcrpEo= +github.com/nats-io/jwt/v2 v2.5.3/go.mod h1:iysuPemFcc7p4IoYots3IuELSI4EDe9Y0bQMe+I3Bf4= +github.com/nats-io/nats-server/v2 v2.10.7 h1:f5VDy+GMu7JyuFA0Fef+6TfulfCs5nBTgq7MMkFJx5Y= +github.com/nats-io/nats-server/v2 v2.10.7/go.mod h1:V2JHOvPiPdtfDXTuEUsthUnCvSDeFrK4Xn9hRo6du7c= +github.com/nats-io/nkeys v0.4.6 h1:IzVe95ru2CT6ta874rt9saQRkWfe2nFj1NtvYSLqMzY= +github.com/nats-io/nkeys v0.4.6/go.mod h1:4DxZNzenSVd1cYQoAa8948QY3QDjrHfcfVADymtkpts= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= @@ -48,32 +46,19 @@ go.opentelemetry.io/otel/sdk v1.7.0/go.mod h1:uTEOTwaqIVuTGiJN7ii13Ibp75wJmYUDe3 go.opentelemetry.io/otel/trace v1.7.0/go.mod h1:fzLSB9nqR2eXzxPXb2JW9IKE+ScyXA48yyE4TNvoHqU= go.opentelemetry.io/otel/trace v1.14.0 h1:wp2Mmvj41tDsyAJXiWDWpfNsOiIyd38fy85pyKcFq/M= go.opentelemetry.io/otel/trace v1.14.0/go.mod h1:8avnQLK+CG77yNLUae4ea2JDQ6iT+gozhnZjy/rw9G8= -go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= -go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= -go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= -go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= -go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= -go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= -golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= +golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= -golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.29.0 h1:44S3JjaKmLEE4YIkjzexaP+NzZsudE3Zin5Njn/pYX0= -google.golang.org/protobuf v1.29.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/headers.go b/headers.go index 2970298..c3f8c4b 100644 --- a/headers.go +++ b/headers.go @@ -10,6 +10,7 @@ import ( ) const ( + // errorHeader will be deprecated in a future update in favor of 'Nats-Service-Error' and 'Nats-Service-Error-Code'. errorHeader = "stormrpc-error" deadlineHeader = "stormrpc-deadline" ) diff --git a/internal/gen/gen.go b/internal/gen/gen.go index 2ee7589..aeb94aa 100644 --- a/internal/gen/gen.go +++ b/internal/gen/gen.go @@ -29,23 +29,23 @@ func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated g.P() g.P("package ", file.GoPackageName) g.P() - GenerateFileContent(gen, file, g) + GenerateFileContent(file, g) return g } // GenerateFileContent generates the stormrpc service definitions, excluding the package statement. -func GenerateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) { +func GenerateFileContent(file *protogen.File, g *protogen.GeneratedFile) { if len(file.Services) == 0 { return } g.P() for _, service := range file.Services { - genService(gen, file, g, service) + genService(g, service) } } -func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) { +func genService(g *protogen.GeneratedFile, service *protogen.Service) { clientName := service.GoName + "Client" g.P("// ", clientName, " is the client API for ", service.GoName, " service.") @@ -55,10 +55,10 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated g.P("//") g.P(deprectationComment) } - g.Annotate(clientName, service.Location) + g.AnnotateSymbol(clientName, protogen.Annotation{Location: service.Location}) g.P("type ", clientName, " interface {") for _, method := range service.Methods { - g.Annotate(clientName+"."+method.GoName, method.Location) + g.AnnotateSymbol(clientName+"."+method.GoName, protogen.Annotation{Location: method.Location}) if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { g.P(deprectationComment) } @@ -86,7 +86,7 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated for _, method := range service.Methods { if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { // Unary RPC method - genClientMethod(gen, file, g, method, methodIndex) + genClientMethod(g, method) methodIndex++ } } @@ -98,10 +98,10 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated g.P("//") g.P(deprectationComment) } - g.Annotate(serverType, service.Location) + g.AnnotateSymbol(serverType, protogen.Annotation{Location: service.Location}) g.P("type ", serverType, " interface {") for _, method := range service.Methods { - g.Annotate(serverType+"."+method.GoName, method.Location) + g.AnnotateSymbol(serverType+"."+method.GoName, protogen.Annotation{Location: method.Location}) if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { g.P(deprectationComment) } @@ -125,7 +125,7 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated // Server handler implementations. var handlerNames []string for _, method := range service.Methods { - hname := genServerHandler(gen, file, g, method) + hname := genServerHandler(g, method) handlerNames = append(handlerNames, hname) } @@ -161,11 +161,8 @@ func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string } func genClientMethod( - gen *protogen.Plugin, - file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, - index int, ) { service := method.Parent @@ -207,8 +204,6 @@ func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string } func genServerHandler( - gen *protogen.Plugin, - file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, ) string { diff --git a/middleware/logging.go b/middleware/logging.go index 9a8a161..e11382e 100644 --- a/middleware/logging.go +++ b/middleware/logging.go @@ -3,16 +3,16 @@ package middleware import ( "context" + "log/slog" "time" "github.com/actatum/stormrpc" "go.opentelemetry.io/otel/trace" - "go.uber.org/zap" ) // Logger logs request scoped information such as request id, trace information, and request duration. // This middleware should be applied after RequestID, and Tracing. -func Logger(l *zap.Logger) func(next stormrpc.HandlerFunc) stormrpc.HandlerFunc { +func Logger(l *slog.Logger) func(next stormrpc.HandlerFunc) stormrpc.HandlerFunc { return func(next stormrpc.HandlerFunc) stormrpc.HandlerFunc { return func(ctx context.Context, r stormrpc.Request) stormrpc.Response { span := trace.SpanFromContext(ctx) @@ -22,21 +22,34 @@ func Logger(l *zap.Logger) func(next stormrpc.HandlerFunc) stormrpc.HandlerFunc resp := next(ctx, r) - fields := []zap.Field{ - zap.String("id", id), - zap.String("trace_id", span.SpanContext().TraceID().String()), - zap.String("span_id", span.SpanContext().SpanID().String()), - zap.String("duration", time.Since(start).String()), - } + attrs := make([]slog.Attr, 0) + level := slog.LevelInfo + msg := "Success" if resp.Err != nil { + msg = "Server Error" + level = slog.LevelError code := stormrpc.CodeFromErr(resp.Err) - fields = append(fields, zap.String("code", code.String())) - l.Error("Server Error", fields...) - } else { - l.Info("Success", fields...) + attrs = append(attrs, slog.Group( + "error", + slog.String("message", resp.Err.Error()), + slog.String("code", code.String()), + )) } + attrs = append(attrs, slog.Group("request", + slog.String("id", id), + slog.String("trace_id", span.SpanContext().TraceID().String()), + slog.String("duration", time.Since(start).String()), + )) + + l.LogAttrs( + ctx, + level, + msg, + attrs..., + ) + return resp } } diff --git a/middleware/logging_test.go b/middleware/logging_test.go index 9b4e2f3..409e343 100644 --- a/middleware/logging_test.go +++ b/middleware/logging_test.go @@ -2,18 +2,37 @@ package middleware import ( + "bytes" "context" + "encoding/json" "fmt" + "log/slog" "testing" + "time" "github.com/actatum/stormrpc" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "go.uber.org/zap/zaptest" ) +type logOutput struct { + Time time.Time `json:"time"` + Level slog.Level `json:"level"` + Msg string `json:"msg"` + Request struct { + ID string `json:"id"` + TraceID string `json:"trace_id"` + Duration string `json:"duration"` + } `json:"request"` + Error struct { + Message string `json:"message"` + Code string `json:"code"` + } `json:"error"` +} + func TestLogger(t *testing.T) { t.Run("success response", func(t *testing.T) { + buf := &bytes.Buffer{} + logger := slog.New(slog.NewJSONHandler(buf, nil)) + req, _ := stormrpc.NewRequest("test", map[string]string{"hi": "there"}) handler := stormrpc.HandlerFunc(func(ctx context.Context, r stormrpc.Request) stormrpc.Response { resp, err := stormrpc.NewResponse(r.Reply, map[string]string{"hello": "world"}) @@ -22,38 +41,45 @@ func TestLogger(t *testing.T) { } return resp }) - l := zaptest.NewLogger(t, zaptest.WrapOptions(zap.Hooks(func(e zapcore.Entry) error { - if e.Level != zap.InfoLevel { - t.Errorf("e.Level got = %v, want %v", e.Level, zap.InfoLevel) - } - if e.Message != "Success" { - t.Errorf("e.Message got = %v, want %v", e.Message, "Success") - } - return nil - }))) - h := Logger(l)(handler) - resp := h(context.Background(), req) - fmt.Println(resp) + h := RequestID(Logger(logger)(handler)) + _ = h(context.Background(), req) + + var out logOutput + if err := json.Unmarshal(buf.Bytes(), &out); err != nil { + t.Fatal(err) + } + + if out.Level != slog.LevelInfo { + t.Errorf("got level = %v, want %v", out.Level, slog.LevelInfo) + } else if out.Msg != "Success" { + t.Errorf("got msg = %v, want %v", out.Msg, "Success") + } }) t.Run("error response", func(t *testing.T) { + buf := &bytes.Buffer{} + logger := slog.New(slog.NewJSONHandler(buf, nil)) + req, _ := stormrpc.NewRequest("test", map[string]string{"hi": "there"}) handler := stormrpc.HandlerFunc(func(ctx context.Context, r stormrpc.Request) stormrpc.Response { return stormrpc.NewErrorResponse(r.Reply, fmt.Errorf("some error")) }) - l := zaptest.NewLogger(t, zaptest.WrapOptions(zap.Hooks(func(e zapcore.Entry) error { - if e.Level != zap.ErrorLevel { - t.Errorf("e.Level got = %v, want %v", e.Level, zap.ErrorLevel) - } - if e.Message != "Server Error" { - t.Errorf("e.Message got = %v, want %v", e.Message, "Server Error") - } - return nil - }))) - h := Logger(l)(handler) - resp := h(context.Background(), req) - fmt.Println(resp) + h := RequestID(Logger(logger)(handler)) + _ = h(context.Background(), req) + + var out logOutput + if err := json.Unmarshal(buf.Bytes(), &out); err != nil { + t.Fatal(err) + } + + if out.Level != slog.LevelError { + t.Errorf("got level = %v, want %v", out.Level, slog.LevelError) + } else if out.Msg != "Server Error" { + t.Errorf("got msg = %v, want %v", out.Msg, "Server Error") + } else if out.Error.Code != stormrpc.ErrorCodeUnknown.String() { + t.Errorf("got error code = %v, want %v", out.Error.Code, stormrpc.ErrorCodeUnknown.String()) + } }) } diff --git a/options.go b/options.go index 3ac73f2..b608950 100644 --- a/options.go +++ b/options.go @@ -27,7 +27,7 @@ func (o *HeaderCallOption) before(c *callOptions) error { return nil } -func (o *HeaderCallOption) after(c *callOptions) {} +func (o *HeaderCallOption) after(_ *callOptions) {} // WithHeaders returns a CallOption that appends the given headers to the request. func WithHeaders(h map[string]string) CallOption { diff --git a/protoc-gen-stormrpc/main.go b/protoc-gen-stormrpc/main.go deleted file mode 100644 index 332f1bc..0000000 --- a/protoc-gen-stormrpc/main.go +++ /dev/null @@ -1,19 +0,0 @@ -// Package main provides the executable function for the protoc-gen-stormrpc binary. -package main - -import ( - stormrpcgen "github.com/actatum/stormrpc/internal/gen" - "google.golang.org/protobuf/compiler/protogen" -) - -func main() { - protogen.Options{}.Run(func(gen *protogen.Plugin) error { - for _, f := range gen.Files { - if !f.Generate { - continue - } - stormrpcgen.GenerateFile(gen, f) - } - return nil - }) -} diff --git a/prototest/protoc.sh b/prototest/protoc.sh index 52421d5..bb87d15 100755 --- a/prototest/protoc.sh +++ b/prototest/protoc.sh @@ -1,3 +1,3 @@ -go install ./protoc-gen-stormrpc +go install ./cmd/protoc-gen-stormrpc protoc --proto_path prototest -I=. prototest/test.proto \ --stormrpc_out=./prototest/gen_out --go_out=./prototest diff --git a/server.go b/server.go index eb27612..b479025 100644 --- a/server.go +++ b/server.go @@ -3,61 +3,105 @@ package stormrpc import ( "context" + "strings" + "sync" "time" "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/micro" ) var defaultServerTimeout = 5 * time.Second +// ServerConfig is used to configure required fields for a StormRPC server. +// If any fields aren't present a default value will be used. +type ServerConfig struct { + NatsURL string + Name string + Version string + + errorHandler ErrorHandler +} + +func (s *ServerConfig) setDefaults() { + if s.NatsURL == "" { + s.NatsURL = nats.DefaultURL + } + if s.Name == "" { + s.Name = "service" + } + if s.Version == "" { + s.Version = "0.1.0" + } + if s.errorHandler == nil { + s.errorHandler = func(ctx context.Context, err error) {} + } +} + // Server represents a stormRPC server. It contains all functionality for handling RPC requests. type Server struct { + mu sync.Mutex nc *nats.Conn - name string shutdownSignal chan struct{} handlerFuncs map[string]HandlerFunc errorHandler ErrorHandler timeout time.Duration mw []Middleware + + running bool + + svc micro.Service } // NewServer returns a new instance of a Server. -func NewServer(name, natsURL string, opts ...ServerOption) (*Server, error) { - options := serverOptions{ - errorHandler: func(ctx context.Context, err error) {}, - } +func NewServer(cfg *ServerConfig, opts ...ServerOption) (*Server, error) { + cfg.setDefaults() for _, o := range opts { - o.apply(&options) + o.apply(cfg) } - nc, err := nats.Connect(natsURL) + nc, err := nats.Connect(cfg.NatsURL) + if err != nil { + return nil, err + } + + mc := micro.Config{ + Name: cfg.Name, + Version: cfg.Version, + } + if cfg.errorHandler != nil { + mc.ErrorHandler = func(s micro.Service, n *micro.NATSError) { + ctx, cancel := context.WithTimeout(context.Background(), defaultServerTimeout) + defer cancel() + cfg.errorHandler(ctx, n) + } + } + + svc, err := micro.AddService(nc, mc) if err != nil { return nil, err } return &Server{ nc: nc, - name: name, shutdownSignal: make(chan struct{}), handlerFuncs: make(map[string]HandlerFunc), timeout: defaultServerTimeout, - errorHandler: options.errorHandler, + errorHandler: cfg.errorHandler, + running: false, + svc: svc, }, nil } -type serverOptions struct { - errorHandler ErrorHandler -} - // ServerOption represents functional options for configuring a stormRPC Server. type ServerOption interface { - apply(*serverOptions) + apply(*ServerConfig) } type errorHandlerOption ErrorHandler -func (h errorHandlerOption) apply(opts *serverOptions) { +func (h errorHandlerOption) apply(opts *ServerConfig) { opts.errorHandler = ErrorHandler(h) } @@ -77,36 +121,54 @@ type ErrorHandler func(context.Context, error) // Handle registers a new HandlerFunc on the server. func (s *Server) Handle(subject string, fn HandlerFunc) { + s.mu.Lock() + defer s.mu.Unlock() + s.handlerFuncs[subject] = fn } // Run listens on the configured subjects. func (s *Server) Run() error { + s.mu.Lock() s.applyMiddlewares() - for k := range s.handlerFuncs { - _, err := s.nc.QueueSubscribe(k, s.name, s.handler) - if err != nil { + + for sub, fn := range s.handlerFuncs { + if err := s.createMicroEndpoint(sub, fn); err != nil { return err } } + s.running = true + s.mu.Unlock() + <-s.shutdownSignal return nil } // Shutdown stops the server. func (s *Server) Shutdown(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if err := s.svc.Stop(); err != nil { + return err + } + if err := s.nc.FlushWithContext(ctx); err != nil { return err } s.nc.Close() + s.running = false s.shutdownSignal <- struct{}{} return nil } // Subjects returns a list of all subjects with registered handler funcs. func (s *Server) Subjects() []string { + s.mu.Lock() + defer s.mu.Unlock() + subs := make([]string, 0, len(s.handlerFuncs)) for k := range s.handlerFuncs { subs = append(subs, k) @@ -117,7 +179,12 @@ func (s *Server) Subjects() []string { // Use applies all given middleware globally across all handlers. func (s *Server) Use(mw ...Middleware) { - s.mw = mw + s.mu.Lock() + defer s.mu.Unlock() + + if !s.running { + s.mw = append(s.mw, mw...) + } } func (s *Server) applyMiddlewares() { @@ -130,45 +197,59 @@ func (s *Server) applyMiddlewares() { } } -// handler serves the request to the specific request handler based on subject. -// wildcard subjects are not supported as you'll need to register a handler func for each -// rpc the server supports. -func (s *Server) handler(msg *nats.Msg) { - fn := s.handlerFuncs[msg.Subject] - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - dl := parseDeadlineHeader(msg.Header) - if !dl.IsZero() { // if deadline is present use it - ctx, cancel = context.WithDeadline(context.Background(), dl) - defer cancel() - } else { - ctx, cancel = context.WithTimeout(ctx, s.timeout) - defer cancel() - } - - req := Request{ - Msg: msg, - } - // pass headers into context for use in protobuf generated servers. - ctx = newContextWithHeaders(ctx, req.Header) - - resp := fn(ctx, req) - - if resp.Err != nil { - if resp.Header == nil { - resp.Header = nats.Header{} - } - setErrorHeader(resp.Header, resp.Err) - err := msg.RespondMsg(resp.Msg) - if err != nil { - s.errorHandler(ctx, err) - } - } +// createMicroEndpoint registers a HandlerFunc as a micro Endpoint +// allowing for automatic service discovery and observability. +func (s *Server) createMicroEndpoint(subject string, handlerFunc HandlerFunc) error { + return s.svc.AddEndpoint( + nameFromSubject(subject), + micro.ContextHandler(context.Background(), func(ctx context.Context, r micro.Request) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + ctx = newContextWithHeaders(ctx, nats.Header(r.Headers())) + + dl := parseDeadlineHeader(nats.Header(r.Headers())) + if !dl.IsZero() { // if deadline is present use it + ctx, cancel = context.WithDeadline(ctx, dl) + defer cancel() + } else { + ctx, cancel = context.WithTimeout(ctx, s.timeout) + defer cancel() + } + + resp := handlerFunc(ctx, Request{ + Msg: &nats.Msg{ + Subject: r.Subject(), + Header: nats.Header(r.Headers()), + Data: r.Data(), + }, + }) + + if resp.Err != nil { + if resp.Header == nil { + resp.Header = nats.Header{} + } + setErrorHeader(resp.Header, resp.Err) + + err := r.Error( + CodeFromErr(resp.Err).String(), + MessageFromErr(resp.Err), + nil, + micro.WithHeaders(micro.Headers(resp.Header)), + ) + if err != nil { + s.errorHandler(ctx, err) + } + } + + err := r.Respond(resp.Data, micro.WithHeaders(micro.Headers(resp.Header))) + if err != nil { + s.errorHandler(ctx, err) + } + }), micro.WithEndpointSubject(subject)) +} - err := msg.RespondMsg(resp.Msg) - if err != nil { - s.errorHandler(ctx, err) - } +// If a subject contains '.' delimiters replace them with '_' for the endpoint name. +func nameFromSubject(subj string) string { + return strings.ReplaceAll(subj, ".", "_") } diff --git a/server_test.go b/server_test.go index 8a29e1e..cebb17e 100644 --- a/server_test.go +++ b/server_test.go @@ -3,11 +3,8 @@ package stormrpc import ( "context" - "errors" "fmt" - "math/rand" "reflect" - "strconv" "testing" "time" @@ -30,9 +27,8 @@ func (t *testErrorHandler) clear() { func TestNewServer(t *testing.T) { teh := &testErrorHandler{} type args struct { - name string - natsURL string - opts []ServerOption + cfg *ServerConfig + opts []ServerOption } tests := []struct { name string @@ -44,12 +40,13 @@ func TestNewServer(t *testing.T) { { name: "defaults", args: args{ - name: "name", - natsURL: "nats://localhost:40897", - opts: nil, + cfg: &ServerConfig{ + Name: "name", + NatsURL: "nats://localhost:40897", + }, + opts: nil, }, want: &Server{ - name: "name", timeout: defaultServerTimeout, mw: nil, errorHandler: func(ctx context.Context, err error) {}, @@ -60,14 +57,15 @@ func TestNewServer(t *testing.T) { { name: "with error handler opt", args: args{ - name: "name", - natsURL: "nats://localhost:40897", + cfg: &ServerConfig{ + Name: "name", + NatsURL: "nats://localhost:40897", + }, opts: []ServerOption{ WithErrorHandler(teh.handle), }, }, want: &Server{ - name: "name", timeout: defaultServerTimeout, mw: nil, errorHandler: teh.handle, @@ -78,9 +76,11 @@ func TestNewServer(t *testing.T) { { name: "no nats running", args: args{ - name: "name", - natsURL: "nats://localhost:40897", - opts: nil, + cfg: &ServerConfig{ + Name: "name", + NatsURL: "nats://localhost:40897", + }, + opts: nil, }, want: nil, runNats: false, @@ -91,25 +91,10 @@ func TestNewServer(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Cleanup(teh.clear) if tt.runNats { - ns, err := server.NewServer(&server.Options{ - Port: 40897, - }) - if err != nil { - t.Fatal(err) - } - go ns.Start() - t.Cleanup(func() { - ns.Shutdown() - ns.WaitForShutdown() - }) - - if !ns.ReadyForConnections(1 * time.Second) { - t.Error("timeout waiting for nats server") - return - } + startNatsServer(t) } - got, err := NewServer(tt.args.name, tt.args.natsURL, tt.args.opts...) + got, err := NewServer(tt.args.cfg, tt.args.opts...) if (err != nil) != tt.wantErr { t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr) return @@ -119,9 +104,7 @@ func TestNewServer(t *testing.T) { return } - if got.name != tt.want.name { - t.Errorf("NewServer() name = %v, want %v", got.name, tt.want.name) - } else if got.timeout != tt.want.timeout { + if got.timeout != tt.want.timeout { t.Errorf("NewServer() timeout = %v, want %v", got.timeout, tt.want.timeout) } else if (got.errorHandler == nil) != (tt.want.errorHandler == nil) { t.Errorf("NewServer() errorHandler = %v, want %v", got.errorHandler == nil, tt.want.errorHandler == nil) @@ -144,24 +127,12 @@ func TestNewServer(t *testing.T) { } func TestServer_RunAndShutdown(t *testing.T) { - ns, err := server.NewServer(&server.Options{ - Port: 40897, - }) - if err != nil { - t.Fatal(err) - } - go ns.Start() - t.Cleanup(func() { - ns.Shutdown() - ns.WaitForShutdown() - }) - - if !ns.ReadyForConnections(1 * time.Second) { - t.Error("timeout waiting for nats server") - return - } + clientURL := startNatsServer(t) - srv, err := NewServer("test", ns.ClientURL()) + srv, err := NewServer(&ServerConfig{ + NatsURL: clientURL, + Name: "test", + }) if err != nil { t.Fatal(err) } @@ -186,261 +157,231 @@ func TestServer_RunAndShutdown(t *testing.T) { } } -func TestServer_handler(t *testing.T) { - ns, err := server.NewServer(&server.Options{ - Port: 40897, - }) - if err != nil { - t.Fatal(err) +func TestServer_Run(t *testing.T) { + type args struct { + ctx context.Context + req Request } - go ns.Start() - t.Cleanup(func() { - ns.Shutdown() - ns.WaitForShutdown() - }) - - if !ns.ReadyForConnections(1 * time.Second) { - t.Error("timeout waiting for nats server") - return + type endpoint struct { + name string + handler HandlerFunc } - - t.Run("successful handle", func(t *testing.T) { - t.Parallel() - - srv, err := NewServer("test", ns.ClientURL()) - if err != nil { - t.Fatal(err) - } - - subject := strconv.Itoa(rand.Int()) - srv.Handle(subject, func(ctx context.Context, r Request) Response { - _, ok := ctx.Deadline() - if !ok { - t.Error("context should have deadline") - } - return Response{ - Msg: &nats.Msg{ - Subject: r.Reply, - Data: []byte(`{"response":"1"}`), + tests := []struct { + name string + endpoints []endpoint + args args + wantErr bool + }{ + { + name: "ok", + endpoints: []endpoint{ + { + name: "test", + handler: HandlerFunc(func(ctx context.Context, r Request) Response { + _, ok := ctx.Deadline() + if !ok { + t.Error("context should have deadline") + } + return Response{ + Msg: &nats.Msg{ + Subject: r.Reply, + Data: []byte(`{"response":"1"}`), + }, + Err: nil, + } + }), }, - Err: nil, - } - }) - - runCh := make(chan error) - go func(ch chan error) { - runErr := srv.Run() - runCh <- runErr - }(runCh) - time.Sleep(250 * time.Millisecond) - - client, err := NewClient(ns.ClientURL()) - if err != nil { - t.Fatal(err) - } - - req, err := NewRequest(subject, map[string]string{"x": "D"}) - if err != nil { - t.Fatal(err) - } - resp := client.Do(context.Background(), req) - if resp.Err != nil { - t.Fatal(resp.Err) - } - - var result map[string]string - if err = resp.Decode(&result); err != nil { - t.Fatal(err) - } - - if result["response"] != "1" { - t.Fatalf("got = %v, want %v", result["response"], "1") - } - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - if err = srv.Shutdown(ctx); err != nil { - t.Fatal(err) - } - - err = <-runCh - if err != nil { - t.Fatal(err) - } - }) - - t.Run("context deadline exceeded", func(t *testing.T) { - t.Parallel() + }, + args: args{ + ctx: ctxWithTimeout(t, 5*time.Second), + req: mustNewRequest(t, "test", map[string]string{"hello": "world"}), + }, + wantErr: false, + }, + { + name: "context deadline exceeded", + endpoints: []endpoint{ + { + name: "test", + handler: HandlerFunc(func(ctx context.Context, r Request) Response { + _, ok := ctx.Deadline() + if !ok { + t.Error("context should have deadline") + } - srv, err := NewServer("test", ns.ClientURL()) - if err != nil { - t.Fatal(err) - } + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return NewErrorResponse(r.Reply, Error{ + Code: ErrorCodeDeadlineExceeded, + Message: ctx.Err().Error(), + }) + case <-ticker.C: + return NewErrorResponse(r.Reply, fmt.Errorf("somethings wrong")) + } + } + }), + }, + }, + args: args{ + ctx: ctxWithTimeout(t, 2*time.Second), + req: mustNewRequest(t, "test", map[string]string{"hello": "world"}), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientURL := startNatsServer(t) - subject := strconv.Itoa(rand.Int()) - srv.Handle(subject, func(ctx context.Context, r Request) Response { - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return NewErrorResponse(r.Reply, Error{ - Code: ErrorCodeDeadlineExceeded, - Message: ctx.Err().Error(), - }) - case <-ticker.C: - return NewErrorResponse(r.Reply, fmt.Errorf("somethings wrong")) - } + srv, err := NewServer(&ServerConfig{ + NatsURL: clientURL, + Name: "test", + }) + if err != nil { + t.Fatal(err) } - }) - - runCh := make(chan error) - go func(ch chan error) { - runErr := srv.Run() - runCh <- runErr - }(runCh) - time.Sleep(250 * time.Millisecond) - - client, err := NewClient(ns.ClientURL()) - if err != nil { - t.Fatal(err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - req, err := NewRequest(subject, map[string]string{"x": "D"}) - if err != nil { - t.Fatal(err) - } - resp := client.Do(ctx, req) - var e *Error - ok := errors.As(resp.Err, &e) - if !ok { - t.Fatalf("expected error to be of type Error, got %T", resp.Err) - } - if e.Code != ErrorCodeDeadlineExceeded { - t.Fatalf("e.Code got = %v, want %v", e.Code, ErrorCodeDeadlineExceeded) - } else if e.Message != context.DeadlineExceeded.Error() { - t.Fatalf("e.Message got = %v, want %v", e.Message, context.DeadlineExceeded.Error()) - } - - if err = srv.Shutdown(ctx); err != nil { - t.Fatal(err) - } - err = <-runCh - if err != nil { - t.Fatal(err) - } - }) - - t.Run("context deadline longer than default timeout", func(t *testing.T) { - t.Parallel() - - srv, err := NewServer("test", ns.ClientURL()) - if err != nil { - t.Fatal(err) - } + for _, ep := range tt.endpoints { + srv.Handle(ep.name, ep.handler) + } - timeout := 7 * time.Second + errs := make(chan error) + go func(srv *Server, errs chan error) { + errs <- srv.Run() + }(srv, errs) - subject := strconv.Itoa(rand.Int()) - srv.Handle(subject, func(ctx context.Context, r Request) Response { - dl, ok := ctx.Deadline() - if !ok { - t.Error("context should have deadline") + client, err := NewClient(clientURL) + if err != nil { + t.Fatal(err) } - var req map[string]time.Time - _ = r.Decode(&req) + resp := client.Do(tt.args.ctx, tt.args.req) + if (resp.Err != nil) != tt.wantErr { + t.Errorf("Client.Do() error = %v, wantErr %v", resp.Err, tt.wantErr) + } - if req["default"].After(dl) { - t.Errorf("req[default] got = %v, want before %v", req["default"], dl) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + if err = srv.Shutdown(ctx); err != nil { + t.Fatal(err) } - var resp Response - resp, err = NewResponse(r.Reply, map[string]string{"success": "ok"}) + err = <-errs if err != nil { - return NewErrorResponse(r.Reply, err) + t.Fatal(err) } - - return resp }) - - runCh := make(chan error) - go func(ch chan error) { - runErr := srv.Run() - runCh <- runErr - }(runCh) - time.Sleep(250 * time.Millisecond) - - client, err := NewClient(ns.ClientURL()) - if err != nil { - t.Fatal(err) - } - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - ctxWithDefaultServerTimeout, cancel2 := context.WithTimeout(ctx, srv.timeout) - defer cancel2() - - defaultDeadline, _ := ctxWithDefaultServerTimeout.Deadline() - - req, err := NewRequest(subject, map[string]time.Time{"default": defaultDeadline}) - if err != nil { - t.Fatal(err) - } - _ = client.Do(ctx, req) - - if err = srv.Shutdown(ctx); err != nil { - t.Fatal(err) - } - - err = <-runCh - if err != nil { - t.Fatal(err) - } - }) + } } func TestServer_Handle(t *testing.T) { - s := Server{ - handlerFuncs: make(map[string]HandlerFunc), - } - - t.Run("OK", func(t *testing.T) { - s.Handle("testing", func(ctx context.Context, r Request) Response { return Response{} }) + clientURL := startNatsServer(t) - if _, ok := s.handlerFuncs["testing"]; !ok { - t.Fatal("expected key testing to contain a handler func") - } + s, err := NewServer(&ServerConfig{ + Name: "test", + NatsURL: clientURL, }) + if err != nil { + t.Fatal(err) + } + s.Handle("testing", func(ctx context.Context, r Request) Response { return Response{} }) + if err != nil { + t.Fatal(err) + } + + _, ok := s.handlerFuncs["testing"] + if !ok { + t.Errorf("expected handler to exist for subject %s", "testing") + } } func TestServer_Subjects(t *testing.T) { - s := Server{ - handlerFuncs: make(map[string]HandlerFunc), + type endpoint struct { + name string + handler HandlerFunc + } + tests := []struct { + name string + endpoints []endpoint + want []string + }{ + { + name: "ok", + endpoints: []endpoint{ + { + name: "test", + handler: HandlerFunc(func(ctx context.Context, r Request) Response { return Response{} }), + }, + }, + want: []string{"test"}, + }, + { + name: "multiple endpoints", + endpoints: []endpoint{ + { + name: "1", + handler: HandlerFunc(func(ctx context.Context, r Request) Response { return Response{} }), + }, + { + name: "2", + handler: HandlerFunc(func(ctx context.Context, r Request) Response { return Response{} }), + }, + { + name: "3", + handler: HandlerFunc(func(ctx context.Context, r Request) Response { return Response{} }), + }, + }, + want: []string{"1", "2", "3"}, + }, + { + name: "duplicate endpoints", + endpoints: []endpoint{ + { + name: "1", + handler: HandlerFunc(func(ctx context.Context, r Request) Response { return Response{} }), + }, + { + name: "1", + handler: HandlerFunc(func(ctx context.Context, r Request) Response { return Response{} }), + }, + { + name: "2", + handler: HandlerFunc(func(ctx context.Context, r Request) Response { return Response{} }), + }, + }, + want: []string{"1", "2"}, + }, } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientURL := startNatsServer(t) - s.Handle("testing", func(ctx context.Context, r Request) Response { return Response{} }) - s.Handle("testing", func(ctx context.Context, r Request) Response { return Response{} }) - s.Handle("1, 2, 3", func(ctx context.Context, r Request) Response { return Response{} }) + srv, err := NewServer(&ServerConfig{ + Name: "test", + NatsURL: clientURL, + }) + if err != nil { + t.Fatal(err) + } - want := []string{"testing", "1, 2, 3"} + for _, ep := range tt.endpoints { + srv.Handle(ep.name, ep.handler) + } - got := s.Subjects() + got := srv.Subjects() - if !sameStringSlice(got, want) { - t.Fatalf("got = %v, want %v", got, want) + if !sameStringSlice(got, tt.want) { + t.Fatalf("got = %v, want %v", got, tt.want) + } + }) } } func TestServer_Use(t *testing.T) { type fields struct { - nc *nats.Conn - name string shutdownSignal chan struct{} handlerFuncs map[string]HandlerFunc errorHandler ErrorHandler @@ -474,8 +415,6 @@ func TestServer_Use(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &Server{ - nc: tt.fields.nc, - name: tt.fields.name, shutdownSignal: tt.fields.shutdownSignal, handlerFuncs: tt.fields.handlerFuncs, errorHandler: tt.fields.errorHandler, @@ -484,8 +423,8 @@ func TestServer_Use(t *testing.T) { } s.Use(tt.args.mw...) - if !reflect.DeepEqual(tt.args.mw, s.mw) { - t.Fatalf("got = %v, want %v", s.mw, tt.args.mw) + if len(tt.args.mw) != len(s.mw) { + t.Fatalf("expected slices to be the same length got = %v, want %v", s.mw, tt.args.mw) } }) } @@ -493,8 +432,6 @@ func TestServer_Use(t *testing.T) { func TestServer_applyMiddlewares(t *testing.T) { type fields struct { - nc *nats.Conn - name string shutdownSignal chan struct{} handlerFuncs map[string]HandlerFunc errorHandler ErrorHandler @@ -534,8 +471,6 @@ func TestServer_applyMiddlewares(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &Server{ - nc: tt.fields.nc, - name: tt.fields.name, shutdownSignal: tt.fields.shutdownSignal, handlerFuncs: tt.fields.handlerFuncs, errorHandler: tt.fields.errorHandler, @@ -557,6 +492,48 @@ func TestServer_applyMiddlewares(t *testing.T) { } } +func startNatsServer(tb testing.TB) string { + tb.Helper() + + ns, err := server.NewServer(&server.Options{ + Port: 40897, + }) + if err != nil { + tb.Fatal(err) + } + + ns.Start() + + tb.Cleanup(func() { + ns.Shutdown() + ns.WaitForShutdown() + }) + + if !ns.ReadyForConnections(1 * time.Second) { + tb.Fatal("timeout waiting for nats server") + } + + return ns.ClientURL() +} + +func mustNewRequest(tb testing.TB, subject string, body any, opts ...RequestOption) Request { + req, err := NewRequest(subject, body, opts...) + if err != nil { + tb.Fatal(err) + } + + return req +} + +func ctxWithTimeout(tb testing.TB, timeout time.Duration) context.Context { + tb.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + tb.Cleanup(cancel) + + return ctx +} + func sameStringSlice(x, y []string) bool { if len(x) != len(y) { return false