From b6443861e07dbf19248f35354fa787ce1c2a6f3e Mon Sep 17 00:00:00 2001 From: deelawn Date: Mon, 9 Dec 2024 13:25:40 -0800 Subject: [PATCH] add custom field tag support --- field_tag.go | 4 ++++ marshaler.go | 42 +++++++++++++++++++++++++----------------- marshaler_test.go | 24 ++++++++++++++++++++++++ unmarshaler.go | 30 +++++++++++++++++++++--------- unmarshaler_test.go | 15 +++++++++++++++ 5 files changed, 89 insertions(+), 26 deletions(-) create mode 100644 field_tag.go diff --git a/field_tag.go b/field_tag.go new file mode 100644 index 00000000..3f0fa417 --- /dev/null +++ b/field_tag.go @@ -0,0 +1,4 @@ +package toml + +// tomlFieldTag is the default field tag name when encoding and decoding TOML schema. +const tomlFieldTag = "toml" diff --git a/marshaler.go b/marshaler.go index 161acd93..b4b1af2a 100644 --- a/marshaler.go +++ b/marshaler.go @@ -43,6 +43,7 @@ type Encoder struct { indentSymbol string indentTables bool marshalJsonNumbers bool + fieldTag string } // NewEncoder returns a new Encoder that writes to w. @@ -50,6 +51,7 @@ func NewEncoder(w io.Writer) *Encoder { return &Encoder{ w: w, indentSymbol: " ", + fieldTag: tomlFieldTag, } } @@ -100,6 +102,12 @@ func (enc *Encoder) SetMarshalJsonNumbers(indent bool) *Encoder { return enc } +// SetFieldTag replaces the current field tag with the specified field tag. +func (enc *Encoder) SetFieldTag(fieldTag string) *Encoder { + enc.fieldTag = fieldTag + return enc +} + // Encode writes a TOML representation of v to the stream. // // If v cannot be represented to TOML it returns an error. @@ -148,7 +156,7 @@ func (enc *Encoder) SetMarshalJsonNumbers(indent bool) *Encoder { // # Struct tags // // The encoding of each public struct field can be customized by the format -// string in the "toml" key of the struct field's tag. This follows +// string in the "toml" (or custom) key of the struct field's tag. This follows // encoding/json's convention. The format string starts with the name of the // field, optionally followed by a comma-separated list of options. The name may // be empty in order to provide options without overriding the default name. @@ -164,7 +172,7 @@ func (enc *Encoder) SetMarshalJsonNumbers(indent bool) *Encoder { // The "commented" option prefixes the value and all its children with a comment // symbol. // -// In addition to the "toml" tag struct tag, a "comment" tag can be used to emit +// In addition to the "toml" (or custom ) tag struct tag, a "comment" tag can be used to emit // a TOML comment before the value being annotated. Comments are ignored inside // inline tables. For array tables, the comment is only present before the first // element of the array. @@ -380,8 +388,8 @@ func isNil(v reflect.Value) bool { } } -func shouldOmitEmpty(options valueOptions, v reflect.Value) bool { - return options.omitempty && isEmptyValue(v) +func (enc *Encoder) shouldOmitEmpty(options valueOptions, v reflect.Value) bool { + return options.omitempty && enc.isEmptyValue(v) } func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v reflect.Value) ([]byte, error) { @@ -418,10 +426,10 @@ func (enc *Encoder) commented(commented bool, b []byte) []byte { return b } -func isEmptyValue(v reflect.Value) bool { +func (enc *Encoder) isEmptyValue(v reflect.Value) bool { switch v.Kind() { case reflect.Struct: - return isEmptyStruct(v) + return enc.isEmptyStruct(v) case reflect.Array, reflect.Map, reflect.Slice, reflect.String: return v.Len() == 0 case reflect.Bool: @@ -438,7 +446,7 @@ func isEmptyValue(v reflect.Value) bool { return false } -func isEmptyStruct(v reflect.Value) bool { +func (enc *Encoder) isEmptyStruct(v reflect.Value) bool { // TODO: merge with walkStruct and cache. typ := v.Type() for i := 0; i < typ.NumField(); i++ { @@ -449,7 +457,7 @@ func isEmptyStruct(v reflect.Value) bool { continue } - tag := fieldType.Tag.Get("toml") + tag := fieldType.Tag.Get(enc.fieldTag) // special field name to skip field if tag == "-" { @@ -458,7 +466,7 @@ func isEmptyStruct(v reflect.Value) bool { f := v.Field(i) - if !isEmptyValue(f) { + if !enc.isEmptyValue(f) { return false } } @@ -715,7 +723,7 @@ func (t *table) pushTable(k string, v reflect.Value, options valueOptions) { t.tables = append(t.tables, entry{Key: k, Value: v, Options: options}) } -func walkStruct(ctx encoderCtx, t *table, v reflect.Value) { +func (enc *Encoder) walkStruct(ctx encoderCtx, t *table, v reflect.Value) { // TODO: cache this typ := v.Type() for i := 0; i < typ.NumField(); i++ { @@ -726,7 +734,7 @@ func walkStruct(ctx encoderCtx, t *table, v reflect.Value) { continue } - tag := fieldType.Tag.Get("toml") + tag := fieldType.Tag.Get(enc.fieldTag) // special field name to skip field if tag == "-" { @@ -743,9 +751,9 @@ func walkStruct(ctx encoderCtx, t *table, v reflect.Value) { if k == "" { if fieldType.Anonymous { if fieldType.Type.Kind() == reflect.Struct { - walkStruct(ctx, t, f) + enc.walkStruct(ctx, t, f) } else if fieldType.Type.Kind() == reflect.Ptr && !f.IsNil() && f.Elem().Kind() == reflect.Struct { - walkStruct(ctx, t, f.Elem()) + enc.walkStruct(ctx, t, f.Elem()) } continue } else { @@ -775,7 +783,7 @@ func walkStruct(ctx encoderCtx, t *table, v reflect.Value) { func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { var t table - walkStruct(ctx, &t, v) + enc.walkStruct(ctx, &t, v) return enc.encodeTable(b, ctx, t) } @@ -879,7 +887,7 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro hasNonEmptyKV := false for _, kv := range t.kvs { - if shouldOmitEmpty(kv.Options, kv.Value) { + if enc.shouldOmitEmpty(kv.Options, kv.Value) { continue } hasNonEmptyKV = true @@ -898,7 +906,7 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro first := true for _, table := range t.tables { - if shouldOmitEmpty(table.Options, table.Value) { + if enc.shouldOmitEmpty(table.Options, table.Value) { continue } if first { @@ -932,7 +940,7 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte first := true for _, kv := range t.kvs { - if shouldOmitEmpty(kv.Options, kv.Value) { + if enc.shouldOmitEmpty(kv.Options, kv.Value) { continue } diff --git a/marshaler_test.go b/marshaler_test.go index f88e49af..f033837c 100644 --- a/marshaler_test.go +++ b/marshaler_test.go @@ -1080,6 +1080,30 @@ Bad = '' assert.Equal(t, expected, string(b)) } +func TestEncoderSetFieldTag(t *testing.T) { + type fieldTagTestType struct { + Hellllooo string `json:"hello"` + Nnnummbbeerrr int `json:"number"` + } + + var buf bytes.Buffer + enc := toml.NewEncoder(&buf) + enc = enc.SetFieldTag("json") + + input := fieldTagTestType{ + Hellllooo: "hi", + Nnnummbbeerrr: 1, + } + err := enc.Encode(input) + require.NoError(t, err) + + expValue := `hello = 'hi' +number = 1 +` + + require.Equal(t, expValue, buf.String()) +} + func TestIssue436(t *testing.T) { data := []byte(`{"a": [ { "b": { "c": "d" } } ]}`) diff --git a/unmarshaler.go b/unmarshaler.go index 189be525..9a721cc9 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -21,7 +21,7 @@ import ( // // It is a shortcut for Decoder.Decode() with the default options. func Unmarshal(data []byte, v interface{}) error { - d := decoder{} + d := decoder{fieldTag: tomlFieldTag} d.p.Reset(data) return d.FromParser(v) } @@ -36,11 +36,14 @@ type Decoder struct { // toggles unmarshaler interface unmarshalerInterface bool + + // fieldTag is the tag name used to decode struct fields. + fieldTag string } // NewDecoder creates a new Decoder that will read from r. func NewDecoder(r io.Reader) *Decoder { - return &Decoder{r: r} + return &Decoder{r: r, fieldTag: tomlFieldTag} } // DisallowUnknownFields causes the Decoder to return an error when the @@ -73,6 +76,11 @@ func (d *Decoder) EnableUnmarshalerInterface() *Decoder { return d } +func (d *Decoder) SetFieldTag(fieldTag string) *Decoder { + d.fieldTag = fieldTag + return d +} + // Decode the whole content of r into v. // // By default, values in the document that don't exist in the target Go value @@ -125,6 +133,7 @@ func (d *Decoder) Decode(v interface{}) error { Enabled: d.strict, }, unmarshalerInterface: d.unmarshalerInterface, + fieldTag: d.fieldTag, } dec.p.Reset(b) @@ -166,6 +175,9 @@ type decoder struct { // Current context for the error. errorContext *errorContext + + // fieldTag is the tag name used to decode struct fields. + fieldTag string } type errorContext struct { @@ -512,7 +524,7 @@ func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn h v.SetMapIndex(mk, mv) } case reflect.Struct: - path, found := structFieldPath(v, string(key.Node().Data)) + path, found := d.structFieldPath(v, string(key.Node().Data)) if !found { d.skipUntilTable = true return reflect.Value{}, nil @@ -1152,7 +1164,7 @@ func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node v.SetMapIndex(mk, mv) } case reflect.Struct: - path, found := structFieldPath(v, string(key.Node().Data)) + path, found := d.structFieldPath(v, string(key.Node().Data)) if !found { d.skipUntilTable = true break @@ -1261,7 +1273,7 @@ type fieldPathsMap = map[string][]int var globalFieldPathsCache atomic.Value // map[danger.TypeID]fieldPathsMap -func structFieldPath(v reflect.Value, name string) ([]int, bool) { +func (d *decoder) structFieldPath(v reflect.Value, name string) ([]int, bool) { t := v.Type() cache, _ := globalFieldPathsCache.Load().(map[danger.TypeID]fieldPathsMap) @@ -1270,7 +1282,7 @@ func structFieldPath(v reflect.Value, name string) ([]int, bool) { if !ok { fieldPaths = map[string][]int{} - forEachField(t, nil, func(name string, path []int) { + d.forEachField(t, nil, func(name string, path []int) { fieldPaths[name] = path // extra copy for the case-insensitive match fieldPaths[strings.ToLower(name)] = path @@ -1291,7 +1303,7 @@ func structFieldPath(v reflect.Value, name string) ([]int, bool) { return path, ok } -func forEachField(t reflect.Type, path []int, do func(name string, path []int)) { +func (d *decoder) forEachField(t reflect.Type, path []int, do func(name string, path []int)) { n := t.NumField() for i := 0; i < n; i++ { f := t.Field(i) @@ -1304,7 +1316,7 @@ func forEachField(t reflect.Type, path []int, do func(name string, path []int)) fieldPath := append(path, i) fieldPath = fieldPath[:len(fieldPath):len(fieldPath)] - name := f.Tag.Get("toml") + name := f.Tag.Get(d.fieldTag) if name == "-" { continue } @@ -1320,7 +1332,7 @@ func forEachField(t reflect.Type, path []int, do func(name string, path []int)) } if t2.Kind() == reflect.Struct { - forEachField(t2, fieldPath, do) + d.forEachField(t2, fieldPath, do) } continue } diff --git a/unmarshaler_test.go b/unmarshaler_test.go index 3cbd81d1..7ec51392 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -112,6 +112,21 @@ func TestDecodeReaderError(t *testing.T) { require.Error(t, err) } +func TestDecoderSetFieldTag(t *testing.T) { + type fieldTagTestType struct { + A string `json:"aaa"` + B string `json:"bbb"` + } + + d := toml.NewDecoder(strings.NewReader("aaa = \"foo\"\nbbb = \"bar\"")) + d.SetFieldTag("json") + var output fieldTagTestType + err := d.Decode(&output) + require.NoError(t, err) + assert.Equal(t, "foo", output.A) + assert.Equal(t, "bar", output.B) +} + // nolint:funlen func TestUnmarshal_Integers(t *testing.T) { examples := []struct {