Skip to content

Commit

Permalink
feat: small fixes based on PR feedback. Creating more link creation m…
Browse files Browse the repository at this point in the history
…ethods to clean up endpoint business logic. tests: small tweaks to test based on PR feedback
  • Loading branch information
theodorehreuter committed Jan 27, 2025
1 parent c148b7d commit 3a2d9c1
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 94 deletions.
9 changes: 4 additions & 5 deletions src/stapi_fastapi/backends/root_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ async def get_orders(
self, request: Request, next: str | None, limit: int
) -> ResultE[tuple[list[Order], Maybe[str]]]:
"""
Return a list of existing orders and pagination token if applicable
No pagination will return empty string for token
Return a list of existing orders and pagination token if applicable.
"""
...

Expand All @@ -26,8 +25,8 @@ async def get_order(self, order_id: str, request: Request) -> ResultE[Maybe[Orde
Should return returns.results.Success[Order] if order is found.
Should return returns.results.Failure[returns.maybe.Nothing] if the order is
not found or if access is denied.
Should return returns.results.Failure[returns.maybe.Nothing] if the
order is not found or if access is denied.
A Failure[Exception] will result in a 500.
"""
Expand All @@ -37,7 +36,7 @@ async def get_order_statuses(
self, order_id: str, request: Request, next: str | None, limit: int
) -> ResultE[tuple[list[T], Maybe[str]]]:
"""
Get statuses for order with `order_id`.
Get statuses for order with `order_id` and return pagination token if applicable
Should return returns.results.Success[list[OrderStatus]] if order is found.
Expand Down
31 changes: 15 additions & 16 deletions src/stapi_fastapi/routers/product_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,22 +176,12 @@ async def search_opportunities(
self, search, request, next, limit
):
case Success((features, Some(pagination_token))):
links.append(self.order_link(request, "create-order"))
body = search.model_dump()
links.append(self.order_link(request))
body = search.model_dump(mode="json")
body["next"] = pagination_token
links.append(
Link(
href=str(
request.url.remove_query_params(keys=["next", "limit"])
),
rel="next",
type=TYPE_JSON,
method="POST",
body=body,
)
)
links.append(self.pagination_link(request, body))
case Success((features, Nothing)): # noqa: F841
links.append(self.order_link(request, "create-order"))
links.append(self.order_link(request))
case Failure(e) if isinstance(e, ConstraintsException):
raise e
case Failure(e):
Expand Down Expand Up @@ -249,14 +239,23 @@ async def create_order(
case x:
raise AssertionError(f"Expected code to be unreachable {x}")

def order_link(self, request: Request, suffix: str):
def order_link(self, request: Request):
return Link(
href=str(
request.url_for(
f"{self.root_router.name}:{self.product.id}:{suffix}",
f"{self.root_router.name}:{self.product.id}:create-order",
),
),
rel="create-order",
type=TYPE_JSON,
method="POST",
)

def pagination_link(self, request: Request, body: dict[str, str | dict]):
return Link(
href=str(request.url.remove_query_params(keys=["next", "limit"])),
rel="next",
type=TYPE_JSON,
method="POST",
body=body,
)
88 changes: 32 additions & 56 deletions src/stapi_fastapi/routers/root_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.conformances = conformances
self.openapi_endpoint_name = openapi_endpoint_name
self.docs_endpoint_name = docs_endpoint_name
self.product_ids: list = []
self.product_ids: list[str] = []

# A dict is used to track the product routers so we can ensure
# idempotentcy in case a product is added multiple times, and also to
Expand Down Expand Up @@ -164,15 +164,7 @@ def get_products(
),
]
if end > 0 and end < len(self.product_ids):
links.append(
Link(
href=str(
request.url.include_query_params(next=self.product_ids[end]),
),
rel="next",
type=TYPE_JSON,
)
)
links.append(self.pagination_link(request, self.product_ids[end]))
return ProductsCollection(
products=[
self.product_routers[product_id].get_product(request)
Expand All @@ -184,36 +176,26 @@ def get_products(
async def get_orders(
self, request: Request, next: str | None = None, limit: int = 10
) -> OrderCollection:
# links: list[Link] = []
links: list[Link] = []
match await self.backend.get_orders(request, next, limit):
case Success((orders, Some(pagination_token))):
for order in orders:
order.links.append(self.order_link(request, "get-order", order))
links = [
Link(
href=str(
request.url.include_query_params(next=pagination_token)
),
rel="next",
type=TYPE_JSON,
)
]
order.links.append(self.order_link(request, order))
links.append(self.pagination_link(request, pagination_token))
case Success((orders, Nothing)): # noqa: F841
for order in orders:
order.links.append(self.order_link(request, "get-order", order))
links = []
order.links.append(self.order_link(request, order))
case Failure(ValueError()):
raise NotFoundException(detail="Error finding pagination token")
case Failure(e):
logger.error(
"An error occurred while retrieving orders: %s",
traceback.format_exception(e),
)
if isinstance(e, ValueError):
raise NotFoundException(detail="Error finding pagination token")
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error finding Orders",
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error finding Orders",
)
case _:
raise AssertionError("Expected code to be unreachable")
return OrderCollection(features=orders, links=links)
Expand Down Expand Up @@ -251,34 +233,21 @@ async def get_order_statuses(
links: list[Link] = []
match await self.backend.get_order_statuses(order_id, request, next, limit):
case Success((statuses, Some(pagination_token))):
links.append(
self.order_statuses_link(request, "list-order-statuses", order_id)
)
links.append(
Link(
href=str(
request.url.include_query_params(next=pagination_token)
),
rel="next",
type=TYPE_JSON,
)
)
links.append(self.order_statuses_link(request, order_id))
links.append(self.pagination_link(request, pagination_token))
case Success((statuses, Nothing)): # noqa: F841
links.append(
self.order_statuses_link(request, "list-order-statuses", order_id)
)
links.append(self.order_statuses_link(request, order_id))
case Failure(KeyError()):
raise NotFoundException("Error finding pagination token")
case Failure(e):
logger.error(
"An error occurred while retrieving order statuses: %s",
traceback.format_exception(e),
)
if isinstance(e, KeyError):
raise NotFoundException(detail="Error finding pagination token")
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error finding Order Statuses",
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error finding Order Statuses",
)
case _:
raise AssertionError("Expected code to be unreachable")
return OrderStatuses(statuses=statuses, links=links)
Expand Down Expand Up @@ -314,21 +283,28 @@ def add_order_links(self, order: Order, request: Request):
),
)

def order_link(self, request: Request, link_suffix: str, order: Order):
def order_link(self, request: Request, order: Order):
return Link(
href=str(request.url_for(f"{self.name}:{link_suffix}", order_id=order.id)),
href=str(request.url_for(f"{self.name}:get-order", order_id=order.id)),
rel="self",
type=TYPE_JSON,
)

def order_statuses_link(self, request: Request, link_suffix: str, order_id: str):
def order_statuses_link(self, request: Request, order_id: str):
return Link(
href=str(
request.url_for(
f"{self.name}:{link_suffix}",
f"{self.name}:list-order-statuses",
order_id=order_id,
)
),
rel="self",
type=TYPE_JSON,
)

def pagination_link(self, request: Request, pagination_token: str):
return Link(
href=str(request.url.include_query_params(next=pagination_token)),
rel="next",
type=TYPE_JSON,
)
18 changes: 7 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
from collections.abc import Iterator
from typing import Any, Callable
from urllib.parse import urljoin
Expand Down Expand Up @@ -173,38 +172,35 @@ def pagination_tester(

assert len(resp_body[target]) <= limit
retrieved.extend(resp_body[target])
next_token = next(
(d["href"] for d in resp_body["links"] if d["rel"] == "next"), None
)
next_url = next((d["href"] for d in resp_body["links"] if d["rel"] == "next"), None)

while next_token:
url = copy.deepcopy(next_token)
while next_url:
url = next_url
if method == "POST":
next_token = next(
next_url = next(
(d["body"]["next"] for d in resp_body["links"] if d["rel"] == "next"),
)

res = make_request(stapi_client, url, method, body, next_token, limit)
res = make_request(stapi_client, url, method, body, next_url, limit)
assert res.status_code == status.HTTP_200_OK
assert len(resp_body[target]) <= limit
resp_body = res.json()
retrieved.extend(resp_body[target])

# get url w/ query params for next call if exists, and POST body if necessary
if resp_body["links"]:
next_token = next(
next_url = next(
(d["href"] for d in resp_body["links"] if d["rel"] == "next"), None
)
body = next(
(d.get("body") for d in resp_body["links"] if d.get("body")),
None,
)
else:
next_token = None
next_url = None

assert len(retrieved) == len(expected_returns)
assert retrieved == expected_returns
# assert retrieved[:2] == expected_returns[:2]


def make_request(
Expand Down
8 changes: 3 additions & 5 deletions tests/test_opportunity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
@pytest.fixture
def mock_test_spotlight_opportunities() -> list[Opportunity]:
"""Fixture to create mock data for Opportunities for `test-spotlight-1`."""
now = datetime.now(timezone.utc) # Use timezone-aware datetime
start = now
start = datetime.now(timezone.utc) # Use timezone-aware datetime
end = start + timedelta(days=5)

# Create a list of mock opportunities for the given product
Expand Down Expand Up @@ -60,10 +59,9 @@ def test_search_opportunities_response(
product_backend._opportunities = mock_test_spotlight_opportunities

now = datetime.now(UTC)
start = now
end = start + timedelta(days=5)
end = now + timedelta(days=5)
format = "%Y-%m-%dT%H:%M:%S.%f%z"
start_string = rfc3339_strftime(start, format)
start_string = rfc3339_strftime(now, format)
end_string = rfc3339_strftime(end, format)

request_payload = {
Expand Down
2 changes: 1 addition & 1 deletion tests/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_product_order_parameters_response(


@pytest.mark.parametrize("limit", [0, 1, 2, 4])
def test_product_pagination(
def test_get_products_pagination(
limit: int,
stapi_client: TestClient,
mock_product_test_spotlight,
Expand Down

0 comments on commit 3a2d9c1

Please sign in to comment.