diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..430a05c --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.DS_Store +.idea/ +.vscode/ +*.tmp +*.log diff --git a/.travis.yml b/.travis.yml index ce4f6b9..618ae99 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,7 @@ language: go go: - - 1.9.x - - 1.10.x + - 1.12.x before_install: - go get github.com/mattn/goveralls diff --git a/README.md b/README.md index fadc53d..9f60201 100644 --- a/README.md +++ b/README.md @@ -28,11 +28,11 @@ Convert RPC style function into http.Handler ### Create new hrpc Manager ```go -m := hrpc.New(hrpc.Config{ - RequestDecoder: func(r *http.Request, dst interface{}) error { +m := hrpc.Manager{ + Decoder: func(r *http.Request, dst interface{}) error { return json.NewDecoder(r.Body).Decode(dst) }, - ResponseEncoder: func(w http.ResponseWriter, r *http.Request, res interface{}) { + Encoder: func(w http.ResponseWriter, r *http.Request, res interface{}) { w.Header().Set("Content-Type", "application/json; charset=utf-8") json.NewEncoder(w).Encode(res) }, @@ -45,7 +45,7 @@ m := hrpc.New(hrpc.Config{ json.NewEncoder(w).Encode(res) }, Validate: true, -}) +} ``` ### RPC style function @@ -55,6 +55,14 @@ type UserRequest struct { ID int `json:"id"` } +func (req *UserRequest) Valid() error { + // Valid will be called when decode, if set validate to true + if req.ID <= 0 { + return fmt.Errorf("invalid id") + } + return nil +} + type UserResponse struct { ID int `json:"id"` Username string `json:"username"` diff --git a/go.mod b/go.mod index 5bcfba6..e131514 100644 --- a/go.mod +++ b/go.mod @@ -1 +1,3 @@ module github.com/acoshift/hrpc + +go 1.12 diff --git a/hrpc.go b/hrpc.go index d35df6e..811aadb 100644 --- a/hrpc.go +++ b/hrpc.go @@ -5,37 +5,47 @@ import ( "reflect" ) -// Manager type +// Decoder is the request decoder +type Decoder func(*http.Request, interface{}) error + +// Encoder is the response encoder +type Encoder func(http.ResponseWriter, *http.Request, interface{}) + +// ErrorEncoder is the error response encoder +type ErrorEncoder func(http.ResponseWriter, *http.Request, error) + +// Manager is the hrpc manager type Manager struct { - c Config + Decoder Decoder + Encoder Encoder + ErrorEncoder ErrorEncoder + Validate bool // set to true to validate request after decode using Validatable interface } -// Config is the hrpc config -type Config struct { - RequestDecoder func(*http.Request, interface{}) error - ResponseEncoder func(http.ResponseWriter, *http.Request, interface{}) - ErrorEncoder func(http.ResponseWriter, *http.Request, error) - Validate bool // set to true to validate request after decode using Validatable interface +func (m *Manager) decoder() Decoder { + if m.Decoder == nil { + return func(*http.Request, interface{}) error { return nil } + } + return m.Decoder } -// Validatable interface -type Validatable interface { - Validate() error +func (m *Manager) encoder() Encoder { + if m.Encoder == nil { + return func(http.ResponseWriter, *http.Request, interface{}) {} + } + return m.Encoder } -// New creates new manager -func New(config Config) *Manager { - m := &Manager{config} - if config.RequestDecoder == nil { - m.c.RequestDecoder = func(*http.Request, interface{}) error { return nil } +func (m *Manager) errorEncoder() ErrorEncoder { + if m.ErrorEncoder == nil { + return func(http.ResponseWriter, *http.Request, error) {} } - if config.ResponseEncoder == nil { - m.c.ResponseEncoder = func(http.ResponseWriter, *http.Request, interface{}) {} - } - if config.ErrorEncoder == nil { - m.c.ErrorEncoder = func(http.ResponseWriter, *http.Request, error) {} - } - return m + return m.ErrorEncoder +} + +// Validatable interface +type Validatable interface { + Valid() error } type mapIndex int @@ -120,6 +130,10 @@ func (m *Manager) Handler(f interface{}) http.Handler { } } + encoder := m.encoder() + decoder := m.decoder() + errorEncoder := m.errorEncoder() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { vIn := make([]reflect.Value, numIn) // inject context @@ -130,17 +144,17 @@ func (m *Manager) Handler(f interface{}) http.Handler { if i, ok := mapIn[miInterface]; ok { rfReq := reflect.New(typ) req := rfReq.Interface() - err := m.c.RequestDecoder(r, req) + err := decoder(r, req) if err != nil { - m.c.ErrorEncoder(w, r, err) + errorEncoder(w, r, err) return } - if m.c.Validate { + if m.Validate { if req, ok := req.(Validatable); ok { - err = req.Validate() + err = req.Valid() if err != nil { - m.c.ErrorEncoder(w, r, err) + errorEncoder(w, r, err) return } } @@ -161,14 +175,14 @@ func (m *Manager) Handler(f interface{}) http.Handler { if i, ok := mapOut[miError]; ok { if vErr := vOut[i]; !vErr.IsNil() { if err, ok := vErr.Interface().(error); ok && err != nil { - m.c.ErrorEncoder(w, r, err) + errorEncoder(w, r, err) return } } } // check response if i, ok := mapOut[miInterface]; ok { - m.c.ResponseEncoder(w, r, vOut[i].Interface()) + encoder(w, r, vOut[i].Interface()) } // if f is not return response, it may already call from native response writer diff --git a/hrpc_test.go b/hrpc_test.go index 8fdad4e..ba6c688 100644 --- a/hrpc_test.go +++ b/hrpc_test.go @@ -12,7 +12,7 @@ import ( "testing" ) -func jsonRequestDecoder(r *http.Request, dst interface{}) error { +func jsonDecoder(r *http.Request, dst interface{}) error { return json.NewDecoder(r.Body).Decode(dst) } @@ -20,7 +20,7 @@ type requestType struct { Data int } -func (req *requestType) Validate() error { +func (req *requestType) Valid() error { if req.Data < 0 { return errors.New("invalid data") } @@ -48,16 +48,16 @@ func TestHandler(t *testing.T) { w = httptest.NewRecorder() } - m := New(Config{ - RequestDecoder: jsonRequestDecoder, - ResponseEncoder: func(w http.ResponseWriter, r *http.Request, res interface{}) { + m := Manager{ + Decoder: jsonDecoder, + Encoder: func(w http.ResponseWriter, r *http.Request, res interface{}) { callSuccess = true }, ErrorEncoder: func(w http.ResponseWriter, r *http.Request, err error) { callError = true }, Validate: true, - }) + } h := m.Handler(func(ctx context.Context, req *requestType) (interface{}, error) { if req.Data != 1 { @@ -146,7 +146,7 @@ func TestHandler(t *testing.T) { } func TestDefault(t *testing.T) { - m := New(Config{}) + m := Manager{} i := 0 h := m.Handler(func(ctx context.Context, req *requestType) (interface{}, error) { if i == 0 { @@ -171,7 +171,7 @@ func TestInvalidF(t *testing.T) { t.Fatal("should panic") } } - m := New(Config{}) + m := Manager{} func() { defer p() m.Handler(1) @@ -191,11 +191,11 @@ func TestInvalidF(t *testing.T) { } func ExampleManager() { - m := New(Config{ - RequestDecoder: func(r *http.Request, dst interface{}) error { + m := Manager{ + Decoder: func(r *http.Request, dst interface{}) error { return json.NewDecoder(r.Body).Decode(dst) }, - ResponseEncoder: func(w http.ResponseWriter, r *http.Request, res interface{}) { + Encoder: func(w http.ResponseWriter, r *http.Request, res interface{}) { w.Header().Set("Content-Type", "application/json; charset=utf-8") json.NewEncoder(w).Encode(res) }, @@ -208,7 +208,7 @@ func ExampleManager() { json.NewEncoder(w).Encode(res) }, Validate: true, - }) + } http.Handle("/user.get", m.Handler(func(ctx context.Context, req *struct { ID string `json:"id"`