Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
FurqanHabibi committed Sep 26, 2024
1 parent f7d6dae commit cb6c3f2
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 10 deletions.
2 changes: 1 addition & 1 deletion mangum/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict:
if self.lifespan in ("auto", "on"):
lifespan_cycle = LifespanCycle(self.app, self.lifespan)
stack.enter_context(lifespan_cycle)
scope |= {"state": lifespan_cycle.lifespan_state.copy()}
scope.update({"state": lifespan_cycle.lifespan_state.copy()})

http_cycle = HTTPCycle(scope, handler.body)
http_response = http_cycle(self.app)
Expand Down
12 changes: 9 additions & 3 deletions mangum/protocols/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
from io import BytesIO

from mangum.types import ASGI, Message, Scope, Response
from mangum.exceptions import UnexpectedMessage
from mangum.types import ASGI, Message, Response, Scope


class HTTPCycleState(enum.Enum):
Expand Down Expand Up @@ -82,11 +82,17 @@ async def receive(self) -> Message:
return await self.app_queue.get() # pragma: no cover

async def send(self, message: Message) -> None:
if self.state is HTTPCycleState.REQUEST and message["type"] == "http.response.start":
if (
self.state is HTTPCycleState.REQUEST
and message["type"] == "http.response.start"
):
self.status = message["status"]
self.headers = message.get("headers", [])
self.state = HTTPCycleState.RESPONSE
elif self.state is HTTPCycleState.RESPONSE and message["type"] == "http.response.body":
elif (
self.state is HTTPCycleState.RESPONSE
and message["type"] == "http.response.body"
):
body = message.get("body", b"")
more_body = message.get("more_body", False)
self.buffer.write(body)
Expand Down
6 changes: 3 additions & 3 deletions mangum/protocols/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import enum
import logging
from types import TracebackType
from typing import Optional, Type, Any
from typing import Any, Dict, Optional, Type

from mangum.exceptions import LifespanFailure, LifespanUnsupported, UnexpectedMessage
from mangum.types import ASGI, LifespanMode, Message
from mangum.exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage


class LifespanCycleState(enum.Enum):
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(self, app: ASGI, lifespan: LifespanMode) -> None:
self.startup_event: asyncio.Event = asyncio.Event()
self.shutdown_event: asyncio.Event = asyncio.Event()
self.logger = logging.getLogger("mangum.lifespan")
self.lifespan_state: dict[str, Any] = {}
self.lifespan_state: Dict[str, Any] = {}

def __enter__(self) -> None:
"""Runs the event loop for application startup."""
Expand Down
56 changes: 53 additions & 3 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import logging

import pytest

from quart import Quart
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse

from mangum import Mangum
from mangum.exceptions import LifespanFailure

from quart import Quart


@pytest.mark.parametrize(
"mock_aws_api_gateway_event,lifespan",
Expand Down Expand Up @@ -211,6 +209,58 @@ async def app(scope, receive, send):
handler(mock_aws_api_gateway_event, {})


@pytest.mark.parametrize(
"mock_aws_api_gateway_event,lifespan_state,lifespan",
[
(["GET", None, None], {"test_key": "test_value"}, "auto"),
(["GET", None, None], {"test_key": "test_value"}, "on"),
],
indirect=["mock_aws_api_gateway_event"],
)
def test_lifespan_state(mock_aws_api_gateway_event, lifespan_state, lifespan) -> None:
startup_complete = False
shutdown_complete = False

async def app(scope, receive, send):
nonlocal startup_complete, shutdown_complete

if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":
scope["state"].update(lifespan_state)
await send({"type": "lifespan.startup.complete"})
startup_complete = True
elif message["type"] == "lifespan.shutdown":
await send({"type": "lifespan.shutdown.complete"})
shutdown_complete = True
return

if scope["type"] == "http":
assert lifespan_state.items() <= scope["state"].items()
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain; charset=utf-8"]],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})

handler = Mangum(app, lifespan=lifespan)
response = handler(mock_aws_api_gateway_event, {})

assert startup_complete
assert shutdown_complete
assert response == {
"statusCode": 200,
"isBase64Encoded": False,
"headers": {"content-type": "text/plain; charset=utf-8"},
"multiValueHeaders": {},
"body": "Hello, world!",
}


@pytest.mark.parametrize("mock_aws_api_gateway_event", [["GET", None, None]], indirect=True)
def test_starlette_lifespan(mock_aws_api_gateway_event) -> None:
startup_complete = False
Expand Down

0 comments on commit cb6c3f2

Please sign in to comment.