Skip to content

Commit

Permalink
Only update struct fields when values change
Browse files Browse the repository at this point in the history
  • Loading branch information
jackcook committed Oct 5, 2022
1 parent 72b2133 commit 2cefdca
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 22 deletions.
38 changes: 31 additions & 7 deletions bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ func bindStruct(prefix, id string, data map[string]string, typ reflect.Type, val
// Set ID
val := res.Elem()
fi := reflect.Indirect(val).FieldByName("ID")
fi.SetString(id)

if fi.String() != id {
fi.SetString(id)
}

structField.Set(res)
} else {
Expand Down Expand Up @@ -177,7 +180,10 @@ func bindStruct(prefix, id string, data map[string]string, typ reflect.Type, val
// Set ID
val := ptr.Elem()
fi := reflect.Indirect(val).FieldByName("ID")
fi.SetString(itemID)

if fi.String() != itemID {
fi.SetString(itemID)
}

arr = reflect.Append(arr, ptr)
} else {
Expand Down Expand Up @@ -265,7 +271,13 @@ func setFieldWithKind(valueKind reflect.Kind, val string, structField reflect.Va
case reflect.Struct:
switch structField.Type() {
case reflect.TypeOf(time.Now()):
timeInt, _ := strconv.Atoi(val)
timeInt, _ := strconv.ParseInt(val, 10, 64)
existingTimeInt := structField.MethodByName("Unix").Call([]reflect.Value{})[0].Int()

if existingTimeInt == timeInt {
return nil
}

timeVal := time.Unix(int64(timeInt), 0)
structField.Set(reflect.ValueOf(timeVal))
default:
Expand All @@ -289,7 +301,10 @@ func setIntField(value string, bitSize int, field reflect.Value) error {
return err
}

field.SetInt(val)
if val != field.Int() {
field.SetInt(val)
}

return nil
}

Expand All @@ -304,7 +319,10 @@ func setUintField(value string, bitSize int, field reflect.Value) error {
return err
}

field.SetUint(val)
if val != field.Uint() {
field.SetUint(val)
}

return nil
}

Expand All @@ -319,7 +337,10 @@ func setBoolField(value string, field reflect.Value) error {
return err
}

field.SetBool(val)
if val != field.Bool() {
field.SetBool(val)
}

return nil
}

Expand All @@ -334,6 +355,9 @@ func setFloatField(value string, bitSize int, field reflect.Value) error {
return err
}

field.SetFloat(val)
if val != field.Float() {
field.SetFloat(val)
}

return nil
}
32 changes: 19 additions & 13 deletions load.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@ import (
// struct, and then setting base attributes (such as ID). See Store for more
// information on creating structs for grocery. Load can be used like so:
//
// type Item struct {
// grocery.Base
// Name string `grocery:"name"`
// }
// type Item struct {
// grocery.Base
// Name string `grocery:"name"`
// }
//
// itemID := "asdf"
// item := new(Item)
// db.Load(itemID, item)
// itemID := "asdf"
// item := new(Item)
// db.Load(itemID, item)
func Load(id string, ptr interface{}) error {
if reflect.TypeOf(ptr).Kind() != reflect.Ptr || reflect.TypeOf(ptr).Elem().Kind() != reflect.Struct {
return errors.New("ptr must be a struct pointer")
}

// Get prefix for the struct (e.g. 'answer:' from Answer)
// Get prefix for the struct (e.g. 'item:' from Item)
prefix := strings.ToLower(reflect.TypeOf(ptr).Elem().Name())

// Load object data
res, _ := C.HGetAll(ctx, prefix + ":" + id).Result()
res, _ := C.HGetAll(ctx, prefix+":"+id).Result()

if err := bind(prefix, id, res, ptr); err != nil {
return err
Expand All @@ -38,7 +38,10 @@ func Load(id string, ptr interface{}) error {
// Set the ID before returning
val := reflect.ValueOf(ptr)
fi := reflect.Indirect(val).FieldByName("ID")
fi.SetString(id)

if fi.String() != id {
fi.SetString(id)
}

// Call post-load hook
postLoad := reflect.ValueOf(ptr).MethodByName("PostLoad")
Expand All @@ -61,15 +64,15 @@ func LoadAll[T any](ids []string, values *[]T) error {
return errors.New("len(ids) must be greater than zero")
}

// Get prefix for the struct (e.g. 'answer:' from Answer)
// Get prefix for the struct (e.g. 'item:' from Item)
prefix := strings.ToLower(reflect.ValueOf(values).Elem().Index(0).Type().Name())

// Pipeline all HGetAll commands
pip := C.Pipeline()
cmds := make([]*redis.StringStringMapCmd, len(ids))

for i, id := range ids {
cmds[i] = pip.HGetAll(ctx, prefix + ":" + id)
cmds[i] = pip.HGetAll(ctx, prefix+":"+id)
}

pip.Exec(ctx)
Expand All @@ -85,7 +88,10 @@ func LoadAll[T any](ids []string, values *[]T) error {
// Set ID
val := reflect.ValueOf(itemPtr).Elem()
fi := reflect.Indirect(val).FieldByName("ID")
fi.SetString(ids[i])

if fi.String() != ids[i] {
fi.SetString(ids[i])
}

// Call post-load hook
postLoad := reflect.ValueOf(itemPtr).MethodByName("PostLoad")
Expand Down
5 changes: 4 additions & 1 deletion update.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ func updateInternal(id string, ptr interface{}, opts *UpdateOptions) error {
if opts.isStore {
// Set ID
fi := reflect.Indirect(val).FieldByName("ID")
fi.SetString(id)

if fi.String() != id {
fi.SetString(id)
}

// Set createdAt timestamp
pip.HSet(ctx, prefix+":"+id, "createdAt", time.Now().Unix())
Expand Down
2 changes: 1 addition & 1 deletion update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestSetTimeValue(t *testing.T) {
t.Error(err)
}

if model.TimeVal != m.TimeVal {
if m.TimeVal.Sub(model.TimeVal).Abs() > time.Second {
t.Errorf("TestSetTimeValue FAILED, initial value was not set correctly")
}
}

0 comments on commit 2cefdca

Please sign in to comment.