Skip to content

Commit

Permalink
add hooks ok and error (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
acoshift authored Oct 19, 2022
1 parent b716d97 commit b11ed80
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
43 changes: 36 additions & 7 deletions hrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type Manager struct {
Encoder Encoder
ErrorEncoder ErrorEncoder
Validate bool // set to true to validate request after decode using Validatable interface
onErrorFuncs []func(http.ResponseWriter, *http.Request, any, error)
onOKFuncs []func(http.ResponseWriter, *http.Request, any, any)
}

func (m *Manager) decoder() Decoder {
Expand All @@ -43,6 +45,16 @@ func (m *Manager) errorEncoder() ErrorEncoder {
return m.ErrorEncoder
}

// OnError calls f when error
func (m *Manager) OnError(f func(w http.ResponseWriter, r *http.Request, req any, err error)) {
m.onErrorFuncs = append(m.onErrorFuncs, f)
}

// OnOK calls f before encode ok response
func (m *Manager) OnOK(f func(w http.ResponseWriter, r *http.Request, req any, res any)) {
m.onOKFuncs = append(m.onOKFuncs, f)
}

// Validatable interface
type Validatable interface {
Valid() error
Expand Down Expand Up @@ -73,6 +85,14 @@ func setOrPanic(m map[mapIndex]int, k mapIndex, v int) {
m[k] = v
}

func (m *Manager) encodeAndHookError(w http.ResponseWriter, r *http.Request, req any, err error) {
m.errorEncoder()(w, r, err)

for _, f := range m.onErrorFuncs {
f(w, r, req, err)
}
}

// Handler func,
// f must be a function which have at least 2 inputs and 2 outputs.
// first input must be a context.
Expand Down Expand Up @@ -136,9 +156,13 @@ func (m *Manager) Handler(f any) http.Handler {

encoder := m.encoder()
decoder := m.decoder()
errorEncoder := m.errorEncoder()

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var (
req any
res any
)

vIn := make([]reflect.Value, numIn)
// inject context
if i, ok := mapIn[miContext]; ok {
Expand All @@ -147,18 +171,18 @@ func (m *Manager) Handler(f any) http.Handler {
// inject request interface
if i, ok := mapIn[miAny]; ok {
rfReq := reflect.New(infType)
req := rfReq.Interface()
req = rfReq.Interface()
err := decoder(r, req)
if err != nil {
errorEncoder(w, r, err)
m.encodeAndHookError(w, r, req, err)
return
}

if m.Validate {
if req, ok := req.(Validatable); ok {
err = req.Valid()
if err != nil {
errorEncoder(w, r, err)
m.encodeAndHookError(w, r, req, err)
return
}
}
Expand All @@ -183,16 +207,21 @@ func (m *Manager) Handler(f any) http.Handler {
if i, ok := mapOut[miError]; ok {
if vErr := vOut[i]; !vErr.IsNil() {
if err, ok := vErr.Interface().(error); ok && err != nil {
errorEncoder(w, r, err)
m.encodeAndHookError(w, r, req, err)
return
}
}
}

// check response
if i, ok := mapOut[miAny]; ok {
encoder(w, r, vOut[i].Interface())
res = vOut[i].Interface()
encoder(w, r, res)
}

// if f is not return response, it may already call from native response writer
// run ok hooks
for _, f := range m.onOKFuncs {
f(w, r, req, res)
}
})
}
22 changes: 22 additions & 0 deletions hrpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ func TestHandler(t *testing.T) {
successBody := &bytes.Buffer{}
errorBody := &bytes.Buffer{}
invalidBody := &bytes.Buffer{}
onErrorCalled := false
onOKCalled := false

var r *http.Request
var w *httptest.ResponseRecorder
Expand All @@ -45,6 +47,8 @@ func TestHandler(t *testing.T) {
errorBody.WriteString("{\"data\": -1}")
invalidBody.Reset()
invalidBody.WriteString("invalid")
onOKCalled = false
onErrorCalled = false
w = httptest.NewRecorder()
}

Expand All @@ -58,6 +62,12 @@ func TestHandler(t *testing.T) {
},
Validate: true,
}
m.OnOK(func(w http.ResponseWriter, r *http.Request, req any, res any) {
onOKCalled = true
})
m.OnError(func(w http.ResponseWriter, r *http.Request, req any, err error) {
onErrorCalled = true
})

h := m.Handler(func(ctx context.Context, req *requestType) (any, error) {
if req.Data != 1 {
Expand All @@ -70,18 +80,30 @@ func TestHandler(t *testing.T) {
if !callSuccess {
t.Fatalf("success not call")
}
if !onOKCalled {
t.Fatalf("onOK not call")
}
if callError {
t.Fatalf("error should not be called")
}
if onErrorCalled {
t.Fatalf("onError should not be called")
}
}

mustError := func() {
if callSuccess {
t.Fatalf("success should not be called")
}
if onOKCalled {
t.Fatalf("onOK should not be called")
}
if !callError {
t.Fatalf("error not call")
}
if !onErrorCalled {
t.Fatalf("onError not call")
}
}

mustNothing := func() {
Expand Down

0 comments on commit b11ed80

Please sign in to comment.