From 134c61d565f4a602cd63caaafb2365fd4daa223e Mon Sep 17 00:00:00 2001 From: bradub Date: Fri, 5 Jan 2024 17:33:54 +0200 Subject: [PATCH] feat(pagination): add `Page` item type that aggregates Items and Page Info --- pagination/psql_pagination.go | 28 +++++++++++++++------------- pagination/psql_pagination_test.go | 10 +++++----- pagination/types.go | 16 ++++++++++++++-- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/pagination/psql_pagination.go b/pagination/psql_pagination.go index bdc379b..fabbd0b 100644 --- a/pagination/psql_pagination.go +++ b/pagination/psql_pagination.go @@ -7,15 +7,14 @@ import ( "time" "gorm.io/gorm" - "gorm.io/gorm/schema" "k8s.io/utils/ptr" ) -var _ Paginator[schema.Tabler] = (*PSQLPaginator[schema.Tabler])(nil) +var _ Paginator[Tabler] = (*PSQLPaginator[Tabler])(nil) // PSQLPaginator implements the Paginator interface for // PostgreSQL databases. -type PSQLPaginator[T schema.Tabler] struct { +type PSQLPaginator[T Tabler] struct { DB *gorm.DB } @@ -24,26 +23,26 @@ type PSQLPaginator[T schema.Tabler] struct { func (p PSQLPaginator[T]) ListItems( _ context.Context, paginationParams Arguments, -) ([]PaginatedItem[T], PageInfo, error) { +) (*Page[T], error) { var err error if paginationParams.After != nil { paginationParams.afterCursor, err = (&Cursor{}).SetString(*paginationParams.After) if err != nil { - return nil, PageInfo{}, fmt.Errorf("decode after cursor: %w", err) + return nil, fmt.Errorf("decode after cursor: %w", err) } } if paginationParams.Before != nil { paginationParams.beforeCursor, err = (&Cursor{}).SetString(*paginationParams.Before) if err != nil { - return nil, PageInfo{}, fmt.Errorf("decode before cursor: %w", err) + return nil, fmt.Errorf("decode before cursor: %w", err) } } items, err := queryItems[T](p.DB, paginationParams) if err != nil { - return nil, PageInfo{}, fmt.Errorf("query items: %w", err) + return nil, fmt.Errorf("query items: %w", err) } var model T @@ -59,7 +58,7 @@ func (p PSQLPaginator[T]) ListItems( for i := range items { cursor, err := computeItemCursor(items[i]) if err != nil { - return nil, PageInfo{}, fmt.Errorf("compute cursor: %w", err) + return nil, fmt.Errorf("compute cursor: %w", err) } if i == 0 { @@ -83,7 +82,7 @@ func (p PSQLPaginator[T]) ListItems( endCursor, ) if err != nil { - return nil, PageInfo{}, fmt.Errorf("get page info: %w", err) + return nil, fmt.Errorf("get page info: %w", err) } if len(paginatedItems) > 0 { @@ -91,7 +90,10 @@ func (p PSQLPaginator[T]) ListItems( pageInfo.EndCursor = ptr.To(paginatedItems[len(paginatedItems)-1].Cursor) } - return paginatedItems, pageInfo, nil + return &Page[T]{ + Items: paginatedItems, + Info: pageInfo, + }, nil } func queryItems[T any](ses *gorm.DB, pagination Arguments) ([]T, error) { @@ -130,7 +132,7 @@ func queryItems[T any](ses *gorm.DB, pagination Arguments) ([]T, error) { return items, nil } -func getPageInfo[T schema.Tabler]( +func getPageInfo[T Tabler]( db *gorm.DB, pagination Arguments, startCursor *Cursor, @@ -143,7 +145,7 @@ func getPageInfo[T schema.Tabler]( return getBackwardPaginationPageInfo[T](db, pagination, startCursor) } -func getForwardPaginationPageInfo[T schema.Tabler]( +func getForwardPaginationPageInfo[T Tabler]( db *gorm.DB, pagination Arguments, endCursor *Cursor, @@ -194,7 +196,7 @@ func getForwardPaginationPageInfo[T schema.Tabler]( return pageInfo, nil } -func getBackwardPaginationPageInfo[T schema.Tabler]( +func getBackwardPaginationPageInfo[T Tabler]( db *gorm.DB, pagination Arguments, startCursor *Cursor, diff --git a/pagination/psql_pagination_test.go b/pagination/psql_pagination_test.go index 09d2eac..c3457aa 100644 --- a/pagination/psql_pagination_test.go +++ b/pagination/psql_pagination_test.go @@ -288,25 +288,25 @@ func TestListPSQLPaginatedItems(t *testing.T) { req := require.New(t) - paginatedUsers, pageInfo, err := psqlPaginator.ListItems( + page, err := psqlPaginator.ListItems( ctx, test.params, ) test.expectedError(t, err) - t.Logf("\nexpected users: %+v\nactual users: %+v", test.expectedUsers, paginatedUsers) + t.Logf("\nexpected users: %+v\nactual users: %+v", test.expectedUsers, page.Items) for i := range test.expectedUsers { req.Equal( test.expectedUsers[i], - paginatedUsers[i].Item, + page.Items[i].Item, "id: %s, i: %d", - paginatedUsers[i].Item.ID, + page.Items[i].Item.ID, i, ) } - req.Equal(test.expectedPageInfo, pageInfo) + req.Equal(test.expectedPageInfo, page.Info) }) } } diff --git a/pagination/types.go b/pagination/types.go index 1e26d38..35f661e 100644 --- a/pagination/types.go +++ b/pagination/types.go @@ -5,8 +5,14 @@ import ( "errors" "reflect" "time" + + "gorm.io/gorm/schema" ) +// Tabler is an interface that must be implemented by all models. +// Currently it's an alias for Gorm's schema.Tabler interface. +type Tabler = schema.Tabler + // Arguments represents pagination arguments. // The arguments can be used to paginate forward or backward. // @@ -90,6 +96,12 @@ var timeKind = reflect.TypeOf(time.Time{}).Kind() // // When before: cursor is used, the edge closest to cursor must come last in the result edges. // When after: cursor is used, the edge closest to cursor must come first in the result edges. -type Paginator[T any] interface { - ListItems(ctx context.Context, pagination Arguments) ([]PaginatedItem[T], PageInfo, error) +type Paginator[T Tabler] interface { + ListItems(ctx context.Context, pagination Arguments) (*Page[T], error) +} + +// Page represents a paginated result set. +type Page[T Tabler] struct { + Items []PaginatedItem[T] + Info PageInfo }