Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine pagination code, other minor refinements #140

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
token.
- Moved `OrderCollection` construction from the root backend to the `RootRouter`
`get_orders` method.
- Renamed `OpportunityRequest` to `OpportunityPayload` so that would not be confused as
being a subclass of the Starlette/FastAPI Request class.

### Fixed

- Opportunities Search result now has the search body in the `create-order` link.

## [v0.5.0] - 2025-01-08

Expand Down
4 changes: 2 additions & 2 deletions src/stapi_fastapi/backends/product_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from returns.maybe import Maybe
from returns.result import ResultE

from stapi_fastapi.models.opportunity import Opportunity, OpportunityRequest
from stapi_fastapi.models.opportunity import Opportunity, OpportunityPayload
from stapi_fastapi.models.order import Order, OrderPayload
from stapi_fastapi.routers.product_router import ProductRouter

SearchOpportunities = Callable[
[ProductRouter, OpportunityRequest, str | None, int, Request],
[ProductRouter, OpportunityPayload, str | None, int, Request],
Coroutine[Any, Any, ResultE[tuple[list[Opportunity], Maybe[str]]]],
]
"""
Expand Down
12 changes: 9 additions & 3 deletions src/stapi_fastapi/models/opportunity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, TypeVar
from typing import Any, Literal, TypeVar

from geojson_pydantic import Feature, FeatureCollection
from geojson_pydantic.geometries import Geometry
Expand All @@ -16,16 +16,22 @@ class OpportunityProperties(BaseModel):
model_config = ConfigDict(extra="allow")


class OpportunityRequest(BaseModel):
class OpportunityPayload(BaseModel):
datetime: DatetimeInterval
geometry: Geometry
# TODO: validate the CQL2 filter?
filter: CQL2Filter | None = None

next: str | None = None
limit: int = 10

model_config = ConfigDict(strict=True)

def search_body(self) -> dict[str, Any]:
return self.model_dump(mode="json", include={"datetime", "geometry", "filter"})

def body(self) -> dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the addition of these small helper functions here to help clean up the endpoint business logic.

return self.model_dump(mode="json")


G = TypeVar("G", bound=Geometry)
P = TypeVar("P", bound=OpportunityProperties)
Expand Down
26 changes: 14 additions & 12 deletions src/stapi_fastapi/routers/product_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from stapi_fastapi.exceptions import ConstraintsException
from stapi_fastapi.models.opportunity import (
OpportunityCollection,
OpportunityRequest,
OpportunityPayload,
)
from stapi_fastapi.models.order import Order, OrderPayload
from stapi_fastapi.models.product import Product
Expand Down Expand Up @@ -163,7 +163,7 @@ def get_product(self, request: Request) -> Product:

async def search_opportunities(
self,
search: OpportunityRequest,
search: OpportunityPayload,
request: Request,
) -> OpportunityCollection:
"""
Expand All @@ -178,13 +178,10 @@ async def search_opportunities(
request,
):
case Success((features, Some(pagination_token))):
links.append(self.order_link(request))
search.next = pagination_token
links.append(
self.pagination_link(request, search.model_dump(mode="json"))
)
links.append(self.order_link(request, search))
links.append(self.pagination_link(request, search, pagination_token))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this pushes the logic for how to construct the link down into the pagination_link method

case Success((features, Nothing)): # noqa: F841
links.append(self.order_link(request))
links.append(self.order_link(request, search))
case Failure(e) if isinstance(e, ConstraintsException):
raise e
case Failure(e):
Expand Down Expand Up @@ -224,7 +221,7 @@ async def create_order(
request,
):
case Success(order):
self.root_router.add_order_links(order, request)
order.links.extend(self.root_router.order_links(order, request))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improves the semantics of the function to make it more functional -- e.g., it doesn't mutate it's parameters, and instead just returns an object that the caller can then use for mutation

