From 3a2d9c16d1e81405c7e9e224c88fe9cfb5b884a6 Mon Sep 17 00:00:00 2001 From: Theodore Reuter Date: Mon, 27 Jan 2025 16:35:21 -0500 Subject: [PATCH] feat: small fixes based on PR feedback. Creating more link creation methods to clean up endpoint business logic. tests: small tweaks to test based on PR feedback --- src/stapi_fastapi/backends/root_backend.py | 9 +-- src/stapi_fastapi/routers/product_router.py | 31 ++++---- src/stapi_fastapi/routers/root_router.py | 88 ++++++++------------- tests/conftest.py | 18 ++--- tests/test_opportunity.py | 8 +- tests/test_product.py | 2 +- 6 files changed, 62 insertions(+), 94 deletions(-) diff --git a/src/stapi_fastapi/backends/root_backend.py b/src/stapi_fastapi/backends/root_backend.py index 4ae9213..fb3d0e6 100644 --- a/src/stapi_fastapi/backends/root_backend.py +++ b/src/stapi_fastapi/backends/root_backend.py @@ -15,8 +15,7 @@ async def get_orders( self, request: Request, next: str | None, limit: int ) -> ResultE[tuple[list[Order], Maybe[str]]]: """ - Return a list of existing orders and pagination token if applicable - No pagination will return empty string for token + Return a list of existing orders and pagination token if applicable. """ ... @@ -26,8 +25,8 @@ async def get_order(self, order_id: str, request: Request) -> ResultE[Maybe[Orde Should return returns.results.Success[Order] if order is found. - Should return returns.results.Failure[returns.maybe.Nothing] if the order is - not found or if access is denied. + Should return returns.results.Failure[returns.maybe.Nothing] if the + order is not found or if access is denied. A Failure[Exception] will result in a 500. """ @@ -37,7 +36,7 @@ async def get_order_statuses( self, order_id: str, request: Request, next: str | None, limit: int ) -> ResultE[tuple[list[T], Maybe[str]]]: """ - Get statuses for order with `order_id`. + Get statuses for order with `order_id` and return pagination token if applicable Should return returns.results.Success[list[OrderStatus]] if order is found. diff --git a/src/stapi_fastapi/routers/product_router.py b/src/stapi_fastapi/routers/product_router.py index 634c46b..1ca0959 100644 --- a/src/stapi_fastapi/routers/product_router.py +++ b/src/stapi_fastapi/routers/product_router.py @@ -176,22 +176,12 @@ async def search_opportunities( self, search, request, next, limit ): case Success((features, Some(pagination_token))): - links.append(self.order_link(request, "create-order")) - body = search.model_dump() + links.append(self.order_link(request)) + body = search.model_dump(mode="json") body["next"] = pagination_token - links.append( - Link( - href=str( - request.url.remove_query_params(keys=["next", "limit"]) - ), - rel="next", - type=TYPE_JSON, - method="POST", - body=body, - ) - ) + links.append(self.pagination_link(request, body)) case Success((features, Nothing)): # noqa: F841 - links.append(self.order_link(request, "create-order")) + links.append(self.order_link(request)) case Failure(e) if isinstance(e, ConstraintsException): raise e case Failure(e): @@ -249,14 +239,23 @@ async def create_order( case x: raise AssertionError(f"Expected code to be unreachable {x}") - def order_link(self, request: Request, suffix: str): + def order_link(self, request: Request): return Link( href=str( request.url_for( - f"{self.root_router.name}:{self.product.id}:{suffix}", + f"{self.root_router.name}:{self.product.id}:create-order", ), ), rel="create-order", type=TYPE_JSON, method="POST", ) + + def pagination_link(self, request: Request, body: dict[str, str | dict]): + return Link( + href=str(request.url.remove_query_params(keys=["next", "limit"])), + rel="next", + type=TYPE_JSON, + method="POST", + body=body, + ) diff --git a/src/stapi_fastapi/routers/root_router.py b/src/stapi_fastapi/routers/root_router.py index 1722340..34577c1 100644 --- a/src/stapi_fastapi/routers/root_router.py +++ b/src/stapi_fastapi/routers/root_router.py @@ -42,7 +42,7 @@ def __init__( self.conformances = conformances self.openapi_endpoint_name = openapi_endpoint_name self.docs_endpoint_name = docs_endpoint_name - self.product_ids: list = [] + self.product_ids: list[str] = [] # A dict is used to track the product routers so we can ensure # idempotentcy in case a product is added multiple times, and also to @@ -164,15 +164,7 @@ def get_products( ), ] if end > 0 and end < len(self.product_ids): - links.append( - Link( - href=str( - request.url.include_query_params(next=self.product_ids[end]), - ), - rel="next", - type=TYPE_JSON, - ) - ) + links.append(self.pagination_link(request, self.product_ids[end])) return ProductsCollection( products=[ self.product_routers[product_id].get_product(request) @@ -184,36 +176,26 @@ def get_products( async def get_orders( self, request: Request, next: str | None = None, limit: int = 10 ) -> OrderCollection: - # links: list[Link] = [] + links: list[Link] = [] match await self.backend.get_orders(request, next, limit): case Success((orders, Some(pagination_token))): for order in orders: - order.links.append(self.order_link(request, "get-order", order)) - links = [ - Link( - href=str( - request.url.include_query_params(next=pagination_token) - ), - rel="next", - type=TYPE_JSON, - ) - ] + order.links.append(self.order_link(request, order)) + links.append(self.pagination_link(request, pagination_token)) case Success((orders, Nothing)): # noqa: F841 for order in orders: - order.links.append(self.order_link(request, "get-order", order)) - links = [] + order.links.append(self.order_link(request, order)) + case Failure(ValueError()): + raise NotFoundException(detail="Error finding pagination token") case Failure(e): logger.error( "An error occurred while retrieving orders: %s", traceback.format_exception(e), ) - if isinstance(e, ValueError): - raise NotFoundException(detail="Error finding pagination token") - else: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error finding Orders", - ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error finding Orders", + ) case _: raise AssertionError("Expected code to be unreachable") return OrderCollection(features=orders, links=links) @@ -251,34 +233,21 @@ async def get_order_statuses( links: list[Link] = [] match await self.backend.get_order_statuses(order_id, request, next, limit): case Success((statuses, Some(pagination_token))): - links.append( - self.order_statuses_link(request, "list-order-statuses", order_id) - ) - links.append( - Link( - href=str( - request.url.include_query_params(next=pagination_token) - ), - rel="next", - type=TYPE_JSON, - ) - ) + links.append(self.order_statuses_link(request, order_id)) + links.append(self.pagination_link(request, pagination_token)) case Success((statuses, Nothing)): # noqa: F841 - links.append( - self.order_statuses_link(request, "list-order-statuses", order_id) - ) + links.append(self.order_statuses_link(request, order_id)) + case Failure(KeyError()): + raise NotFoundException("Error finding pagination token") case Failure(e): logger.error( "An error occurred while retrieving order statuses: %s", traceback.format_exception(e), ) - if isinstance(e, KeyError): - raise NotFoundException(detail="Error finding pagination token") - else: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error finding Order Statuses", - ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error finding Order Statuses", + ) case _: raise AssertionError("Expected code to be unreachable") return OrderStatuses(statuses=statuses, links=links) @@ -314,21 +283,28 @@ def add_order_links(self, order: Order, request: Request): ), ) - def order_link(self, request: Request, link_suffix: str, order: Order): + def order_link(self, request: Request, order: Order): return Link( - href=str(request.url_for(f"{self.name}:{link_suffix}", order_id=order.id)), + href=str(request.url_for(f"{self.name}:get-order", order_id=order.id)), rel="self", type=TYPE_JSON, ) - def order_statuses_link(self, request: Request, link_suffix: str, order_id: str): + def order_statuses_link(self, request: Request, order_id: str): return Link( href=str( request.url_for( - f"{self.name}:{link_suffix}", + f"{self.name}:list-order-statuses", order_id=order_id, ) ), rel="self", type=TYPE_JSON, ) + + def pagination_link(self, request: Request, pagination_token: str): + return Link( + href=str(request.url.include_query_params(next=pagination_token)), + rel="next", + type=TYPE_JSON, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 9d97fce..9417116 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ -import copy from collections.abc import Iterator from typing import Any, Callable from urllib.parse import urljoin @@ -173,18 +172,16 @@ def pagination_tester( assert len(resp_body[target]) <= limit retrieved.extend(resp_body[target]) - next_token = next( - (d["href"] for d in resp_body["links"] if d["rel"] == "next"), None - ) + next_url = next((d["href"] for d in resp_body["links"] if d["rel"] == "next"), None) - while next_token: - url = copy.deepcopy(next_token) + while next_url: + url = next_url if method == "POST": - next_token = next( + next_url = next( (d["body"]["next"] for d in resp_body["links"] if d["rel"] == "next"), ) - res = make_request(stapi_client, url, method, body, next_token, limit) + res = make_request(stapi_client, url, method, body, next_url, limit) assert res.status_code == status.HTTP_200_OK assert len(resp_body[target]) <= limit resp_body = res.json() @@ -192,7 +189,7 @@ def pagination_tester( # get url w/ query params for next call if exists, and POST body if necessary if resp_body["links"]: - next_token = next( + next_url = next( (d["href"] for d in resp_body["links"] if d["rel"] == "next"), None ) body = next( @@ -200,11 +197,10 @@ def pagination_tester( None, ) else: - next_token = None + next_url = None assert len(retrieved) == len(expected_returns) assert retrieved == expected_returns - # assert retrieved[:2] == expected_returns[:2] def make_request( diff --git a/tests/test_opportunity.py b/tests/test_opportunity.py index 914fd37..38794d4 100644 --- a/tests/test_opportunity.py +++ b/tests/test_opportunity.py @@ -18,8 +18,7 @@ @pytest.fixture def mock_test_spotlight_opportunities() -> list[Opportunity]: """Fixture to create mock data for Opportunities for `test-spotlight-1`.""" - now = datetime.now(timezone.utc) # Use timezone-aware datetime - start = now + start = datetime.now(timezone.utc) # Use timezone-aware datetime end = start + timedelta(days=5) # Create a list of mock opportunities for the given product @@ -60,10 +59,9 @@ def test_search_opportunities_response( product_backend._opportunities = mock_test_spotlight_opportunities now = datetime.now(UTC) - start = now - end = start + timedelta(days=5) + end = now + timedelta(days=5) format = "%Y-%m-%dT%H:%M:%S.%f%z" - start_string = rfc3339_strftime(start, format) + start_string = rfc3339_strftime(now, format) end_string = rfc3339_strftime(end, format) request_payload = { diff --git a/tests/test_product.py b/tests/test_product.py index 9e32d5b..cb2a45b 100644 --- a/tests/test_product.py +++ b/tests/test_product.py @@ -68,7 +68,7 @@ def test_product_order_parameters_response( @pytest.mark.parametrize("limit", [0, 1, 2, 4]) -def test_product_pagination( +def test_get_products_pagination( limit: int, stapi_client: TestClient, mock_product_test_spotlight,