Skip to content

Commit

Permalink
Fix typing issues for 3.8 & 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
foarsitter committed Oct 23, 2023
1 parent 83cd141 commit 6043e82
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 58 deletions.
71 changes: 39 additions & 32 deletions src/requestmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, get_origin
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import Generic
Expand All @@ -7,13 +7,9 @@
from typing import Set
from typing import Type
from typing import TypeVar
from typing import get_args

from fastapi._compat import (
field_annotation_is_scalar,
field_annotation_is_complex,
field_annotation_is_sequence,
)
from fastapi._compat import field_annotation_is_complex
from fastapi._compat import field_annotation_is_sequence
from fastapi.utils import get_path_param_names
from httpx import AsyncClient
from httpx import Client
Expand All @@ -22,9 +18,11 @@
from httpx._client import BaseClient
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic._internal._model_construction import ModelMetaclass
from pydantic.fields import FieldInfo
from typing_extensions import get_type_hints, Annotated
from typing_extensions import Annotated
from typing_extensions import get_args
from typing_extensions import get_origin
from typing_extensions import get_type_hints
from typing_extensions import override

from requestmodel import params
Expand All @@ -51,6 +49,7 @@ def get_annotated_type(
is_complex = field_annotation_is_complex(origin)
is_sequence = field_annotation_is_sequence(origin)
else:
origin = variable_type
is_complex = field_annotation_is_complex(variable_type)
is_sequence = field_annotation_is_sequence(variable_type)

Expand All @@ -69,14 +68,39 @@ def get_annotated_type(
if isinstance(annotated_property, scalar_types) and is_complex:
# query params do accept lists
if not (isinstance(annotated_property, params.Query) and is_sequence):
annotated_name = annotated_property.__class__.__name__

# in 3.8 & 3.9 Dict does not have a __name__
if not hasattr(origin, "__name__"):
origin = get_origin(origin)
raise ValueError(
f"`{variable_key}` annotated as {annotated_property.__class__.__name__} "
f"`{variable_key}` annotated as {annotated_name} "
f"can only be a scalar, not a `{origin.__name__}`"
)

return annotated_property


def flatten_body(request_args: RequestArgs) -> None:
body: Dict[str, Any] = {}
for field_name, field_value in request_args[params.Body].items():
body[field_name] = field_value
request_args[params.Body] = body


def unify_body(
annotated_property: Any, key: str, request_args: RequestArgs, value: Any
) -> None:
if isinstance(value, dict):
if annotated_property.embed:
request_args[type(annotated_property)][key] = value
else:
for nested_key, nested_value in value.items():
request_args[type(annotated_property)][nested_key] = nested_value
else:
request_args[type(annotated_property)][key] = value


class RequestModel(BaseModel, Generic[ResponseType]):
"""Declarative way to define a model"""

Expand All @@ -94,21 +118,18 @@ def as_request(self, client: BaseClient) -> Request:

request_args = self.request_args_for_values()

_params = request_args[params.Query]
headers = request_args[params.Header]
cookies = request_args[params.Cookie]
files = request_args[params.File]
body = request_args[params.Body]

is_json_request = "json" in headers.get("content-type", "")

r = Request(
method=self.method,
url=client._merge_url(self.url.format(**request_args[params.Path])),
params=_params,
params=request_args[params.Query],
headers=headers,
cookies=cookies,
files=files,
cookies=request_args[params.Cookie],
files=request_args[params.File],
data=body if not is_json_request else None,
json=body if is_json_request else None,
)
Expand Down Expand Up @@ -150,25 +171,11 @@ def request_args_for_values(self) -> RequestArgs:
key = key.replace("_", "-")

if isinstance(annotated_property, params.Body):
if isinstance(value, dict):
if annotated_property.embed:
request_args[type(annotated_property)][key] = value
else:
for nested_key, nested_value in value.items():
request_args[type(annotated_property)][
nested_key
] = nested_value
else:
request_args[type(annotated_property)][key] = value
unify_body(annotated_property, key, request_args, value)
else:
request_args[type(annotated_property)][key] = value

body: Dict[str, Any] = {}

for field_name, field_value in request_args[params.Body].items():
body[field_name] = field_value

request_args[params.Body] = body
flatten_body(request_args)

return request_args

Expand Down
2 changes: 1 addition & 1 deletion tests/flask_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ def submit() -> Response:
return jsonify({"errors": form.errors})


client = TestClient(WSGIMiddleware(app)) # type: ignore
client = TestClient(WSGIMiddleware(app)) # type: ignore[arg-type]
50 changes: 25 additions & 25 deletions tests/test_request_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import typing
from typing import Any, List, Dict
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import List
from typing import Optional
from typing import Type

import pytest
from fastapi import params, Query
from fastapi._compat import (
field_annotation_is_scalar,
field_annotation_is_sequence,
field_annotation_is_complex,
)
from httpx import Client
from pydantic import BaseModel, ValidationError
from typing_extensions import Annotated, get_origin, get_args
from fastapi import params
from fastapi._compat import field_annotation_is_scalar
from fastapi._compat import field_annotation_is_sequence
from pydantic import BaseModel
from pydantic import ValidationError
from typing_extensions import Annotated
from typing_extensions import get_args
from typing_extensions import get_origin
from typing_extensions import get_type_hints

from requestmodel import RequestModel
Expand Down Expand Up @@ -94,7 +94,7 @@ class AnnotatedType(RequestModel[Any]):
assert isinstance(get_annotated_type("g", hints["g"]), params.Query)


def test_annotated_type():
def test_annotated_type() -> None:
class SimpleResponse(BaseModel):
data: str

Expand All @@ -111,7 +111,7 @@ class SimpleResponse(BaseModel):

with pytest.raises(
ValueError,
match="`y` annotated as Query can only be a scalar, not a `Dict`",
match="`y` annotated as Query can only be a scalar, not a `(.)ict`",
):
get_annotated_type("y", Annotated[Dict[str, str], params.Query()])

Expand All @@ -132,16 +132,16 @@ class SimpleResponse(BaseModel):
# assert field_annotation_is_complex(SimpleResponse) is True
assert field_annotation_is_scalar(Annotated[SimpleResponse, params.Path()]) is True

# assert field_annotation_is_scalar(SimpleResponse) is False
# assert field_annotation_is_scalar(str) is True
# assert field_annotation_is_scalar(Annotated[str, params.Path()]) is True
# assert field_annotation_is_scalar(Annotated[SimpleResponse, params.Path()]) is True
assert field_annotation_is_scalar(SimpleResponse) is False
assert field_annotation_is_scalar(str) is True
assert field_annotation_is_scalar(Annotated[str, params.Path()]) is True
assert field_annotation_is_scalar(Annotated[SimpleResponse, params.Path()]) is True


def test_field_annotation_with_constraints() -> None:
class SimpleRequest(RequestModel[Any]):
url = "test"
method = "test"
url: ClassVar[str] = "test"
method: ClassVar[str] = "test"
data: Annotated[str, params.Query(min_length=8, max_length=10)]

with pytest.raises(
Expand All @@ -152,19 +152,19 @@ class SimpleRequest(RequestModel[Any]):

def test_field_unified_body() -> None:
class SimpleRequest(RequestModel[Any]):
url = "test"
method = "test"
url: ClassVar[str] = "test"
method: ClassVar[str] = "test"
query_list: Annotated[List[int], params.Query()]
data_str: Annotated[str, params.Body()]
data_int: Annotated[int, params.Body()]
data_list: Annotated[List[int], params.Body()]
data_dict: Annotated[Dict[str, int], params.Body(embed=True)]

data = dict(
data: Dict[str, Any] = dict(
data_str="test", data_int=1, data_list=[0, 1, 2], data_dict={"key": 1925}
)

r = SimpleRequest(**data, query_list=[1, 2, 3])
r = SimpleRequest(query_list=[1, 2, 3], **data)

x = r.request_args_for_values()

Expand All @@ -178,14 +178,14 @@ class SimpleRequest(RequestModel[Any]):
}


def test_get_origin():
def test_get_origin() -> None:
assert get_origin(str) is None
assert get_origin(List[str]) is list
assert get_origin(Annotated[List[str], params.Query()]) is Annotated
assert get_origin(Annotated[str, params.Query()]) is Annotated


def test_get_args():
def test_get_args() -> None:
assert get_args(List[str]) == (str,)
p = get_args(Annotated[List[str], params.Query()])
assert p[0] == List[str]
Expand Down

0 comments on commit 6043e82

Please sign in to comment.