diff --git a/go.mod b/go.mod index 1190476..c4605d7 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/loopholelabs/frpc-go go 1.22 require ( - github.com/loopholelabs/frisbee-go v0.9.2 + github.com/loopholelabs/frisbee-go v0.10.0 github.com/loopholelabs/logging v0.3.0 github.com/loopholelabs/polyglot/v2 v2.0.2 github.com/loopholelabs/testing v0.2.3 diff --git a/go.sum b/go.sum index a56c635..cc49b90 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/loopholelabs/common v0.4.9 h1:9MPUYlZZ/qx3Kt8LXgXxcSXthrM91od8026c4DlGpAU= github.com/loopholelabs/common v0.4.9/go.mod h1:Wop5srN1wYT+mdQ9gZ+kn2I9qKAyVd0FB48pThwIa9M= -github.com/loopholelabs/frisbee-go v0.9.2 h1:TiihZMh/aq82olM3JEY0h5WiSHOji9lheaQ26V35zK0= -github.com/loopholelabs/frisbee-go v0.9.2/go.mod h1:RwKglItbNcQq9UW6Vm2aOwOWXR6wTxUNjvuXcCIyJkU= +github.com/loopholelabs/frisbee-go v0.10.0 h1:eqdDqm44V23GMxhjDL9OBz2Fxecsu42M0KHM8kQObBQ= +github.com/loopholelabs/frisbee-go v0.10.0/go.mod h1:RwKglItbNcQq9UW6Vm2aOwOWXR6wTxUNjvuXcCIyJkU= github.com/loopholelabs/logging v0.3.0 h1:Rfo9fGdBk4nwbNdiLrVUF7JkYoC67dvGDodkRe81xF0= github.com/loopholelabs/logging v0.3.0/go.mod h1:uRDUydiqPqKbZkb0WoQ3dfyAcJ2iOMhxdEafZssLVv0= github.com/loopholelabs/polyglot/v2 v2.0.2 h1:v308fg2ZKSvkKDnWgBnDvvmiu4YypCxcDe5Ih5GUVnY= diff --git a/pkg/generator/generator.go b/pkg/generator/generator.go index cf48dd7..4349c3b 100644 --- a/pkg/generator/generator.go +++ b/pkg/generator/generator.go @@ -114,7 +114,6 @@ func (g *Generator) Generate(req *pluginpb.CodeGeneratorRequest) (res *pluginpb. "package": packageName, "requiredImports": requiredImports, "serviceImports": serviceImports, - "methodImports": methodImports, "streamMethodImports": streamMethodImports, "numServices": numServices, "numMethods": numMethods, diff --git a/pkg/generator/imports.go b/pkg/generator/imports.go index 4cd30f7..2b3bdc0 100644 --- a/pkg/generator/imports.go +++ b/pkg/generator/imports.go @@ -6,11 +6,12 @@ var ( requiredImports = []string{ "errors", "net", + "sync", + "context", "github.com/loopholelabs/polyglot/v2", } serviceImports = []string{ - "context", "crypto/tls", "github.com/loopholelabs/frisbee-go", "github.com/loopholelabs/frisbee-go/pkg/packet", @@ -21,8 +22,4 @@ var ( "sync/atomic", "io", } - - methodImports = []string{ - "sync", - } ) diff --git a/pkg/generator/test/server.go b/pkg/generator/test/server.go index d0d151f..8cd8a43 100644 --- a/pkg/generator/test/server.go +++ b/pkg/generator/test/server.go @@ -23,7 +23,7 @@ func (s svc) Echo(_ context.Context, request *Request) (*Response, error) { }}, nil } -func (s svc) EchoStream(srv *EchoStreamServer) error { +func (s svc) EchoStream(_ context.Context, srv *EchoStreamServer) error { for { request, err := srv.Recv() if err == io.EOF { @@ -52,7 +52,7 @@ func (s svc) Testy(_ context.Context, _ *SearchResponse) (*StockPricesWrapper, e panic("not implemented") } -func (s svc) Search(req *SearchResponse, srv *SearchServer) error { +func (s svc) Search(_ context.Context, req *SearchResponse, srv *SearchServer) error { assert.Equal(s.t, 1, len(req.Results)) for i := 0; i < 10; i++ { err := srv.Send(&Response{Message: "Hello World", Test: &Data{ @@ -64,7 +64,7 @@ func (s svc) Search(req *SearchResponse, srv *SearchServer) error { return srv.CloseSend() } -func (s svc) Upload(srv *UploadServer) error { +func (s svc) Upload(_ context.Context, srv *UploadServer) error { received := 0 for { res, err := srv.Recv() diff --git a/pkg/generator/test/test.frpc.go b/pkg/generator/test/test.frpc.go index 0a2ebeb..a924719 100644 --- a/pkg/generator/test/test.frpc.go +++ b/pkg/generator/test/test.frpc.go @@ -1,21 +1,20 @@ -// Code generated by fRPC Go v0.9.2, DO NOT EDIT. +// Code generated by fRPC Go v0.10.0, DO NOT EDIT. // source: test.proto package test import ( + "context" "errors" "github.com/loopholelabs/polyglot/v2" "net" + "sync" - "context" "crypto/tls" "github.com/loopholelabs/frisbee-go" "github.com/loopholelabs/frisbee-go/pkg/packet" "github.com/loopholelabs/logging/types" - "sync" - "io" "sync/atomic" ) @@ -1468,12 +1467,12 @@ func (x *StockPricesSuperWrap) decode(d *polyglot.BufferDecoder) error { type EchoService interface { Echo(context.Context, *Request) (*Response, error) - EchoStream(srv *EchoStreamServer) error + EchoStream(context.Context, *EchoStreamServer) error Testy(context.Context, *SearchResponse) (*StockPricesWrapper, error) - Search(req *SearchResponse, srv *SearchServer) error + Search(context.Context, *SearchResponse, *SearchServer) error - Upload(srv *UploadServer) error + Upload(context.Context, *UploadServer) error } const connectionContextKey int = 1000 @@ -1512,12 +1511,12 @@ func (x *RPCStreamOpen) decode(d *polyglot.BufferDecoder) error { } type Server struct { - *frisbee.Server - onClosed func(*frisbee.Async, error) + server *frisbee.Server + wg sync.WaitGroup } func NewServer(echoService EchoService, tlsConfig *tls.Config, logger types.Logger) (*Server, error) { - var s *Server + s := new(Server) table := make(frisbee.HandlerTable) table[10] = func(ctx context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action frisbee.Action) { @@ -1560,21 +1559,20 @@ func NewServer(echoService EchoService, tlsConfig *tls.Config, logger types.Logg } return } - var fsrv *frisbee.Server var err error if tlsConfig != nil { - fsrv, err = frisbee.NewServer(table, frisbee.WithTLS(tlsConfig), frisbee.WithLogger(logger)) + s.server, err = frisbee.NewServer(table, context.Background(), frisbee.WithTLS(tlsConfig), frisbee.WithLogger(logger)) if err != nil { return nil, err } } else { - fsrv, err = frisbee.NewServer(table, frisbee.WithLogger(logger)) + s.server, err = frisbee.NewServer(table, context.Background(), frisbee.WithLogger(logger)) if err != nil { return nil, err } } - fsrv.SetStreamHandler(func(conn *frisbee.Async, stream *frisbee.Stream) { + s.server.SetStreamHandler(func(ctx context.Context, stream *frisbee.Stream) { p, err := stream.ReadPacket() if err != nil { return @@ -1587,34 +1585,51 @@ func NewServer(echoService EchoService, tlsConfig *tls.Config, logger types.Logg } switch open.operation { case 11: - s.createEchoStreamServer(echoService, stream) + s.createEchoStreamServer(ctx, echoService, stream) case 13: - s.createSearchServer(echoService, stream) + s.createSearchServer(ctx, echoService, stream) case 14: - s.createUploadServer(echoService, stream) + s.createUploadServer(ctx, echoService, stream) } }) - fsrv.ConnContext = func(ctx context.Context, conn *frisbee.Async) context.Context { + s.server.ConnContext = func(ctx context.Context, conn *frisbee.Async) context.Context { return context.WithValue(ctx, connectionContextKey, conn) } - s, err = &Server{ - Server: fsrv, - }, nil - fsrv.SetOnClosed(func(async *frisbee.Async, err error) { - if s.onClosed != nil { - s.onClosed(async, err) - } - }) - return s, err + return s, nil } func (s *Server) SetOnClosed(f func(*frisbee.Async, error)) error { - if f == nil { - return frisbee.OnClosedNil + return s.server.SetOnClosed(f) +} + +func (s *Server) SetPreWrite(f func()) error { + return s.server.SetPreWrite(f) +} + +func (s *Server) SetConcurrency(concurrency uint64) { + s.server.SetConcurrency(concurrency) +} + +func (s *Server) Start(addr string) error { + return s.server.Start(addr) +} + +func (s *Server) StartWithListener(listener net.Listener) error { + return s.server.StartWithListener(listener) +} + +func (s *Server) ServeConn(conn net.Conn) { + s.server.ServeConn(conn) +} + +func (s *Server) Shutdown() error { + err := s.server.Shutdown() + if err != nil { + return err } - s.onClosed = f + s.wg.Wait() return nil } @@ -1626,7 +1641,7 @@ type EchoStreamServer struct { closed *atomic.Bool } -func (s *Server) createEchoStreamServer(echoService EchoService, stream *frisbee.Stream) { +func (s *Server) createEchoStreamServer(ctx context.Context, echoService EchoService, stream *frisbee.Stream) { srv := &EchoStreamServer{ stream: stream, } @@ -1656,8 +1671,9 @@ func (s *Server) createEchoStreamServer(echoService EchoService, stream *frisbee return srv.stream.WritePacket(p) } + s.wg.Add(1) go func() { - err := echoService.EchoStream(srv) + err := echoService.EchoStream(ctx, srv) if err != nil { res := Response{error: err} res.flags = SetErrorFlag(res.flags, true) @@ -1665,6 +1681,7 @@ func (s *Server) createEchoStreamServer(echoService EchoService, stream *frisbee } else { srv.CloseSend() } + s.wg.Done() }() } @@ -1702,7 +1719,7 @@ type SearchServer struct { closed *atomic.Bool } -func (s *Server) createSearchServer(echoService EchoService, stream *frisbee.Stream) { +func (s *Server) createSearchServer(ctx context.Context, echoService EchoService, stream *frisbee.Stream) { srv := &SearchServer{ stream: stream, } @@ -1721,9 +1738,10 @@ func (s *Server) createSearchServer(echoService EchoService, stream *frisbee.Str } req := NewSearchResponse() err = req.Decode((*incoming.Content).Bytes()[:incoming.Metadata.ContentLength]) + s.wg.Add(1) go func() { - err := echoService.Search(req, srv) + err := echoService.Search(ctx, req, srv) if err != nil { res := Response{error: err} res.flags = SetErrorFlag(res.flags, true) @@ -1731,6 +1749,7 @@ func (s *Server) createSearchServer(echoService EchoService, stream *frisbee.Str } else { srv.CloseSend() } + s.wg.Done() }() } @@ -1761,7 +1780,7 @@ type UploadServer struct { closed *atomic.Bool } -func (s *Server) createUploadServer(echoService EchoService, stream *frisbee.Stream) { +func (s *Server) createUploadServer(ctx context.Context, echoService EchoService, stream *frisbee.Stream) { srv := &UploadServer{ stream: stream, } @@ -1791,8 +1810,9 @@ func (s *Server) createUploadServer(echoService EchoService, stream *frisbee.Str return srv.stream.WritePacket(p) } + s.wg.Add(1) go func() { - err := echoService.Upload(srv) + err := echoService.Upload(ctx, srv) if err != nil { res := Response{error: err} res.flags = SetErrorFlag(res.flags, true) @@ -1800,6 +1820,7 @@ func (s *Server) createUploadServer(echoService EchoService, stream *frisbee.Str } else { srv.CloseSend() } + s.wg.Done() }() } diff --git a/templates/imports.templ b/templates/imports.templ index 71cc312..377564e 100644 --- a/templates/imports.templ +++ b/templates/imports.templ @@ -8,11 +8,6 @@ import ( "{{$im}}" {{end -}} {{end}} -{{ if .numMethods }} -{{range $im := .methodImports -}} - "{{$im}}" -{{end -}} -{{end}} {{ if .numStreamMethods -}} {{range $im := .streamMethodImports -}} "{{$im}}" diff --git a/templates/interfaces.templ b/templates/interfaces.templ index cbefa37..30aac98 100644 --- a/templates/interfaces.templ +++ b/templates/interfaces.templ @@ -11,9 +11,9 @@ type {{ CamelCaseName .Name }} interface { {{ range $i, $v := MakeIterable .Methods.Len -}} {{ $method := $.Methods.Get $i -}} {{ if $method.IsStreamingClient }} - {{ CamelCaseName $method.Name }} (srv *{{ CamelCaseName $method.Name }}Server) error + {{ CamelCaseName $method.Name }} (context.Context, *{{ CamelCaseName $method.Name }}Server) error {{ else if $method.IsStreamingServer }} - {{ CamelCaseName $method.Name }} (req *{{ CamelCase $method.Input.FullName }}, srv *{{ CamelCaseName $method.Name }}Server) error + {{ CamelCaseName $method.Name }} (context.Context, *{{ CamelCase $method.Input.FullName }}, *{{ CamelCaseName $method.Name }}Server) error {{else -}} {{ CamelCaseName $method.Name }} (context.Context, *{{ CamelCase $method.Input.FullName }}) (*{{ CamelCase $method.Output.FullName }}, error) {{end -}} diff --git a/templates/server.templ b/templates/server.templ index 65b5fd5..20dadb5 100644 --- a/templates/server.templ +++ b/templates/server.templ @@ -29,30 +29,31 @@ func (x *RPCStreamOpen) decode(d *polyglot.BufferDecoder) error { {{ end -}} type Server struct { - *frisbee.Server - onClosed func(*frisbee.Async, error) + server *frisbee.Server + {{ if .numStreamMethods -}} + wg sync.WaitGroup + {{end -}} } func NewServer({{ GetServerFields .services }}, tlsConfig *tls.Config, logger types.Logger) (*Server, error) { - var s *Server + s := new(Server) table := make(frisbee.HandlerTable) {{template "serverhandlers" .services -}} - var fsrv *frisbee.Server var err error if tlsConfig != nil { - fsrv, err = frisbee.NewServer(table, frisbee.WithTLS(tlsConfig), frisbee.WithLogger(logger)) + s.server, err = frisbee.NewServer(table, context.Background(), frisbee.WithTLS(tlsConfig), frisbee.WithLogger(logger)) if err != nil { return nil, err } } else { - fsrv, err = frisbee.NewServer(table, frisbee.WithLogger(logger)) + s.server, err = frisbee.NewServer(table, context.Background(), frisbee.WithLogger(logger)) if err != nil { return nil, err } } {{ if .numStreamMethods -}} - fsrv.SetStreamHandler(func(conn *frisbee.Async, stream *frisbee.Stream) { + s.server.SetStreamHandler(func(ctx context.Context, stream *frisbee.Stream) { p, err := stream.ReadPacket() if err != nil { return @@ -72,7 +73,7 @@ func NewServer({{ GetServerFields .services }}, tlsConfig *tls.Config, logger ty {{ $opIndex := call $counter -}} {{ if or $method.IsStreamingClient $method.IsStreamingServer -}} case {{ $opIndex }}: - s.create{{ CamelCaseName $method.Name }}Server({{ FirstLowerCase (CamelCaseName $service.Name) }}, stream) + s.create{{ CamelCaseName $method.Name }}Server(ctx, {{ FirstLowerCase (CamelCaseName $service.Name) }}, stream) {{end -}} {{end -}} {{end -}} @@ -81,26 +82,45 @@ func NewServer({{ GetServerFields .services }}, tlsConfig *tls.Config, logger ty {{ end -}} - fsrv.ConnContext = func (ctx context.Context, conn *frisbee.Async) context.Context { + s.server.ConnContext = func (ctx context.Context, conn *frisbee.Async) context.Context { return context.WithValue(ctx, connectionContextKey, conn) } - s, err = &Server{ - Server: fsrv, - }, nil - fsrv.SetOnClosed(func(async *frisbee.Async, err error) { - if s.onClosed != nil { - s.onClosed(async, err) - } - }) - return s, err + return s, nil } func (s *Server) SetOnClosed(f func(*frisbee.Async, error)) error { - if f == nil { - return frisbee.OnClosedNil + return s.server.SetOnClosed(f) +} + +func (s *Server) SetPreWrite(f func()) error { + return s.server.SetPreWrite(f) +} + +func (s *Server) SetConcurrency(concurrency uint64) { + s.server.SetConcurrency(concurrency) +} + +func (s *Server) Start(addr string) error { + return s.server.Start(addr) +} + +func (s *Server) StartWithListener(listener net.Listener) error { + return s.server.StartWithListener(listener) +} + +func (s *Server) ServeConn(conn net.Conn) { + s.server.ServeConn(conn) +} + +func (s *Server) Shutdown() error { + err := s.server.Shutdown() + if err != nil { + return err } - s.onClosed = f + {{ if .numStreamMethods -}} + s.wg.Wait() + {{end -}} return nil } @@ -123,7 +143,7 @@ func (s *Server) SetOnClosed(f func(*frisbee.Async, error)) error { closed *atomic.Bool } - func (s *Server) create{{ CamelCaseName $method.Name}}Server ({{ FirstLowerCase (CamelCaseName $service.Name) }} {{ CamelCaseName $service.Name }}, stream *frisbee.Stream) { + func (s *Server) create{{ CamelCaseName $method.Name}}Server (ctx context.Context, {{ FirstLowerCase (CamelCaseName $service.Name) }} {{ CamelCaseName $service.Name }}, stream *frisbee.Stream) { srv := &{{ CamelCaseName $method.Name }}Server{ stream: stream, } @@ -164,12 +184,12 @@ func (s *Server) SetOnClosed(f func(*frisbee.Async, error)) error { req := New{{ CamelCase $method.Input.FullName }}() err = req.Decode((*incoming.Content).Bytes()[:incoming.Metadata.ContentLength]) {{ end -}} - + s.wg.Add(1) go func() { {{ if $method.IsStreamingClient -}} - err := {{ FirstLowerCaseName $service.Name }}.{{ CamelCaseName $method.Name }}(srv) + err := {{ FirstLowerCaseName $service.Name }}.{{ CamelCaseName $method.Name }}(ctx, srv) {{ else }} - err := {{ FirstLowerCaseName $service.Name }}.{{ CamelCaseName $method.Name }}(req, srv) + err := {{ FirstLowerCaseName $service.Name }}.{{ CamelCaseName $method.Name }}(ctx, req, srv) {{ end -}} if err != nil { res := {{ CamelCase $method.Output.FullName }}{error: err} @@ -178,6 +198,7 @@ func (s *Server) SetOnClosed(f func(*frisbee.Async, error)) error { } else { srv.CloseSend() } + s.wg.Done() }() } diff --git a/version/current_version b/version/current_version index e6e6db4..bf057db 100644 --- a/version/current_version +++ b/version/current_version @@ -1 +1 @@ -v0.9.2 +v0.10.0