Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for custom field tags #977

Open
wants to merge 1 commit into
base: v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions field_tag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package toml

// tomlFieldTag is the default field tag name when encoding and decoding TOML schema.
const tomlFieldTag = "toml"
42 changes: 25 additions & 17 deletions marshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ type Encoder struct {
indentSymbol string
indentTables bool
marshalJsonNumbers bool
fieldTag string
}

// NewEncoder returns a new Encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: w,
indentSymbol: " ",
fieldTag: tomlFieldTag,
}
}

Expand Down Expand Up @@ -100,6 +102,12 @@ func (enc *Encoder) SetMarshalJsonNumbers(indent bool) *Encoder {
return enc
}

// SetFieldTag replaces the current field tag with the specified field tag.
func (enc *Encoder) SetFieldTag(fieldTag string) *Encoder {
enc.fieldTag = fieldTag
return enc
}

// Encode writes a TOML representation of v to the stream.
//
// If v cannot be represented to TOML it returns an error.
Expand Down Expand Up @@ -148,7 +156,7 @@ func (enc *Encoder) SetMarshalJsonNumbers(indent bool) *Encoder {
// # Struct tags
//
// The encoding of each public struct field can be customized by the format
// string in the "toml" key of the struct field's tag. This follows
// string in the "toml" (or custom) key of the struct field's tag. This follows
// encoding/json's convention. The format string starts with the name of the
// field, optionally followed by a comma-separated list of options. The name may
// be empty in order to provide options without overriding the default name.
Expand All @@ -164,7 +172,7 @@ func (enc *Encoder) SetMarshalJsonNumbers(indent bool) *Encoder {
// The "commented" option prefixes the value and all its children with a comment
// symbol.
//
// In addition to the "toml" tag struct tag, a "comment" tag can be used to emit
// In addition to the "toml" (or custom ) tag struct tag, a "comment" tag can be used to emit
// a TOML comment before the value being annotated. Comments are ignored inside
// inline tables. For array tables, the comment is only present before the first
// element of the array.
Expand Down Expand Up @@ -380,8 +388,8 @@ func isNil(v reflect.Value) bool {
}
}

func shouldOmitEmpty(options valueOptions, v reflect.Value) bool {
return options.omitempty && isEmptyValue(v)
func (enc *Encoder) shouldOmitEmpty(options valueOptions, v reflect.Value) bool {
return options.omitempty && enc.isEmptyValue(v)
}

func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v reflect.Value) ([]byte, error) {
Expand Down Expand Up @@ -418,10 +426,10 @@ func (enc *Encoder) commented(commented bool, b []byte) []byte {
return b
}

func isEmptyValue(v reflect.Value) bool {
func (enc *Encoder) isEmptyValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Struct:
return isEmptyStruct(v)
return enc.isEmptyStruct(v)
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
Expand All @@ -438,7 +446,7 @@ func isEmptyValue(v reflect.Value) bool {
return false
}

func isEmptyStruct(v reflect.Value) bool {
func (enc *Encoder) isEmptyStruct(v reflect.Value) bool {
// TODO: merge with walkStruct and cache.
typ := v.Type()
for i := 0; i < typ.NumField(); i++ {
Expand All @@ -449,7 +457,7 @@ func isEmptyStruct(v reflect.Value) bool {
continue
}

tag := fieldType.Tag.Get("toml")
tag := fieldType.Tag.Get(enc.fieldTag)

// special field name to skip field
if tag == "-" {
Expand All @@ -458,7 +466,7 @@ func isEmptyStruct(v reflect.Value) bool {

f := v.Field(i)

if !isEmptyValue(f) {
if !enc.isEmptyValue(f) {
return false
}
}
Expand Down Expand Up @@ -715,7 +723,7 @@ func (t *table) pushTable(k string, v reflect.Value, options valueOptions) {
t.tables = append(t.tables, entry{Key: k, Value: v, Options: options})
}

func walkStruct(ctx encoderCtx, t *table, v reflect.Value) {
func (enc *Encoder) walkStruct(ctx encoderCtx, t *table, v reflect.Value) {
// TODO: cache this
typ := v.Type()
for i := 0; i < typ.NumField(); i++ {
Expand All @@ -726,7 +734,7 @@ func walkStruct(ctx encoderCtx, t *table, v reflect.Value) {
continue
}

tag := fieldType.Tag.Get("toml")
tag := fieldType.Tag.Get(enc.fieldTag)

// special field name to skip field
if tag == "-" {
Expand All @@ -743,9 +751,9 @@ func walkStruct(ctx encoderCtx, t *table, v reflect.Value) {
if k == "" {
if fieldType.Anonymous {
if fieldType.Type.Kind() == reflect.Struct {
walkStruct(ctx, t, f)
enc.walkStruct(ctx, t, f)
} else if fieldType.Type.Kind() == reflect.Ptr && !f.IsNil() && f.Elem().Kind() == reflect.Struct {
walkStruct(ctx, t, f.Elem())
enc.walkStruct(ctx, t, f.Elem())
}
continue
} else {
Expand Down Expand Up @@ -775,7 +783,7 @@ func walkStruct(ctx encoderCtx, t *table, v reflect.Value) {
func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
var t table

walkStruct(ctx, &t, v)
enc.walkStruct(ctx, &t, v)

return enc.encodeTable(b, ctx, t)
}
Expand Down Expand Up @@ -879,7 +887,7 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro

hasNonEmptyKV := false
for _, kv := range t.kvs {
if shouldOmitEmpty(kv.Options, kv.Value) {
if enc.shouldOmitEmpty(kv.Options, kv.Value) {
continue
}
hasNonEmptyKV = true
Expand All @@ -898,7 +906,7 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro

first := true
for _, table := range t.tables {
if shouldOmitEmpty(table.Options, table.Value) {
if enc.shouldOmitEmpty(table.Options, table.Value) {
continue
}
if first {
Expand Down Expand Up @@ -932,7 +940,7 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte

first := true
for _, kv := range t.kvs {
if shouldOmitEmpty(kv.Options, kv.Value) {
if enc.shouldOmitEmpty(kv.Options, kv.Value) {
continue
}

Expand Down
24 changes: 24 additions & 0 deletions marshaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,30 @@ Bad = ''
assert.Equal(t, expected, string(b))
}

func TestEncoderSetFieldTag(t *testing.T) {
type fieldTagTestType struct {
Hellllooo string `json:"hello"`
Nnnummbbeerrr int `json:"number"`
}

var buf bytes.Buffer
enc := toml.NewEncoder(&buf)
enc = enc.SetFieldTag("json")

input := fieldTagTestType{
Hellllooo: "hi",
Nnnummbbeerrr: 1,
}
err := enc.Encode(input)
require.NoError(t, err)

expValue := `hello = 'hi'
number = 1
`

require.Equal(t, expValue, buf.String())
}

func TestIssue436(t *testing.T) {
data := []byte(`{"a": [ { "b": { "c": "d" } } ]}`)

Expand Down
30 changes: 21 additions & 9 deletions unmarshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
//
// It is a shortcut for Decoder.Decode() with the default options.
func Unmarshal(data []byte, v interface{}) error {
d := decoder{}
d := decoder{fieldTag: tomlFieldTag}
d.p.Reset(data)
return d.FromParser(v)
}
Expand All @@ -36,11 +36,14 @@ type Decoder struct {

// toggles unmarshaler interface
unmarshalerInterface bool

// fieldTag is the tag name used to decode struct fields.
fieldTag string
}

// NewDecoder creates a new Decoder that will read from r.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
return &Decoder{r: r, fieldTag: tomlFieldTag}
}

// DisallowUnknownFields causes the Decoder to return an error when the
Expand Down Expand Up @@ -73,6 +76,11 @@ func (d *Decoder) EnableUnmarshalerInterface() *Decoder {
return d
}

func (d *Decoder) SetFieldTag(fieldTag string) *Decoder {
d.fieldTag = fieldTag
return d
}

// Decode the whole content of r into v.
//
// By default, values in the document that don't exist in the target Go value
Expand Down Expand Up @@ -125,6 +133,7 @@ func (d *Decoder) Decode(v interface{}) error {
Enabled: d.strict,
},
unmarshalerInterface: d.unmarshalerInterface,
fieldTag: d.fieldTag,
}
dec.p.Reset(b)

Expand Down Expand Up @@ -166,6 +175,9 @@ type decoder struct {

// Current context for the error.
errorContext *errorContext

// fieldTag is the tag name used to decode struct fields.
fieldTag string
}

type errorContext struct {
Expand Down Expand Up @@ -512,7 +524,7 @@ func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn h
v.SetMapIndex(mk, mv)
}
case reflect.Struct:
path, found := structFieldPath(v, string(key.Node().Data))
path, found := d.structFieldPath(v, string(key.Node().Data))
if !found {
d.skipUntilTable = true
return reflect.Value{}, nil
Expand Down Expand Up @@ -1152,7 +1164,7 @@ func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node
v.SetMapIndex(mk, mv)
}
case reflect.Struct:
path, found := structFieldPath(v, string(key.Node().Data))
path, found := d.structFieldPath(v, string(key.Node().Data))
if !found {
d.skipUntilTable = true
break
Expand Down Expand Up @@ -1261,7 +1273,7 @@ type fieldPathsMap = map[string][]int

var globalFieldPathsCache atomic.Value // map[danger.TypeID]fieldPathsMap

func structFieldPath(v reflect.Value, name string) ([]int, bool) {
func (d *decoder) structFieldPath(v reflect.Value, name string) ([]int, bool) {
t := v.Type()

cache, _ := globalFieldPathsCache.Load().(map[danger.TypeID]fieldPathsMap)
Expand All @@ -1270,7 +1282,7 @@ func structFieldPath(v reflect.Value, name string) ([]int, bool) {
if !ok {
fieldPaths = map[string][]int{}

forEachField(t, nil, func(name string, path []int) {
d.forEachField(t, nil, func(name string, path []int) {
fieldPaths[name] = path
// extra copy for the case-insensitive match
fieldPaths[strings.ToLower(name)] = path
Expand All @@ -1291,7 +1303,7 @@ func structFieldPath(v reflect.Value, name string) ([]int, bool) {
return path, ok
}

func forEachField(t reflect.Type, path []int, do func(name string, path []int)) {
func (d *decoder) forEachField(t reflect.Type, path []int, do func(name string, path []int)) {
n := t.NumField()
for i := 0; i < n; i++ {
f := t.Field(i)
Expand All @@ -1304,7 +1316,7 @@ func forEachField(t reflect.Type, path []int, do func(name string, path []int))
fieldPath := append(path, i)
fieldPath = fieldPath[:len(fieldPath):len(fieldPath)]

name := f.Tag.Get("toml")
name := f.Tag.Get(d.fieldTag)
if name == "-" {
continue
}
Expand All @@ -1320,7 +1332,7 @@ func forEachField(t reflect.Type, path []int, do func(name string, path []int))
}

if t2.Kind() == reflect.Struct {
forEachField(t2, fieldPath, do)
d.forEachField(t2, fieldPath, do)
}
continue
}
Expand Down
15 changes: 15 additions & 0 deletions unmarshaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,21 @@ func TestDecodeReaderError(t *testing.T) {
require.Error(t, err)
}

func TestDecoderSetFieldTag(t *testing.T) {
type fieldTagTestType struct {
A string `json:"aaa"`
B string `json:"bbb"`
}

d := toml.NewDecoder(strings.NewReader("aaa = \"foo\"\nbbb = \"bar\""))
d.SetFieldTag("json")
var output fieldTagTestType
err := d.Decode(&output)
require.NoError(t, err)
assert.Equal(t, "foo", output.A)
assert.Equal(t, "bar", output.B)
}

// nolint:funlen
func TestUnmarshal_Integers(t *testing.T) {
examples := []struct {
Expand Down