location = str(self.root_router.generate_order_href(request, order.id))
response.headers["Location"] = location
return order
Expand All @@ -242,7 +239,7 @@ async def create_order(
case x:
raise AssertionError(f"Expected code to be unreachable {x}")

def order_link(self, request: Request):
def order_link(self, request: Request, opp_req: OpportunityPayload):
return Link(
href=str(
request.url_for(
Expand All @@ -252,11 +249,16 @@ def order_link(self, request: Request):
rel="create-order",
type=TYPE_JSON,
method="POST",
body=opp_req.search_body(),
)

def pagination_link(self, request: Request, body: dict[str, str | dict]):
def pagination_link(
self, request: Request, opp_req: OpportunityPayload, pagination_token: str
):
body = opp_req.body()
body["next"] = pagination_token
return Link(
href=str(request.url.remove_query_params(keys=["next", "limit"])),
href=str(request.url),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there won't be query params on this, so they don't need to be removed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah! Right you are

rel="next",
type=TYPE_JSON,
method="POST",
Expand Down
42 changes: 18 additions & 24 deletions src/stapi_fastapi/routers/root_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def get_products(
),
]
if end > 0 and end < len(self.product_ids):
links.append(self.pagination_link(request, self.product_ids[end]))
links.append(self.pagination_link(request, self.product_ids[end], limit))
return ProductsCollection(
products=[
self.product_routers[product_id].get_product(request)
Expand All @@ -182,13 +182,14 @@ async def get_orders(
) -> OrderCollection:
links: list[Link] = []
match await self._get_orders(next, limit, request):
case Success((orders, Some(pagination_token))):
case Success((orders, maybe_pagination_token)):
for order in orders:
order.links.append(self.order_link(request, order))
links.append(self.pagination_link(request, pagination_token))
case Success((orders, Nothing)): # noqa: F841
for order in orders:
order.links.append(self.order_link(request, order))
order.links.extend(self.order_links(order, request))
match maybe_pagination_token:
case Some(x):
links.append(self.pagination_link(request, x, limit))
case Maybe.empty:
pass
case Failure(ValueError()):
raise NotFoundException(detail="Error finding pagination token")
case Failure(e):
Expand All @@ -210,7 +211,7 @@ async def get_order(self: Self, order_id: str, request: Request) -> Order:
"""
match await self._get_order(order_id, request):
case Success(Some(order)):
self.add_order_links(order, request)
order.links.extend(self.order_links(order, request))
return order
case Success(Maybe.empty):
raise NotFoundException("Order not found")
Expand Down Expand Up @@ -238,7 +239,7 @@ async def get_order_statuses(
match await self._get_order_statuses(order_id, next, limit, request):
case Success((statuses, Some(pagination_token))):
links.append(self.order_statuses_link(request, order_id))
links.append(self.pagination_link(request, pagination_token))
links.append(self.pagination_link(request, pagination_token, limit))
case Success((statuses, Nothing)): # noqa: F841
links.append(self.order_statuses_link(request, order_id))
case Failure(KeyError()):
Expand Down Expand Up @@ -271,28 +272,19 @@ def generate_order_statuses_href(
) -> URL:
return request.url_for(f"{self.name}:list-order-statuses", order_id=order_id)

def add_order_links(self, order: Order, request: Request):
order.links.append(
def order_links(self, order: Order, request: Request) -> list[Link]:
return [
Link(
href=str(self.generate_order_href(request, order.id)),
rel="self",
type=TYPE_GEOJSON,
)
)
order.links.append(
),
Link(
href=str(self.generate_order_statuses_href(request, order.id)),
rel="monitor",
type=TYPE_JSON,
),
)

def order_link(self, request: Request, order: Order):
return Link(
href=str(request.url_for(f"{self.name}:get-order", order_id=order.id)),
rel="self",
type=TYPE_JSON,
)
]

def order_statuses_link(self, request: Request, order_id: str):
return Link(
Expand All @@ -306,9 +298,11 @@ def order_statuses_link(self, request: Request, order_id: str):
type=TYPE_JSON,
)

def pagination_link(self, request: Request, pagination_token: str):
def pagination_link(self, request: Request, pagination_token: str, limit: int):
return Link(
href=str(request.url.include_query_params(next=pagination_token)),
href=str(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include the limit explicitly, so it's not just relying on whatever the default is

request.url.include_query_params(next=pagination_token, limit=limit)
),
rel="next",
type=TYPE_JSON,
)
7 changes: 6 additions & 1 deletion tests/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
mock_get_order_statuses,
mock_get_orders,
)
from tests.shared import InMemoryOrderDB, mock_product_test_spotlight
from tests.shared import (
InMemoryOrderDB,
mock_product_test_satellite_provider,
mock_product_test_spotlight,
)


@asynccontextmanager
Expand All @@ -35,5 +39,6 @@ async def lifespan(app: FastAPI) -> AsyncIterator[dict[str, Any]]:
conformances=[CORE],
)
root_router.add_product(mock_product_test_spotlight)
root_router.add_product(mock_product_test_satellite_provider)
app: FastAPI = FastAPI(lifespan=lifespan)
app.include_router(root_router, prefix="")
4 changes: 2 additions & 2 deletions tests/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from stapi_fastapi.models.opportunity import (
Opportunity,
OpportunityRequest,
OpportunityPayload,
)
from stapi_fastapi.models.order import (
Order,
Expand Down Expand Up @@ -76,7 +76,7 @@ async def mock_get_order_statuses(

async def mock_search_opportunities(
product_router: ProductRouter,
search: OpportunityRequest,
search: OpportunityPayload,
next: str | None,
limit: int,
request: Request,
Expand Down
29 changes: 17 additions & 12 deletions tests/shared.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import Any, Literal, Self
from urllib.parse import parse_qs, urlparse
from uuid import uuid4

from fastapi import status
Expand Down Expand Up @@ -133,7 +134,7 @@ def create_mock_opportunity() -> Opportunity:

def pagination_tester(
stapi_client: TestClient,
endpoint: str,
url: str,
method: str,
limit: int,
target: str,
Expand All @@ -142,7 +143,7 @@ def pagination_tester(
) -> None:
retrieved = []

res = make_request(stapi_client, endpoint, method, body, None, limit)
res = make_request(stapi_client, url, method, body, limit)
assert res.status_code == status.HTTP_200_OK
resp_body = res.json()

Expand All @@ -151,15 +152,16 @@ def pagination_tester(
next_url = next((d["href"] for d in resp_body["links"] if d["rel"] == "next"), None)

while next_url:
url = next_url
if method == "POST":
body = next(
(d["body"] for d in resp_body["links"] if d["rel"] == "next"), None
)

res = make_request(stapi_client, url, method, body, next_url, limit)
assert res.status_code == status.HTTP_200_OK
res = make_request(stapi_client, next_url, method, body, limit)

assert res.status_code == status.HTTP_200_OK, res.status_code
assert len(resp_body[target]) <= limit

resp_body = res.json()
retrieved.extend(resp_body[target])

Expand All @@ -177,22 +179,25 @@ def pagination_tester(

def make_request(
stapi_client: TestClient,
endpoint: str,
url: str,
method: str,
body: dict | None,
next_token: str | None,
limit: int,
) -> Response:
"""request wrapper for pagination tests"""

match method:
case "GET":
if next_token: # extract pagination token
next_token = next_token.split("next=")[1]
params = {"next": next_token, "limit": limit}
res = stapi_client.get(endpoint, params=params)
o = urlparse(url)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re-wrote this to actually parse the url, since a simple split on next= doesn't work if there are any other parameters, e.g., next=foo&limit=1, since the next_token value gets set to foo&limit=1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh shoot wish I had written the tests that would've caught that case in the first place

base_url = f"{o.scheme}://{o.netloc}{o.path}"
parsed_qs = parse_qs(o.query)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't know about this parse_qs function, that would've been helpful early only womp womp. The more you know.

params = {}
if "next" in parsed_qs:
params["next"] = parsed_qs["next"][0]
params["limit"] = int(parsed_qs.get("limit", [None])[0] or limit)
res = stapi_client.get(base_url, params=params)
case "POST":
res = stapi_client.post(endpoint, json=body)
res = stapi_client.post(url, json=body)
case _:
fail(f"method {method} not supported in make request")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_opportunity.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_search_opportunities_pagination(

pagination_tester(
stapi_client=stapi_client,
endpoint=f"/products/{product_id}/opportunities",
url=f"/products/{product_id}/opportunities",
method="POST",
limit=limit,
target="features",
Expand Down
16 changes: 10 additions & 6 deletions tests/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,20 @@ def test_get_orders_pagination(
limit, setup_orders_pagination, create_order_payloads, stapi_client: TestClient
) -> None:
expected_returns = []
if limit != 0:
if limit > 0:
for order in setup_orders_pagination:
json_link = copy.deepcopy(order["links"][0])
json_link["type"] = "application/json"
order["links"].append(json_link)
self_link = copy.deepcopy(order["links"][0])
order["links"].append(self_link)
monitor_link = copy.deepcopy(order["links"][0])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was missing from the /orders endpoint Orders and was added, so the test was updated

monitor_link["rel"] = "monitor"
monitor_link["type"] = "application/json"
monitor_link["href"] = monitor_link["href"] + "/statuses"
order["links"].append(monitor_link)
expected_returns.append(order)

pagination_tester(
stapi_client=stapi_client,
endpoint="/orders",
url="/orders",
method="GET",
limit=limit,
target="features",
Expand Down Expand Up @@ -233,7 +237,7 @@ def test_get_order_status_pagination(

pagination_tester(
stapi_client=stapi_client,
endpoint=f"/orders/{order_id}/statuses",
url=f"/orders/{order_id}/statuses",
method="GET",
limit=limit,
target="statuses",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_get_products_pagination(

pagination_tester(
stapi_client=stapi_client,
endpoint="/products",
url="/products",
method="GET",
limit=limit,
target="products",
Expand Down