diff --git a/assert/assertions.go b/assert/assertions.go index bc15101b0..368ca74ff 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -48,6 +48,28 @@ type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool // Comparison is a custom function that returns true on success and false on failure type Comparison func() (success bool) +// List of basic types. This is specifically a list of types whose underlying value is not printed +// when using the %#v format specifier to print a pointer to the value. +var basicTypes = []reflect.Kind{ + reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64, + reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Uintptr, + reflect.Float32, + reflect.Float64, + reflect.Complex64, + reflect.Complex128, + reflect.String, + reflect.Bool, +} + /* Helper functions */ @@ -512,22 +534,44 @@ func samePointers(first, second interface{}) bool { // to a type conversion in the Go grammar. func formatUnequalValues(expected, actual interface{}) (e string, a string) { if reflect.TypeOf(expected) != reflect.TypeOf(actual) { - return fmt.Sprintf("%T(%s)", expected, truncatingFormat(expected)), - fmt.Sprintf("%T(%s)", actual, truncatingFormat(actual)) + return fmt.Sprintf("%T(%s)", expected, truncatingFormat(expected, false)), + fmt.Sprintf("%T(%s)", actual, truncatingFormat(actual, false)) } switch expected.(type) { case time.Duration: return fmt.Sprintf("%v", expected), fmt.Sprintf("%v", actual) } - return truncatingFormat(expected), truncatingFormat(actual) + + return truncatingFormat(expected, true), truncatingFormat(actual, true) } // truncatingFormat formats the data and truncates it if it's too long. // // This helps keep formatted error messages lines from exceeding the // bufio.MaxScanTokenSize max line length that the go testing framework imposes. -func truncatingFormat(data interface{}) string { - value := fmt.Sprintf("%#v", data) +// +// If the `printValueOfBasicTypes` flag is set to true, the underlying value of +// pointers of basic types (see `basicTypes` for list of basic types) will be +// returned in addition to the pointer address. For non-basic types, only the pointer +// address will be returned. +// If the `printValueOfBasicTypes` flag is set to false, only the pointer address +// will be returned, regardless of whether `data` is a basic type or not. +func truncatingFormat(data interface{}, printValueOfBasicTypes bool) string { + var value string + if printValueOfBasicTypes && reflect.ValueOf(data).Kind() == reflect.Ptr { + v := reflect.ValueOf(data).Elem() + for _, t := range basicTypes { + if v.Kind() == t { + value = fmt.Sprintf("%#v %v", data, v.Interface()) + break + } + } + } + + if value == "" { + value = fmt.Sprintf("%#v", data) + } + max := bufio.MaxScanTokenSize - 100 // Give us some space the type info too if needed. if len(value) > max { value = value[0:max] + "<... truncated>" diff --git a/assert/assertions_test.go b/assert/assertions_test.go index d2a25c245..0e5cdbb7a 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -2896,11 +2896,11 @@ func Test_validateEqualArgs(t *testing.T) { func Test_truncatingFormat(t *testing.T) { original := strings.Repeat("a", bufio.MaxScanTokenSize-102) - result := truncatingFormat(original) + result := truncatingFormat(original, false) Equal(t, fmt.Sprintf("%#v", original), result, "string should not be truncated") original = original + "x" - result = truncatingFormat(original) + result = truncatingFormat(original, false) NotEqual(t, fmt.Sprintf("%#v", original), result, "string should have been truncated.") if !strings.HasSuffix(result, "<... truncated>") { @@ -2908,6 +2908,29 @@ func Test_truncatingFormat(t *testing.T) { } } +func Test_truncatingFormat_PrintValueOfBasicTypes(t *testing.T) { + i := 5 + iPtr := &i + result := truncatingFormat(iPtr, true) + if !strings.HasSuffix(result, "5") { + t.Error("format should have printed the underlying value of int pointer") + } + + b := false + bPtr := &b + result = truncatingFormat(bPtr, true) + if !strings.HasSuffix(result, "false") { + t.Error("format should have printed the underlying value of bool pointer") + } + + s := "testingString" + sPtr := &s + result = truncatingFormat(sPtr, true) + if !strings.HasSuffix(result, "testingString") { + t.Error("format should have printed the underlying value of string pointer") + } +} + // parseLabeledOutput does the inverse of labeledOutput - it takes a formatted // output string and turns it back into a slice of labeledContent. func parseLabeledOutput(output string) []labeledContent {