Skip to content

Commit

Permalink
Embed fastapi methods and move fastapi to dev dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
foarsitter committed Feb 11, 2024
1 parent fa96314 commit c192f35
Show file tree
Hide file tree
Showing 10 changed files with 911 additions and 625 deletions.
7 changes: 6 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Nox sessions."""

import os
import shlex
import shutil
Expand Down Expand Up @@ -156,6 +157,7 @@ def mypy(session: Session) -> None:
"python-multipart",
"flask-wtf",
"a2wsgi",
"fastapi",
"requests",
"types-requests",
)
Expand All @@ -175,6 +177,7 @@ def tests(session: Session) -> None:
"python-multipart",
"flask-wtf",
"a2wsgi",
"fastapi",
"pygments",
"requests",
)
Expand All @@ -195,6 +198,7 @@ def coverage(session: Session) -> None:
if not session.posargs and any(Path().glob(".coverage.*")):
session.run("coverage", "combine")

session.run("coverage", "debug", "config")
session.run("coverage", *args)


Expand All @@ -208,8 +212,9 @@ def typeguard(session: Session) -> None:
"python-multipart",
"flask-wtf",
"a2wsgi",
"requests",
"fastapi",
"typeguard",
"requests",
"pygments",
)
session.run("pytest", f"--typeguard-packages={package}", *session.posargs)
Expand Down
1,096 changes: 506 additions & 590 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ Changelog = "https://github.com/foarsitter/requestmodel/releases"

[tool.poetry.dependencies]
python = ">=3.8"
fastapi = "*"
httpx = "*"
pydantic = "^2.0.0"

[tool.poetry.group.requests.dependencies]
requests = "^2.31.0"
Expand Down Expand Up @@ -51,6 +51,7 @@ sphinx-autobuild = ">=2021.3.14"
typeguard = ">=2.13.3"
xdoctest = { extras = ["colors"], version = ">=0.15.10" }
myst-parser = { version = ">=0.16.1" }
fastapi = "^0.109.2"

[tool.coverage.paths]
source = ["src", "*/site-packages"]
Expand All @@ -59,6 +60,7 @@ tests = ["tests", "*/tests"]
[tool.coverage.run]
branch = true
source = ["requestmodel", "tests"]
omit = ["src/requestmodel/encoders.py", '**/encoders.py']

[tool.coverage.report]
show_missing = true
Expand Down
1 change: 0 additions & 1 deletion src/requestmodel/encoders.py

This file was deleted.

324 changes: 324 additions & 0 deletions src/requestmodel/fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
"""
Copied from fastapi.encoders
"""

import dataclasses
import datetime
import re
import types
from collections import defaultdict
from collections import deque
from dataclasses import is_dataclass
from decimal import Decimal
from enum import Enum
from ipaddress import IPv4Address
from ipaddress import IPv4Interface
from ipaddress import IPv4Network
from ipaddress import IPv6Address
from ipaddress import IPv6Interface
from ipaddress import IPv6Network
from pathlib import Path
from pathlib import PurePath
from re import Pattern
from types import GeneratorType
from typing import Any
from typing import Callable
from typing import Deque
from typing import Dict
from typing import FrozenSet
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union
from uuid import UUID

from pydantic import BaseModel
from pydantic._internal._utils import lenient_issubclass
from pydantic.color import Color
from pydantic.networks import AnyUrl
from pydantic.networks import NameEmail
from pydantic.types import SecretBytes
from pydantic.types import SecretStr
from pydantic_core import Url as Url
from typing_extensions import get_args
from typing_extensions import get_origin


IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]


# Taken from Pydantic v1 as is
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
return o.isoformat()


# Taken from Pydantic v1 as is
# TODO: pv2 should this return strings instead?
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
"""
Encodes a Decimal as int of there's no exponent, otherwise float
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
where a integer (but not int typed) is used. Encoding this as a float
results in failed round-tripping between encode and parse.
Our Id type is a prime example of this.
>>> decimal_encoder(Decimal("1.0"))
1.0
>>> decimal_encoder(Decimal("1"))
1
"""
if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
return int(dec_value)
else:
return float(dec_value)


ENCODERS_BY_TYPE: Dict[Any, Callable[[Any], Any]] = {
bytes: lambda o: o.decode(),
Color: str,
datetime.date: isoformat,
datetime.datetime: isoformat,
datetime.time: isoformat,
datetime.timedelta: lambda td: td.total_seconds(),
Decimal: decimal_encoder,
Enum: lambda o: o.value,
frozenset: list,
deque: list,
GeneratorType: list,
IPv4Address: str,
IPv4Interface: str,
IPv4Network: str,
IPv6Address: str,
IPv6Interface: str,
IPv6Network: str,
NameEmail: str,
Path: str,
Pattern: lambda o: o.pattern,
SecretBytes: str,
SecretStr: str,
set: list,
UUID: str,
Url: str,
AnyUrl: str,
}


def generate_encoders_by_class_tuples(
type_encoder_map: Dict[Any, Callable[[Any], Any]],
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
tuple
)
for type_, encoder in type_encoder_map.items():
encoders_by_class_tuples[encoder] += (type_,)
return encoders_by_class_tuples


encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)


def jsonable_encoder( # noqa: C901
obj: Any,
include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
) -> Any:
"""
Convert any object to something that can be encoded in JSON.
This is used internally by FastAPI to make sure anything you return can be
encoded as JSON before it is sent to the client.
You can also use it yourself, for example to convert objects before saving them
in a database that supports only JSON.
Read more about it in the
[FastAPI docs for JSON Compatible Encoder](https://fastapi.tiangolo.com/tutorial/encoder/).
""" # noqa: B950

if include is not None and not isinstance(include, (set, dict)):
include = set(include)
if exclude is not None and not isinstance(exclude, (set, dict)):
exclude = set(exclude)
if isinstance(obj, BaseModel):
obj_dict = obj.model_dump(
mode="json",
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
return jsonable_encoder(
obj_dict,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
obj_dict = dataclasses.asdict(obj)
return jsonable_encoder(
obj_dict,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, dict):
encoded_dict = {}
allowed_keys = set(obj.keys())
if include is not None:
allowed_keys &= set(include)
if exclude is not None:
allowed_keys -= set(exclude)
for key, value in obj.items():
if (
(
not sqlalchemy_safe
or (not isinstance(key, str))
or (not key.startswith("_sa"))
)
and (value is not None or not exclude_none)
and key in allowed_keys
):
encoded_key = jsonable_encoder(
key,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = jsonable_encoder(
value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)):
encoded_list = []
for item in obj:
encoded_list.append(
jsonable_encoder(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
)
return encoded_list

if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)

try:
data = dict(obj)
except Exception as e:
errors: List[Exception] = [e]
try:
data = vars(obj)
except Exception as e:
errors.append(e)
raise ValueError(errors) from e
return jsonable_encoder(
data,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)


UnionType = getattr(types, "UnionType", Union)
sequence_annotation_to_type = {
Sequence: list,
List: list,
list: list,
Tuple: tuple,
tuple: tuple,
Set: set,
set: set,
FrozenSet: frozenset,
frozenset: frozenset,
Deque: deque,
deque: deque,
}
sequence_types = tuple(sequence_annotation_to_type.keys())


def _annotation_is_sequence(annotation: Any) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, sequence_types)


def _annotation_is_complex(annotation: Any) -> bool:
return (
lenient_issubclass(annotation, (BaseModel, Mapping))
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)


def field_annotation_is_complex(
annotation: Any,
) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))

return (
_annotation_is_complex(annotation)
or _annotation_is_complex(origin)
or hasattr(origin, "__pydantic_core_schema__")
or hasattr(origin, "__get_pydantic_core_schema__")
)


def field_annotation_is_sequence(annotation: Any) -> bool:
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
get_origin(annotation)
)


def get_path_param_names(path: str) -> Set[str]:
return set(re.findall("{(.*?)}", path))
Loading

0 comments on commit c192f35

Please sign in to comment.