diff --git a/src/requestmodel/__init__.py b/src/requestmodel/__init__.py index 8f28246..cbb2b29 100644 --- a/src/requestmodel/__init__.py +++ b/src/requestmodel/__init__.py @@ -1,4 +1,4 @@ -from typing import Any, get_origin +from typing import Any from typing import ClassVar from typing import Dict from typing import Generic @@ -7,13 +7,9 @@ from typing import Set from typing import Type from typing import TypeVar -from typing import get_args -from fastapi._compat import ( - field_annotation_is_scalar, - field_annotation_is_complex, - field_annotation_is_sequence, -) +from fastapi._compat import field_annotation_is_complex +from fastapi._compat import field_annotation_is_sequence from fastapi.utils import get_path_param_names from httpx import AsyncClient from httpx import Client @@ -22,9 +18,11 @@ from httpx._client import BaseClient from pydantic import BaseModel from pydantic import ConfigDict -from pydantic._internal._model_construction import ModelMetaclass from pydantic.fields import FieldInfo -from typing_extensions import get_type_hints, Annotated +from typing_extensions import Annotated +from typing_extensions import get_args +from typing_extensions import get_origin +from typing_extensions import get_type_hints from typing_extensions import override from requestmodel import params @@ -51,6 +49,7 @@ def get_annotated_type( is_complex = field_annotation_is_complex(origin) is_sequence = field_annotation_is_sequence(origin) else: + origin = variable_type is_complex = field_annotation_is_complex(variable_type) is_sequence = field_annotation_is_sequence(variable_type) @@ -69,14 +68,39 @@ def get_annotated_type( if isinstance(annotated_property, scalar_types) and is_complex: # query params do accept lists if not (isinstance(annotated_property, params.Query) and is_sequence): + annotated_name = annotated_property.__class__.__name__ + + # in 3.8 & 3.9 Dict does not have a __name__ + if not hasattr(origin, "__name__"): + origin = get_origin(origin) raise ValueError( - f"`{variable_key}` annotated as {annotated_property.__class__.__name__} " + f"`{variable_key}` annotated as {annotated_name} " f"can only be a scalar, not a `{origin.__name__}`" ) return annotated_property +def flatten_body(request_args: RequestArgs) -> None: + body: Dict[str, Any] = {} + for field_name, field_value in request_args[params.Body].items(): + body[field_name] = field_value + request_args[params.Body] = body + + +def unify_body( + annotated_property: Any, key: str, request_args: RequestArgs, value: Any +) -> None: + if isinstance(value, dict): + if annotated_property.embed: + request_args[type(annotated_property)][key] = value + else: + for nested_key, nested_value in value.items(): + request_args[type(annotated_property)][nested_key] = nested_value + else: + request_args[type(annotated_property)][key] = value + + class RequestModel(BaseModel, Generic[ResponseType]): """Declarative way to define a model""" @@ -94,10 +118,7 @@ def as_request(self, client: BaseClient) -> Request: request_args = self.request_args_for_values() - _params = request_args[params.Query] headers = request_args[params.Header] - cookies = request_args[params.Cookie] - files = request_args[params.File] body = request_args[params.Body] is_json_request = "json" in headers.get("content-type", "") @@ -105,10 +126,10 @@ def as_request(self, client: BaseClient) -> Request: r = Request( method=self.method, url=client._merge_url(self.url.format(**request_args[params.Path])), - params=_params, + params=request_args[params.Query], headers=headers, - cookies=cookies, - files=files, + cookies=request_args[params.Cookie], + files=request_args[params.File], data=body if not is_json_request else None, json=body if is_json_request else None, ) @@ -150,25 +171,11 @@ def request_args_for_values(self) -> RequestArgs: key = key.replace("_", "-") if isinstance(annotated_property, params.Body): - if isinstance(value, dict): - if annotated_property.embed: - request_args[type(annotated_property)][key] = value - else: - for nested_key, nested_value in value.items(): - request_args[type(annotated_property)][ - nested_key - ] = nested_value - else: - request_args[type(annotated_property)][key] = value + unify_body(annotated_property, key, request_args, value) else: request_args[type(annotated_property)][key] = value - body: Dict[str, Any] = {} - - for field_name, field_value in request_args[params.Body].items(): - body[field_name] = field_value - - request_args[params.Body] = body + flatten_body(request_args) return request_args diff --git a/tests/flask_server/__init__.py b/tests/flask_server/__init__.py index 57a604f..53eab50 100644 --- a/tests/flask_server/__init__.py +++ b/tests/flask_server/__init__.py @@ -28,4 +28,4 @@ def submit() -> Response: return jsonify({"errors": form.errors}) -client = TestClient(WSGIMiddleware(app)) # type: ignore +client = TestClient(WSGIMiddleware(app)) # type: ignore[arg-type] diff --git a/tests/test_request_model.py b/tests/test_request_model.py index 7db87b9..092e960 100644 --- a/tests/test_request_model.py +++ b/tests/test_request_model.py @@ -1,19 +1,19 @@ -import typing -from typing import Any, List, Dict +from typing import Any from typing import ClassVar +from typing import Dict +from typing import List from typing import Optional from typing import Type import pytest -from fastapi import params, Query -from fastapi._compat import ( - field_annotation_is_scalar, - field_annotation_is_sequence, - field_annotation_is_complex, -) -from httpx import Client -from pydantic import BaseModel, ValidationError -from typing_extensions import Annotated, get_origin, get_args +from fastapi import params +from fastapi._compat import field_annotation_is_scalar +from fastapi._compat import field_annotation_is_sequence +from pydantic import BaseModel +from pydantic import ValidationError +from typing_extensions import Annotated +from typing_extensions import get_args +from typing_extensions import get_origin from typing_extensions import get_type_hints from requestmodel import RequestModel @@ -94,7 +94,7 @@ class AnnotatedType(RequestModel[Any]): assert isinstance(get_annotated_type("g", hints["g"]), params.Query) -def test_annotated_type(): +def test_annotated_type() -> None: class SimpleResponse(BaseModel): data: str @@ -111,7 +111,7 @@ class SimpleResponse(BaseModel): with pytest.raises( ValueError, - match="`y` annotated as Query can only be a scalar, not a `Dict`", + match="`y` annotated as Query can only be a scalar, not a `(.)ict`", ): get_annotated_type("y", Annotated[Dict[str, str], params.Query()]) @@ -132,16 +132,16 @@ class SimpleResponse(BaseModel): # assert field_annotation_is_complex(SimpleResponse) is True assert field_annotation_is_scalar(Annotated[SimpleResponse, params.Path()]) is True - # assert field_annotation_is_scalar(SimpleResponse) is False - # assert field_annotation_is_scalar(str) is True - # assert field_annotation_is_scalar(Annotated[str, params.Path()]) is True - # assert field_annotation_is_scalar(Annotated[SimpleResponse, params.Path()]) is True + assert field_annotation_is_scalar(SimpleResponse) is False + assert field_annotation_is_scalar(str) is True + assert field_annotation_is_scalar(Annotated[str, params.Path()]) is True + assert field_annotation_is_scalar(Annotated[SimpleResponse, params.Path()]) is True def test_field_annotation_with_constraints() -> None: class SimpleRequest(RequestModel[Any]): - url = "test" - method = "test" + url: ClassVar[str] = "test" + method: ClassVar[str] = "test" data: Annotated[str, params.Query(min_length=8, max_length=10)] with pytest.raises( @@ -152,19 +152,19 @@ class SimpleRequest(RequestModel[Any]): def test_field_unified_body() -> None: class SimpleRequest(RequestModel[Any]): - url = "test" - method = "test" + url: ClassVar[str] = "test" + method: ClassVar[str] = "test" query_list: Annotated[List[int], params.Query()] data_str: Annotated[str, params.Body()] data_int: Annotated[int, params.Body()] data_list: Annotated[List[int], params.Body()] data_dict: Annotated[Dict[str, int], params.Body(embed=True)] - data = dict( + data: Dict[str, Any] = dict( data_str="test", data_int=1, data_list=[0, 1, 2], data_dict={"key": 1925} ) - r = SimpleRequest(**data, query_list=[1, 2, 3]) + r = SimpleRequest(query_list=[1, 2, 3], **data) x = r.request_args_for_values() @@ -178,14 +178,14 @@ class SimpleRequest(RequestModel[Any]): } -def test_get_origin(): +def test_get_origin() -> None: assert get_origin(str) is None assert get_origin(List[str]) is list assert get_origin(Annotated[List[str], params.Query()]) is Annotated assert get_origin(Annotated[str, params.Query()]) is Annotated -def test_get_args(): +def test_get_args() -> None: assert get_args(List[str]) == (str,) p = get_args(Annotated[List[str], params.Query()]) assert p[0] == List[str]