diff --git a/mangum/adapter.py b/mangum/adapter.py index 60670e4..ff41519 100644 --- a/mangum/adapter.py +++ b/mangum/adapter.py @@ -76,7 +76,7 @@ def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict: if self.lifespan in ("auto", "on"): lifespan_cycle = LifespanCycle(self.app, self.lifespan) stack.enter_context(lifespan_cycle) - scope |= {"state": lifespan_cycle.lifespan_state.copy()} + scope.update({"state": lifespan_cycle.lifespan_state.copy()}) http_cycle = HTTPCycle(scope, handler.body) http_response = http_cycle(self.app) diff --git a/mangum/protocols/http.py b/mangum/protocols/http.py index fc95787..6ab6b37 100644 --- a/mangum/protocols/http.py +++ b/mangum/protocols/http.py @@ -3,8 +3,8 @@ import logging from io import BytesIO -from mangum.types import ASGI, Message, Scope, Response from mangum.exceptions import UnexpectedMessage +from mangum.types import ASGI, Message, Response, Scope class HTTPCycleState(enum.Enum): @@ -82,11 +82,17 @@ async def receive(self) -> Message: return await self.app_queue.get() # pragma: no cover async def send(self, message: Message) -> None: - if self.state is HTTPCycleState.REQUEST and message["type"] == "http.response.start": + if ( + self.state is HTTPCycleState.REQUEST + and message["type"] == "http.response.start" + ): self.status = message["status"] self.headers = message.get("headers", []) self.state = HTTPCycleState.RESPONSE - elif self.state is HTTPCycleState.RESPONSE and message["type"] == "http.response.body": + elif ( + self.state is HTTPCycleState.RESPONSE + and message["type"] == "http.response.body" + ): body = message.get("body", b"") more_body = message.get("more_body", False) self.buffer.write(body) diff --git a/mangum/protocols/lifespan.py b/mangum/protocols/lifespan.py index 193ebcc..8d7fe05 100644 --- a/mangum/protocols/lifespan.py +++ b/mangum/protocols/lifespan.py @@ -2,10 +2,10 @@ import enum import logging from types import TracebackType -from typing import Optional, Type, Any +from typing import Any, Dict, Optional, Type +from mangum.exceptions import LifespanFailure, LifespanUnsupported, UnexpectedMessage from mangum.types import ASGI, LifespanMode, Message -from mangum.exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage class LifespanCycleState(enum.Enum): @@ -62,7 +62,7 @@ def __init__(self, app: ASGI, lifespan: LifespanMode) -> None: self.startup_event: asyncio.Event = asyncio.Event() self.shutdown_event: asyncio.Event = asyncio.Event() self.logger = logging.getLogger("mangum.lifespan") - self.lifespan_state: dict[str, Any] = {} + self.lifespan_state: Dict[str, Any] = {} def __enter__(self) -> None: """Runs the event loop for application startup.""" diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py index ac09aa4..12a98ef 100644 --- a/tests/test_lifespan.py +++ b/tests/test_lifespan.py @@ -1,15 +1,13 @@ import logging import pytest - +from quart import Quart from starlette.applications import Starlette from starlette.responses import PlainTextResponse from mangum import Mangum from mangum.exceptions import LifespanFailure -from quart import Quart - @pytest.mark.parametrize( "mock_aws_api_gateway_event,lifespan", @@ -211,6 +209,58 @@ async def app(scope, receive, send): handler(mock_aws_api_gateway_event, {}) +@pytest.mark.parametrize( + "mock_aws_api_gateway_event,lifespan_state,lifespan", + [ + (["GET", None, None], {"test_key": "test_value"}, "auto"), + (["GET", None, None], {"test_key": "test_value"}, "on"), + ], + indirect=["mock_aws_api_gateway_event"], +) +def test_lifespan_state(mock_aws_api_gateway_event, lifespan_state, lifespan) -> None: + startup_complete = False + shutdown_complete = False + + async def app(scope, receive, send): + nonlocal startup_complete, shutdown_complete + + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + scope["state"].update(lifespan_state) + await send({"type": "lifespan.startup.complete"}) + startup_complete = True + elif message["type"] == "lifespan.shutdown": + await send({"type": "lifespan.shutdown.complete"}) + shutdown_complete = True + return + + if scope["type"] == "http": + assert lifespan_state.items() <= scope["state"].items() + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain; charset=utf-8"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + handler = Mangum(app, lifespan=lifespan) + response = handler(mock_aws_api_gateway_event, {}) + + assert startup_complete + assert shutdown_complete + assert response == { + "statusCode": 200, + "isBase64Encoded": False, + "headers": {"content-type": "text/plain; charset=utf-8"}, + "multiValueHeaders": {}, + "body": "Hello, world!", + } + + @pytest.mark.parametrize("mock_aws_api_gateway_event", [["GET", None, None]], indirect=True) def test_starlette_lifespan(mock_aws_api_gateway_event) -> None: startup_complete = False