diff --git a/mock/mock.go b/mock/mock.go index d6694ed78..ce446db04 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -557,6 +557,10 @@ const ( Anything = "mock.Anything" ) +var ( + errorType = reflect.TypeOf((*error)(nil)).Elem() +) + // AnythingOfTypeArgument is a string that contains the type of an argument // for use when type checking. Used in Diff and Assert. type AnythingOfTypeArgument string @@ -578,6 +582,10 @@ type argumentMatcher struct { } func (f argumentMatcher) Matches(argument interface{}) bool { + return f.match(argument) == nil +} + +func (f argumentMatcher) match(argument interface{}) error { expectType := f.fn.Type().In(0) expectTypeNilSupported := false switch expectType.Kind() { @@ -598,25 +606,52 @@ func (f argumentMatcher) Matches(argument interface{}) bool { } if argType == nil || argType.AssignableTo(expectType) { result := f.fn.Call([]reflect.Value{arg}) - return result[0].Bool() + + var matchError error + switch { + case result[0].Type().Kind() == reflect.Bool: + if !result[0].Bool() { + matchError = fmt.Errorf("not matched by %s", f) + } + case result[0].Type().Implements(errorType): + if !result[0].IsNil() { + matchError = result[0].Interface().(error) + } + default: + panic(fmt.Errorf("matcher function of unknown type: %s", result[0].Type().Kind())) + } + + return matchError } - return false + return fmt.Errorf("unexpected type for %s", f) } func (f argumentMatcher) String() string { - return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name()) + return fmt.Sprintf("func(%s) %s", f.fn.Type().In(0).String(), f.fn.Type().Out(0).String()) +} + +func (f argumentMatcher) GoString() string { + return fmt.Sprintf("MatchedBy(%s)", f) } // MatchedBy can be used to match a mock call based on only certain properties // from a complex struct or some calculation. It takes a function that will be -// evaluated with the called argument and will return true when there's a match -// and false otherwise. +// evaluated with the called argument and will return either a boolean (true +// when there's a match and false otherwise) or an error (nil when there's a +// match and error holding the failure message otherwise). +// +// Examples: +// m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" })) // -// Example: -// m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" })) +// m.On("Do", MatchedBy(func(req *http.Request) (err error) { +// if req.Host != "example.com" { +// err = errors.New("host was not example.com") +// } +// return +// }) // // |fn|, must be a function accepting a single argument (of the expected type) -// which returns a bool. If |fn| doesn't match the required signature, +// which returns a bool or error. If |fn| doesn't match the required signature, // MatchedBy() panics. func MatchedBy(fn interface{}) argumentMatcher { fnType := reflect.TypeOf(fn) @@ -627,8 +662,9 @@ func MatchedBy(fn interface{}) argumentMatcher { if fnType.NumIn() != 1 { panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn)) } - if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool { - panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn)) + + if fnType.NumOut() != 1 || (fnType.Out(0).Kind() != reflect.Bool && !fnType.Out(0).Implements(errorType)) { + panic(fmt.Sprintf("assert: arguments: %s does not return a bool or a error", fn)) } return argumentMatcher{fn: reflect.ValueOf(fn)} @@ -688,11 +724,11 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { } if matcher, ok := expected.(argumentMatcher); ok { - if matcher.Matches(actual) { + if matchError := matcher.match(actual); matchError == nil { output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher) } else { differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher) + output = fmt.Sprintf("%s\t%d: FAIL: %s %s\n", output, i, actualFmt, matchError) } } else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() { diff --git a/mock/mock_test.go b/mock/mock_test.go index 2608f5a36..9fb14a393 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -1259,7 +1259,7 @@ func Test_Arguments_Diff_WithArgMatcher(t *testing.T) { diff, count = args.Diff([]interface{}{"string", false, true}) assert.Equal(t, 1, count) - assert.Contains(t, diff, `(bool=false) not matched by func(int) bool`) + assert.Contains(t, diff, `(bool=false) unexpected type for func(int) bool`) diff, count = args.Diff([]interface{}{"string", 123, false}) assert.Contains(t, diff, `(int=123) matched by func(int) bool`) @@ -1269,6 +1269,31 @@ func Test_Arguments_Diff_WithArgMatcher(t *testing.T) { assert.Contains(t, diff, `No differences.`) } +func Test_Arguments_Diff_WithArgMatcherReturningError(t *testing.T) { + matchFn := func(a int) (err error) { + if a != 123 { + err = errors.New("did not match") + } + return + } + var args = Arguments([]interface{}{"string", MatchedBy(matchFn), true}) + + diff, count := args.Diff([]interface{}{"string", 124, true}) + assert.Equal(t, 1, count) + assert.Contains(t, diff, `(int=124) did not match`) + + diff, count = args.Diff([]interface{}{"string", false, true}) + assert.Equal(t, 1, count) + assert.Contains(t, diff, `(bool=false) unexpected type for func(int) error`) + + diff, count = args.Diff([]interface{}{"string", 123, false}) + assert.Contains(t, diff, `(int=123) matched by func(int) error`) + + diff, count = args.Diff([]interface{}{"string", 123, true}) + assert.Equal(t, 0, count) + assert.Contains(t, diff, `No differences.`) +} + func Test_Arguments_Assert(t *testing.T) { var args = Arguments([]interface{}{"string", 123, true}) @@ -1445,7 +1470,7 @@ func TestArgumentMatcherToPrintMismatch(t *testing.T) { defer func() { if r := recover(); r != nil { matchingExp := regexp.MustCompile( - `\s+mock: Unexpected Method Call\s+-*\s+GetTime\(int\)\s+0: 1\s+The closest call I have is:\s+GetTime\(mock.argumentMatcher\)\s+0: mock.argumentMatcher\{.*?\}\s+Diff:.*\(int=1\) not matched by func\(int\) bool`) + `\s+mock: Unexpected Method Call\s+-*\s+GetTime\(int\)\s+0: 1\s+The closest call I have is:\s+GetTime\(mock.argumentMatcher\)\s+0: MatchedBy\(func\(int\) bool\)\s+Diff:.*\(int=1\) not matched by func\(int\) bool`) assert.Regexp(t, matchingExp, r) } }()