Skip to content

Commit

Permalink
feat(pagination): add support for UUID array type
Browse files Browse the repository at this point in the history
  • Loading branch information
andreipurposeinplay committed Aug 2, 2024
1 parent 4365cb9 commit d5ed635
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 5 deletions.
8 changes: 5 additions & 3 deletions pagination/cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func bytesToUUIDString(b []byte) string {
}

func isUUID(v reflect.Value) bool {
if v.Kind() != reflect.Slice {
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return false
}
if v.Type().Elem().Kind() != reflect.Uint8 {
Expand Down Expand Up @@ -100,7 +100,7 @@ func computeItemCursor(obj any) (Cursor, error) {
switch idField.Kind() {
case reflect.String:
cursorID = idField.String()
case reflect.Slice:
case reflect.Slice, reflect.Array:
// Check if the slice is a UUID
if isUUID(idField) {
cursorID = bytesToUUIDString(idField.Bytes())
Expand All @@ -116,9 +116,11 @@ func computeItemCursor(obj any) (Cursor, error) {
return Cursor{}, fmt.Errorf("%w: ID", ErrCursorFieldNotFound)
default:
return Cursor{}, fmt.Errorf(
"%w: ID: expected: %s, actual: %s",
"%w: ID: expected: %s/%s/%s, actual: %s",
ErrCursorInvalidValueType,
reflect.String,
reflect.Slice,
reflect.Array,
idField.Kind(),
)
}
Expand Down
27 changes: 26 additions & 1 deletion pagination/cursor_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func TestComputeCursor(t *testing.T) {
// nolint: revive
expectedCursor: "M2Y2ZThkNWEtYjk3Mi00Y2I3LWE3NDEtY2UwM2ZlNzkxNDM5OjIwMjMtMTItMjBUMTM6NTY6MDNa",
},
"UUID": {
"UUIDSlice": {
item: ptr.To(struct {
ID []byte
CreatedAt *time.Time
Expand All @@ -112,6 +112,31 @@ func TestComputeCursor(t *testing.T) {
},
expectedCursor: "",
},
"UUIDArray": {
item: ptr.To(struct {
ID [16]byte
CreatedAt *time.Time
}{
ID: [16]byte{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0},
CreatedAt: timeMustParse(time.RFC3339, "2023-12-20T13:56:03Z"),
}),
expectedError: require.NoError,
// nolint: revive
expectedCursor: "MTIzNDU2NzgtOWFiYy1kZWYwLTEyMzQtNTY3ODlhYmNkZWYwOjIwMjMtMTItMjBUMTM6NTY6MDNa",
},
"arrayNotUUID": {
item: ptr.To(struct {
ID [15]byte
CreatedAt *time.Time
}{
ID: [15]byte{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde},
CreatedAt: timeMustParse(time.RFC3339, "2023-12-20T13:56:03Z"),
}),
expectedError: func(t require.TestingT, err error, i ...any) {
require.ErrorIs(t, err, ErrCursorInvalidValueType)
},
expectedCursor: "",
},
"CustomID": {
item: ptr.To(StructWithCustomID{
IDFieldName: "some_id",
Expand Down
18 changes: 17 additions & 1 deletion pagination/psql_pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pagination
import (
"context"
"fmt"
"reflect"
"slices"
"time"

Expand Down Expand Up @@ -56,7 +57,22 @@ func (p PSQLPaginator[T]) ListItems(
)

for i := range items {
cursor, err := computeItemCursor(items[i])
var (
cursor Cursor
err error
)
// Need to pass address of item, otherwise it will not work for structs
// containing array, as they will be marked unaddressable.
v := reflect.ValueOf(items[i])
if v.Kind() == reflect.Ptr {
cursor, err = computeItemCursor(items[i])
} else {
// If obj is not a pointer, create a pointer to it
ptr := reflect.New(v.Type())
ptr.Elem().Set(v)
cursor, err = computeItemCursor(ptr.Interface())
}

if err != nil {
return nil, fmt.Errorf("compute cursor: %w", err)
}
Expand Down
49 changes: 49 additions & 0 deletions pagination/psql_pagination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ func (*user) TableName() string {
return "users"
}

type machine struct {
ID uuid.UUID `gorm:"column:id;type:uuid;primaryKey"`
Name *string `gorm:"column:name;type:text"`
// nolint: revive
CreatedAt time.Time `gorm:"column:created_at;type:timestamp with time zone;not null;default:now()"`
}

func (machine) TableName() string {
return "machines"
}

func userToCursor(u *user) *string {
return ptr.To((&pagination.Cursor{
ID: u.ID,
Expand All @@ -50,6 +61,11 @@ func setupPsql(t *testing.T) *gorm.DB {
name TEXT,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE TABLE machines (
id UUID PRIMARY KEY,
name TEXT,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
`

psqlContainer := psqldocker.NewContainer(
Expand Down Expand Up @@ -397,3 +413,36 @@ func TestListPSQLPaginatedItemsWithWhereCondtion(t *testing.T) {
EndCursor: userToCursor(users[2]),
})
}

func TestListPSQLNonPointer(t *testing.T) {
req := require.New(t)
ctx := context.Background()

db := setupPsql(t)

name := "macbook"

macbook := machine{
Name: &name,
CreatedAt: time.Now().Add(time.Duration(1) * time.Second),
}

err := db.Create(&macbook).Error
req.NoError(err)

psqlPaginator := pagination.PSQLPaginator[machine]{
DB: db,
}

page, err := psqlPaginator.ListItems(
ctx,
pagination.Arguments{
First: ptr.To(1),
},
)
require.NoError(t, err)

require.Len(t, page.Items, 1)
require.NotNil(t, page.Items[0].Item.Name)
require.Equal(t, "macbook", *(page.Items[0].Item.Name))
}

0 comments on commit d5ed635

Please sign in to comment.