Skip to content

Commit

Permalink
feat(pagination): add Page item type that aggregates Items and Page…
Browse files Browse the repository at this point in the history
… Info
  • Loading branch information
bradub committed Jan 5, 2024
1 parent 4636b02 commit 134c61d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 20 deletions.
28 changes: 15 additions & 13 deletions pagination/psql_pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ import (
"time"

"gorm.io/gorm"
"gorm.io/gorm/schema"
"k8s.io/utils/ptr"
)

var _ Paginator[schema.Tabler] = (*PSQLPaginator[schema.Tabler])(nil)
var _ Paginator[Tabler] = (*PSQLPaginator[Tabler])(nil)

// PSQLPaginator implements the Paginator interface for
// PostgreSQL databases.
type PSQLPaginator[T schema.Tabler] struct {
type PSQLPaginator[T Tabler] struct {
DB *gorm.DB
}

Expand All @@ -24,26 +23,26 @@ type PSQLPaginator[T schema.Tabler] struct {
func (p PSQLPaginator[T]) ListItems(
_ context.Context,
paginationParams Arguments,
) ([]PaginatedItem[T], PageInfo, error) {
) (*Page[T], error) {
var err error

if paginationParams.After != nil {
paginationParams.afterCursor, err = (&Cursor{}).SetString(*paginationParams.After)
if err != nil {
return nil, PageInfo{}, fmt.Errorf("decode after cursor: %w", err)
return nil, fmt.Errorf("decode after cursor: %w", err)
}
}

if paginationParams.Before != nil {
paginationParams.beforeCursor, err = (&Cursor{}).SetString(*paginationParams.Before)
if err != nil {
return nil, PageInfo{}, fmt.Errorf("decode before cursor: %w", err)
return nil, fmt.Errorf("decode before cursor: %w", err)
}
}

items, err := queryItems[T](p.DB, paginationParams)
if err != nil {
return nil, PageInfo{}, fmt.Errorf("query items: %w", err)
return nil, fmt.Errorf("query items: %w", err)
}

var model T
Expand All @@ -59,7 +58,7 @@ func (p PSQLPaginator[T]) ListItems(
for i := range items {
cursor, err := computeItemCursor(items[i])
if err != nil {
return nil, PageInfo{}, fmt.Errorf("compute cursor: %w", err)
return nil, fmt.Errorf("compute cursor: %w", err)
}

if i == 0 {
Expand All @@ -83,15 +82,18 @@ func (p PSQLPaginator[T]) ListItems(
endCursor,
)
if err != nil {
return nil, PageInfo{}, fmt.Errorf("get page info: %w", err)
return nil, fmt.Errorf("get page info: %w", err)
}

if len(paginatedItems) > 0 {
pageInfo.StartCursor = ptr.To(paginatedItems[0].Cursor)
pageInfo.EndCursor = ptr.To(paginatedItems[len(paginatedItems)-1].Cursor)
}

return paginatedItems, pageInfo, nil
return &Page[T]{
Items: paginatedItems,
Info: pageInfo,
}, nil
}

func queryItems[T any](ses *gorm.DB, pagination Arguments) ([]T, error) {
Expand Down Expand Up @@ -130,7 +132,7 @@ func queryItems[T any](ses *gorm.DB, pagination Arguments) ([]T, error) {
return items, nil
}

func getPageInfo[T schema.Tabler](
func getPageInfo[T Tabler](
db *gorm.DB,
pagination Arguments,
startCursor *Cursor,
Expand All @@ -143,7 +145,7 @@ func getPageInfo[T schema.Tabler](
return getBackwardPaginationPageInfo[T](db, pagination, startCursor)
}

func getForwardPaginationPageInfo[T schema.Tabler](
func getForwardPaginationPageInfo[T Tabler](
db *gorm.DB,
pagination Arguments,
endCursor *Cursor,
Expand Down Expand Up @@ -194,7 +196,7 @@ func getForwardPaginationPageInfo[T schema.Tabler](
return pageInfo, nil
}

func getBackwardPaginationPageInfo[T schema.Tabler](
func getBackwardPaginationPageInfo[T Tabler](
db *gorm.DB,
pagination Arguments,
startCursor *Cursor,
Expand Down
10 changes: 5 additions & 5 deletions pagination/psql_pagination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,25 +288,25 @@ func TestListPSQLPaginatedItems(t *testing.T) {

req := require.New(t)

paginatedUsers, pageInfo, err := psqlPaginator.ListItems(
page, err := psqlPaginator.ListItems(
ctx,
test.params,
)
test.expectedError(t, err)

t.Logf("\nexpected users: %+v\nactual users: %+v", test.expectedUsers, paginatedUsers)
t.Logf("\nexpected users: %+v\nactual users: %+v", test.expectedUsers, page.Items)

for i := range test.expectedUsers {
req.Equal(
test.expectedUsers[i],
paginatedUsers[i].Item,
page.Items[i].Item,
"id: %s, i: %d",
paginatedUsers[i].Item.ID,
page.Items[i].Item.ID,
i,
)
}

req.Equal(test.expectedPageInfo, pageInfo)
req.Equal(test.expectedPageInfo, page.Info)
})
}
}
16 changes: 14 additions & 2 deletions pagination/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@ import (
"errors"
"reflect"
"time"

"gorm.io/gorm/schema"
)

// Tabler is an interface that must be implemented by all models.
// Currently it's an alias for Gorm's schema.Tabler interface.
type Tabler = schema.Tabler

// Arguments represents pagination arguments.
// The arguments can be used to paginate forward or backward.
//
Expand Down Expand Up @@ -90,6 +96,12 @@ var timeKind = reflect.TypeOf(time.Time{}).Kind()
//
// When before: cursor is used, the edge closest to cursor must come last in the result edges.
// When after: cursor is used, the edge closest to cursor must come first in the result edges.
type Paginator[T any] interface {
ListItems(ctx context.Context, pagination Arguments) ([]PaginatedItem[T], PageInfo, error)
type Paginator[T Tabler] interface {
ListItems(ctx context.Context, pagination Arguments) (*Page[T], error)
}

// Page represents a paginated result set.
type Page[T Tabler] struct {
Items []PaginatedItem[T]
Info PageInfo
}

0 comments on commit 134c61d

Please sign in to comment.