diff --git a/bind_test.go b/bind_test.go index aa00e191ca..5cc59f6280 100644 --- a/bind_test.go +++ b/bind_test.go @@ -1091,6 +1091,42 @@ func Benchmark_Bind_Body_XML(b *testing.B) { require.Equal(b, "john", d.Name) } +// go test -run Test_Bind_Body_Form_Embedded +func Test_Bind_Body_Form_Embedded(b *testing.T) { + var err error + + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + type EmbeddedDemo struct { + EmbeddedString string `form:"embedded_string"` + EmbeddedStrings []string `form:"embedded_strings"` + } + + type Demo struct { + String string `form:"some_string"` + OtherString string `form:"some_other_string"` + Strings []string `form:"strings"` + OtherStrings []string `form:"other_strings"` + EmbeddedDemo + } + body := []byte("some_string=john%2Clong&some_other_string=long&some_other_string=long&strings=long%2Cjohn&embedded_strings=john%2Clongest&embedded_string=johny%2Cwalker&other_strings=long&other_strings=johny") + c.Request().SetBody(body) + c.Request().Header.SetContentType(MIMEApplicationForm) + c.Request().Header.SetContentLength(len(body)) + d := new(Demo) + + err = c.Bind().Body(d) + + require.NoError(b, err) + require.Equal(b, "john,long", d.String) + require.Equal(b, []string{"long", "john"}, d.Strings) + //! only one value is taken + require.Equal(b, "long", d.OtherString) + require.Equal(b, []string{"long", "johny"}, d.OtherStrings) + require.Equal(b, "johny,walker", d.EmbeddedString) + require.Equal(b, []string{"john", "longest"}, d.EmbeddedStrings) +} + // go test -v -run=^$ -bench=Benchmark_Bind_Body_Form -benchmem -count=4 func Benchmark_Bind_Body_Form(b *testing.B) { var err error diff --git a/binder/cookie.go b/binder/cookie.go index 0f5c650c33..2dd1dc864a 100644 --- a/binder/cookie.go +++ b/binder/cookie.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -26,14 +23,7 @@ func (b *cookieBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + appendValue(data, v, out, k, b.Name()) }) if err != nil { diff --git a/binder/form.go b/binder/form.go index f45407fe93..7df300e256 100644 --- a/binder/form.go +++ b/binder/form.go @@ -1,7 +1,6 @@ package binder import ( - "reflect" "strings" "github.com/gofiber/utils/v2" @@ -30,14 +29,7 @@ func (b *formBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k, err = parseParamSquareBrackets(k) } - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + appendValue(data, v, out, k, b.Name()) }) if err != nil { diff --git a/binder/header.go b/binder/header.go index 196163694d..3610408137 100644 --- a/binder/header.go +++ b/binder/header.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -20,14 +17,7 @@ func (b *headerBinding) Bind(req *fasthttp.Request, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + appendValue(data, v, out, k, b.Name()) }) return parse(b.Name(), out, data) diff --git a/binder/mapping.go b/binder/mapping.go index ea67ace200..e9420a1278 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -172,7 +172,18 @@ func parseParamSquareBrackets(k string) (string, error) { return bb.String(), nil } -func equalFieldType(out any, kind reflect.Kind, key string) bool { +func appendValue(to map[string][]string, rawValue string, out any, k string, bindingName string) { + if strings.Contains(rawValue, ",") && equalFieldType(out, reflect.Slice, k, bindingName) { + values := strings.Split(rawValue, ",") + for i := 0; i < len(values); i++ { + to[k] = append(to[k], values[i]) + } + } else { + to[k] = append(to[k], rawValue) + } +} + +func equalFieldType(out any, kind reflect.Kind, key string, bindingName string) bool { // Get type of interface outTyp := reflect.TypeOf(out).Elem() key = utils.ToLower(key) @@ -196,47 +207,54 @@ func equalFieldType(out any, kind reflect.Kind, key string) bool { if !structField.CanSet() { continue } + // Get field key data typeField := outTyp.Field(i) // Get type of field key structFieldKind := structField.Kind() - // Does the field type equals input? - if structFieldKind != kind { - // Is the field an embedded struct? - if structFieldKind == reflect.Struct { - // Loop over embedded struct fields - for j := 0; j < structField.NumField(); j++ { - structFieldField := structField.Field(j) - - // Can this embedded field be changed? - if !structFieldField.CanSet() { - continue - } - - // Is the embedded struct field type equal to the input? - if structFieldField.Kind() == kind { - return true - } - } - } - continue - } - // Get tag from field if exist - inputFieldName := typeField.Tag.Get(QueryBinder.Name()) - if inputFieldName == "" { - inputFieldName = typeField.Name - } else { - inputFieldName = strings.Split(inputFieldName, ",")[0] - } // Compare field/tag with provided key - if utils.ToLower(inputFieldName) == key { - return true + if getFieldKey(typeField, bindingName) == key { + return structFieldKind == kind + } + + // Is the field an embedded struct? + if typeField.Anonymous { + // Loop over embedded struct fields + for j := 0; j < structField.NumField(); j++ { + if getFieldKey(structField.Type().Field(j), bindingName) != key { + // this is not the field that we are looking for + continue + } + + structFieldField := structField.Field(j) + + // Can this embedded field be changed? + if !structFieldField.CanSet() { + continue + } + + // Is the embedded struct field type equal to the input? + return structFieldField.Kind() == kind + } } } return false } +// Get binding key for a field +func getFieldKey(typeField reflect.StructField, bindingName string) string { + // Get tag from field if exist + inputFieldName := typeField.Tag.Get(bindingName) + if inputFieldName == "" { + inputFieldName = typeField.Name + } else { + inputFieldName = strings.Split(inputFieldName, ",")[0] + } + // Compare field key + return utils.ToLower(inputFieldName) +} + // Get content type from content type header func FilterFlags(content string) string { for i, char := range content { diff --git a/binder/mapping_test.go b/binder/mapping_test.go index e6fc8146f7..1f74664cfe 100644 --- a/binder/mapping_test.go +++ b/binder/mapping_test.go @@ -10,25 +10,25 @@ import ( func Test_EqualFieldType(t *testing.T) { var out int - require.False(t, equalFieldType(&out, reflect.Int, "key")) + require.False(t, equalFieldType(&out, reflect.Int, "key", "query")) var dummy struct{ f string } - require.False(t, equalFieldType(&dummy, reflect.String, "key")) + require.False(t, equalFieldType(&dummy, reflect.String, "key", "query")) var dummy2 struct{ f string } - require.False(t, equalFieldType(&dummy2, reflect.String, "f")) + require.False(t, equalFieldType(&dummy2, reflect.String, "f", "query")) var user struct { Name string Address string `query:"address"` Age int `query:"AGE"` } - require.True(t, equalFieldType(&user, reflect.String, "name")) - require.True(t, equalFieldType(&user, reflect.String, "Name")) - require.True(t, equalFieldType(&user, reflect.String, "address")) - require.True(t, equalFieldType(&user, reflect.String, "Address")) - require.True(t, equalFieldType(&user, reflect.Int, "AGE")) - require.True(t, equalFieldType(&user, reflect.Int, "age")) + require.True(t, equalFieldType(&user, reflect.String, "name", "query")) + require.True(t, equalFieldType(&user, reflect.String, "Name", "query")) + require.True(t, equalFieldType(&user, reflect.String, "address", "query")) + require.True(t, equalFieldType(&user, reflect.String, "Address", "query")) + require.True(t, equalFieldType(&user, reflect.Int, "AGE", "query")) + require.True(t, equalFieldType(&user, reflect.Int, "age", "query")) } func Test_ParseParamSquareBrackets(t *testing.T) { diff --git a/binder/query.go b/binder/query.go index 25b69f5bc3..d35d92d22c 100644 --- a/binder/query.go +++ b/binder/query.go @@ -1,7 +1,6 @@ package binder import ( - "reflect" "strings" "github.com/gofiber/utils/v2" @@ -30,14 +29,7 @@ func (b *queryBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error { k, err = parseParamSquareBrackets(k) } - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + appendValue(data, v, out, k, b.Name()) }) if err != nil { diff --git a/binder/resp_header.go b/binder/resp_header.go index 0455185ba1..749e98b324 100644 --- a/binder/resp_header.go +++ b/binder/resp_header.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -20,14 +17,7 @@ func (b *respHeaderBinding) Bind(resp *fasthttp.Response, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + appendValue(data, v, out, k, b.Name()) }) return parse(b.Name(), out, data)