Skip to content

Commit

Permalink
add some missing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut committed Jan 16, 2025
1 parent ce165bc commit 04c3c2a
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 1 deletion.
2 changes: 1 addition & 1 deletion litestar/handlers/websocket_handlers/route_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async def handle(self, connection: WebSocket[Any, Any, Any]) -> None:
parsed_kwargs: dict[str, Any] = {}
cleanup_group: DependencyCleanupGroup | None = None

if handler_kwargs_model.has_kwargs and self.signature_model:
if handler_kwargs_model.has_kwargs:
parsed_kwargs = await handler_kwargs_model.to_kwargs(connection=connection)

if handler_kwargs_model.dependency_batches:
Expand Down
81 changes: 81 additions & 0 deletions tests/unit/test_handlers/test_base_handlers/test_resolution.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Awaitable, Callable
from unittest.mock import AsyncMock

from litestar import Controller, Litestar, Router, get
from litestar.di import Provide
from litestar.params import Parameter


def test_resolve_dependencies_without_provide() -> None:
Expand Down Expand Up @@ -52,3 +54,82 @@ async def handler(self) -> None:
"controller": Provide(controller_dependency),
"handler": Provide(handler_dependency),
}


def test_resolve_type_encoders() -> None:
@get("/", type_encoders={int: str})
def handler() -> None:
pass

assert handler.resolve_type_encoders() == {int: str}


def test_resolve_type_decoders() -> None:
type_decoders = [(lambda t: True, lambda v, t: t)]

@get("/", type_decoders=type_decoders)
def handler() -> None:
pass

assert handler.resolve_type_decoders() == type_decoders


def test_resolve_parameters() -> None:
parameters = {"foo": Parameter()}

@get("/")
def handler() -> None:
pass

handler = handler.merge(Router("/", parameters=parameters, route_handlers=[]))
assert handler.resolve_layered_parameters() == handler.parameter_field_definitions


def test_resolve_guards() -> None:
guard = AsyncMock()

@get("/", guards=[guard])
def handler() -> None:
pass

assert handler.resolve_guards() == (guard,)


def test_resolve_dependencies() -> None:
dependency = AsyncMock()

@get("/", dependencies={"foo": dependency})
def handler() -> None:
pass

assert handler.resolve_dependencies() == handler.dependencies


def test_resolve_middleware() -> None:
middleware = AsyncMock()

@get("/", middleware=[middleware])
def handler() -> None:
pass

assert handler.resolve_middleware() == handler.middleware


def test_exception_handlers() -> None:
exception_handler = AsyncMock()

@get("/", exception_handlers={ValueError: exception_handler})
def handler() -> None:
pass

assert handler.resolve_exception_handlers() == {ValueError: exception_handler}


def test_resolve_signature_namespace() -> None:
namespace = {"foo": object()}

@get("/", signature_namespace=namespace)
def handler() -> None:
pass

assert handler.resolve_signature_namespace() == namespace
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import List
from unittest.mock import MagicMock

import pytest

from litestar import Controller, Router, WebSocket, websocket
from litestar.exceptions import ImproperlyConfiguredException
from litestar.testing import create_test_client


Expand Down Expand Up @@ -49,3 +53,12 @@ async def simple_websocket_handler(
ws.send_json({"data": "123"})
data = ws.receive_json()
assert data == {"a": 1, "b": "two", "c": 3.0, "d": ["d"]}


async def test_not_finalized_raises() -> None:
@websocket("/")
async def handler(socket: WebSocket) -> None:
pass

with pytest.raises(ImproperlyConfiguredException, match="handler parameter model not defined"):
await handler.handle(MagicMock())
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from litestar import WebSocket, websocket


def test_resolve_websocket_class() -> None:
@websocket()
async def handler(socket: WebSocket) -> None:
pass

assert handler.resolve_websocket_class() is handler.websocket_class

0 comments on commit 04c3c2a

Please sign in to comment.