diff --git a/pagination/cursor.go b/pagination/cursor.go index 79595af..9dc9c43 100644 --- a/pagination/cursor.go +++ b/pagination/cursor.go @@ -64,7 +64,7 @@ func bytesToUUIDString(b []byte) string { } func isUUID(v reflect.Value) bool { - if v.Kind() != reflect.Slice { + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { return false } if v.Type().Elem().Kind() != reflect.Uint8 { @@ -100,7 +100,7 @@ func computeItemCursor(obj any) (Cursor, error) { switch idField.Kind() { case reflect.String: cursorID = idField.String() - case reflect.Slice: + case reflect.Slice, reflect.Array: // Check if the slice is a UUID if isUUID(idField) { cursorID = bytesToUUIDString(idField.Bytes()) @@ -116,9 +116,11 @@ func computeItemCursor(obj any) (Cursor, error) { return Cursor{}, fmt.Errorf("%w: ID", ErrCursorFieldNotFound) default: return Cursor{}, fmt.Errorf( - "%w: ID: expected: %s, actual: %s", + "%w: ID: expected: %s/%s/%s, actual: %s", ErrCursorInvalidValueType, reflect.String, + reflect.Slice, + reflect.Array, idField.Kind(), ) } diff --git a/pagination/cursor_internal_test.go b/pagination/cursor_internal_test.go index 3e8f43f..f1e99eb 100644 --- a/pagination/cursor_internal_test.go +++ b/pagination/cursor_internal_test.go @@ -87,7 +87,7 @@ func TestComputeCursor(t *testing.T) { // nolint: revive expectedCursor: "M2Y2ZThkNWEtYjk3Mi00Y2I3LWE3NDEtY2UwM2ZlNzkxNDM5OjIwMjMtMTItMjBUMTM6NTY6MDNa", }, - "UUID": { + "UUIDSlice": { item: ptr.To(struct { ID []byte CreatedAt *time.Time @@ -112,6 +112,31 @@ func TestComputeCursor(t *testing.T) { }, expectedCursor: "", }, + "UUIDArray": { + item: ptr.To(struct { + ID [16]byte + CreatedAt *time.Time + }{ + ID: [16]byte{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0}, + CreatedAt: timeMustParse(time.RFC3339, "2023-12-20T13:56:03Z"), + }), + expectedError: require.NoError, + // nolint: revive + expectedCursor: "MTIzNDU2NzgtOWFiYy1kZWYwLTEyMzQtNTY3ODlhYmNkZWYwOjIwMjMtMTItMjBUMTM6NTY6MDNa", + }, + "arrayNotUUID": { + item: ptr.To(struct { + ID [15]byte + CreatedAt *time.Time + }{ + ID: [15]byte{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde}, + CreatedAt: timeMustParse(time.RFC3339, "2023-12-20T13:56:03Z"), + }), + expectedError: func(t require.TestingT, err error, i ...any) { + require.ErrorIs(t, err, ErrCursorInvalidValueType) + }, + expectedCursor: "", + }, "CustomID": { item: ptr.To(StructWithCustomID{ IDFieldName: "some_id", diff --git a/pagination/psql_pagination.go b/pagination/psql_pagination.go index 41dcedb..a84e5a4 100644 --- a/pagination/psql_pagination.go +++ b/pagination/psql_pagination.go @@ -3,6 +3,7 @@ package pagination import ( "context" "fmt" + "reflect" "slices" "time" @@ -56,7 +57,22 @@ func (p PSQLPaginator[T]) ListItems( ) for i := range items { - cursor, err := computeItemCursor(items[i]) + var ( + cursor Cursor + err error + ) + // Need to pass address of item, otherwise it will not work for structs + // containing array, as they will be marked unaddressable. + v := reflect.ValueOf(items[i]) + if v.Kind() == reflect.Ptr { + cursor, err = computeItemCursor(items[i]) + } else { + // If obj is not a pointer, create a pointer to it + ptr := reflect.New(v.Type()) + ptr.Elem().Set(v) + cursor, err = computeItemCursor(ptr.Interface()) + } + if err != nil { return nil, fmt.Errorf("compute cursor: %w", err) } diff --git a/pagination/psql_pagination_test.go b/pagination/psql_pagination_test.go index 9f559a9..d626dfe 100644 --- a/pagination/psql_pagination_test.go +++ b/pagination/psql_pagination_test.go @@ -28,6 +28,17 @@ func (*user) TableName() string { return "users" } +type machine struct { + ID uuid.UUID `gorm:"column:id;type:uuid;primaryKey"` + Name *string `gorm:"column:name;type:text"` + // nolint: revive + CreatedAt time.Time `gorm:"column:created_at;type:timestamp with time zone;not null;default:now()"` +} + +func (machine) TableName() string { + return "machines" +} + func userToCursor(u *user) *string { return ptr.To((&pagination.Cursor{ ID: u.ID, @@ -50,6 +61,11 @@ func setupPsql(t *testing.T) *gorm.DB { name TEXT, created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() ); + CREATE TABLE machines ( + id UUID PRIMARY KEY, + name TEXT, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + ); ` psqlContainer := psqldocker.NewContainer( @@ -397,3 +413,36 @@ func TestListPSQLPaginatedItemsWithWhereCondtion(t *testing.T) { EndCursor: userToCursor(users[2]), }) } + +func TestListPSQLNonPointer(t *testing.T) { + req := require.New(t) + ctx := context.Background() + + db := setupPsql(t) + + name := "macbook" + + macbook := machine{ + Name: &name, + CreatedAt: time.Now().Add(time.Duration(1) * time.Second), + } + + err := db.Create(&macbook).Error + req.NoError(err) + + psqlPaginator := pagination.PSQLPaginator[machine]{ + DB: db, + } + + page, err := psqlPaginator.ListItems( + ctx, + pagination.Arguments{ + First: ptr.To(1), + }, + ) + require.NoError(t, err) + + require.Len(t, page.Items, 1) + require.NotNil(t, page.Items[0].Item.Name) + require.Equal(t, "macbook", *(page.Items[0].Item.Name)) +}