Skip to content

Commit

Permalink
Fix recorded path in mounted Starlette apps (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
itssimon authored Jan 1, 2025
1 parent 2e167e2 commit ac428c6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 38 deletions.
2 changes: 1 addition & 1 deletion apitally/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def get_path(request: Request) -> Optional[str]:
for route in request.app.routes:
match, _ = route.matches(request.scope)
if match == Match.FULL:
return route.path
return request.scope.get("root_path", "") + route.path
return None

def get_consumer(self, request: Request) -> Optional[ApitallyConsumer]:
Expand Down
86 changes: 49 additions & 37 deletions tests/test_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def app(request: FixtureRequest, module_mocker: MockerFixture) -> Starlett
def get_starlette_app() -> Starlette:
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Route
from starlette.routing import Mount, Route

from apitally.starlette import ApitallyConsumer, ApitallyMiddleware, RequestLoggingConfig

Expand Down Expand Up @@ -79,16 +79,22 @@ def task_func_with_error():
def identify_consumer(request: Request) -> Optional[ApitallyConsumer]:
return ApitallyConsumer("test", name="Test")

routes = [
Route("/foo/", foo),
Route("/foo/{bar}/", foo_bar),
Route("/bar/", bar, methods=["POST"]),
Route("/baz/", baz, methods=["POST"]),
Route("/val/", val),
Route("/stream/", stream),
Route("/task/", task, methods=["POST"]),
]
app = Starlette(routes=routes)
sub_app = Starlette(
routes=[
Route("/foo", foo),
Route("/foo/{bar}", foo_bar),
Route("/bar", bar, methods=["POST"]),
Route("/baz", baz, methods=["POST"]),
Route("/val", val),
]
)
app = Starlette(
routes=[
Mount("/api", sub_app),
Route("/stream", stream),
Route("/task", task, methods=["POST"]),
]
)
app.add_middleware(
ApitallyMiddleware,
client_id=CLIENT_ID,
Expand All @@ -104,7 +110,7 @@ def identify_consumer(request: Request) -> Optional[ApitallyConsumer]:


def get_fastapi_app() -> Starlette:
from fastapi import FastAPI, Query
from fastapi import APIRouter, FastAPI, Query
from fastapi.responses import PlainTextResponse, StreamingResponse

from apitally.fastapi import ApitallyConsumer, ApitallyMiddleware, RequestLoggingConfig
Expand All @@ -125,43 +131,47 @@ def identify_consumer(request: Request) -> Optional[ApitallyConsumer]:
identify_consumer_callback=identify_consumer,
)

@app.get("/foo/")
router = APIRouter()

@router.get("/foo")
def foo():
return "foo"

@app.get("/foo/{bar}/")
@router.get("/foo/{bar}")
def foo_bar(bar: str):
return PlainTextResponse(f"foo: {bar}")

@app.post("/bar/")
@router.post("/bar")
async def bar(request: Request):
body = await request.body()
return PlainTextResponse("bar: " + body.decode())

@app.post("/baz/")
@router.post("/baz")
def baz():
raise ValueError("baz")

@app.get("/val/")
@router.get("/val")
def val(foo: int = Query()):
return "val"

@app.get("/stream/")
@app.get("/stream")
def stream():
def stream_response():
yield b"foo"
yield b"bar"

return StreamingResponse(stream_response())

@app.post("/task/")
@app.post("/task")
def task(background_tasks: BackgroundTasks):
def task_func_with_error():
raise ValueError("task")

background_tasks.add_task(task_func_with_error)
return "ok"

app.include_router(router, prefix="/api")

return app


Expand All @@ -171,23 +181,23 @@ def test_middleware_requests_ok(app: Starlette, mocker: MockerFixture):
mock = mocker.patch("apitally.client.requests.RequestCounter.add_request")
client = TestClient(app)

response = client.get("/foo/")
response = client.get("/api/foo/")
assert response.status_code == 200
mock.assert_called_once()
assert mock.call_args is not None
assert mock.call_args.kwargs["consumer"] == "test"
assert mock.call_args.kwargs["method"] == "GET"
assert mock.call_args.kwargs["path"] == "/foo/"
assert mock.call_args.kwargs["path"] == "/api/foo"
assert mock.call_args.kwargs["status_code"] == 200
assert mock.call_args.kwargs["response_time"] > 0

response = client.get("/foo/123/")
response = client.get("/api/foo/123/")
assert response.status_code == 200
assert mock.call_count == 2
assert mock.call_args is not None
assert mock.call_args.kwargs["path"] == "/foo/{bar}/"
assert mock.call_args.kwargs["path"] == "/api/foo/{bar}"

response = client.post("/bar/")
response = client.post("/api/bar/")
assert response.status_code == 200
assert mock.call_count == 3
assert mock.call_args is not None
Expand All @@ -207,12 +217,12 @@ def test_middleware_requests_error(app: Starlette, mocker: MockerFixture):
mock2 = mocker.patch("apitally.client.server_errors.ServerErrorCounter.add_server_error")
client = TestClient(app, raise_server_exceptions=False)

response = client.post("/baz/")
response = client.post("/api/baz")
assert response.status_code == 500
mock1.assert_called_once()
assert mock1.call_args is not None
assert mock1.call_args.kwargs["method"] == "POST"
assert mock1.call_args.kwargs["path"] == "/baz/"
assert mock1.call_args.kwargs["path"] == "/api/baz"
assert mock1.call_args.kwargs["status_code"] == 500
assert mock1.call_args.kwargs["response_time"] > 0

Expand All @@ -222,7 +232,7 @@ def test_middleware_requests_error(app: Starlette, mocker: MockerFixture):
assert isinstance(exception, ValueError)

# Throws a ValueError in a background task, but returns 200
response = client.post("/task/")
response = client.post("/task")
assert response.status_code == 200
assert mock1.call_count == 2
assert mock1.call_args is not None
Expand All @@ -236,7 +246,7 @@ def test_middleware_requests_unhandled(app: Starlette, mocker: MockerFixture):
mock = mocker.patch("apitally.client.requests.RequestCounter.add_request")
client = TestClient(app)

response = client.post("/xxx/")
response = client.post("/xxx")
assert response.status_code == 404
mock.assert_not_called()

Expand All @@ -248,15 +258,15 @@ def test_middleware_validation_error(app: Starlette, mocker: MockerFixture):
client = TestClient(app)

# Validation error as foo must be an integer
response = client.get("/val?foo=bar")
response = client.get("/api/val?foo=bar")
assert response.status_code == 422

# FastAPI only
if response.headers["Content-Type"] == "application/json":
mock.assert_called_once()
assert mock.call_args is not None
assert mock.call_args.kwargs["method"] == "GET"
assert mock.call_args.kwargs["path"] == "/val/"
assert mock.call_args.kwargs["path"] == "/api/val"
assert len(mock.call_args.kwargs["detail"]) == 1
assert mock.call_args.kwargs["detail"][0]["loc"] == ["query", "foo"]

Expand All @@ -269,13 +279,13 @@ def test_middleware_request_logging(app: Starlette, mocker: MockerFixture):
mock = mocker.patch("apitally.client.request_logging.RequestLogger.log_request")
client = TestClient(app)

response = client.get("/foo/123/?foo=bar", headers={"Test-Header": "test"})
response = client.get("/api/foo/123?foo=bar", headers={"Test-Header": "test"})
assert response.status_code == 200
mock.assert_called_once()
assert mock.call_args is not None
assert mock.call_args.kwargs["request"]["method"] == "GET"
assert mock.call_args.kwargs["request"]["path"] == "/foo/{bar}/"
assert mock.call_args.kwargs["request"]["url"] == "http://testserver/foo/123/?foo=bar"
assert mock.call_args.kwargs["request"]["path"] == "/api/foo/{bar}"
assert mock.call_args.kwargs["request"]["url"] == "http://testserver/api/foo/123?foo=bar"
assert ("test-header", "test") in mock.call_args.kwargs["request"]["headers"]
assert mock.call_args.kwargs["request"]["consumer"] == "test"
assert mock.call_args.kwargs["response"]["status_code"] == 200
Expand All @@ -284,18 +294,18 @@ def test_middleware_request_logging(app: Starlette, mocker: MockerFixture):
assert mock.call_args.kwargs["response"]["size"] > 0
assert mock.call_args.kwargs["response"]["body"] == b"foo: 123"

response = client.post("/bar/", content=b"foo")
response = client.post("/api/bar", content=b"foo")
assert response.status_code == 200
assert mock.call_count == 2
assert mock.call_args is not None
assert mock.call_args.kwargs["request"]["method"] == "POST"
assert mock.call_args.kwargs["request"]["path"] == "/bar/"
assert mock.call_args.kwargs["request"]["url"] == "http://testserver/bar/"
assert mock.call_args.kwargs["request"]["path"] == "/api/bar"
assert mock.call_args.kwargs["request"]["url"] == "http://testserver/api/bar"
assert mock.call_args.kwargs["request"]["body"] == b"foo"
assert mock.call_args.kwargs["response"]["body"] == b"bar: foo"

mocker.patch("apitally.starlette.MAX_BODY_SIZE", 2)
response = client.post("/bar/", content=b"foo")
response = client.post("/api/bar", content=b"foo")
assert response.status_code == 200
assert mock.call_count == 3
assert mock.call_args is not None
Expand All @@ -312,6 +322,8 @@ def test_get_startup_data(app: Starlette, mocker: MockerFixture):

data = _get_startup_data(app=app.middleware_stack, app_version="1.2.3", openapi_url=None)
assert len(data["paths"]) == 7
assert {"method": "get", "path": "/api/foo"} in data["paths"]
assert {"method": "get", "path": "/stream"} in data["paths"]
assert data["versions"]["starlette"]
assert data["versions"]["app"] == "1.2.3"
assert data["client"] == "python:starlette"

0 comments on commit ac428c6

Please sign in to comment.