diff --git a/src/requestmodel/adapters/requests.py b/src/requestmodel/adapters/requests.py index 6c0e427..f77ddcd 100644 --- a/src/requestmodel/adapters/requests.py +++ b/src/requestmodel/adapters/requests.py @@ -50,4 +50,4 @@ def send(self, client: Session) -> ResponseType: r = self.as_request() self.response = client.send(r.prepare()) self.handle_error(self.response) - return self.response_model.model_validate(self.response.json()) + return self.adapt_type(self.response) diff --git a/src/requestmodel/model.py b/src/requestmodel/model.py index b3c7e9e..0ba86dc 100644 --- a/src/requestmodel/model.py +++ b/src/requestmodel/model.py @@ -1,7 +1,9 @@ +from typing import Any from typing import ClassVar from typing import Generic from typing import Iterator from typing import Optional +from typing import Protocol from typing import Set from typing import Type @@ -13,6 +15,7 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import TypeAdapter +from pydantic._internal._model_construction import ModelMetaclass from typing_extensions import get_type_hints from typing_extensions import override @@ -26,6 +29,10 @@ from .utils import unify_body +class JSONResponse(Protocol): # pragma: no cover + def json(self) -> Any: ... # noqa: E704 + + class BaseRequestModel(BaseModel, Generic[ResponseType]): """Declarative way to define a model""" @@ -86,6 +93,15 @@ def request_args_for_values(self) -> RequestArgs: return request_args + def adapt_type(self, response: JSONResponse) -> ResponseType: + if isinstance(self.response_model, TypeAdapter): + return self.response_model.validate_python(response.json()) + + if isinstance(self.response_model, (BaseModel, ModelMetaclass)): + return self.response_model.model_validate(response.json()) + + raise ValueError("response_model must be a TypeAdapter or a BaseModel") + class RequestModel(BaseRequestModel[ResponseType]): raw_response: Optional[Response] = None @@ -98,18 +114,15 @@ def send(self, client: Client) -> ResponseType: r = self.as_request(client) self.raw_response = client.send(r) self.handle_error(self.raw_response) - if isinstance(self.response_model, TypeAdapter): - return self.response_model.validate_python(self.raw_response.json()) - return self.response_model.model_validate(self.raw_response.json()) + + return self.adapt_type(self.raw_response) async def asend(self, client: AsyncClient) -> ResponseType: """Send the request asynchronously""" r = self.as_request(client) self.raw_response = await client.send(r) self.handle_error(self.raw_response) - if isinstance(self.response_model, TypeAdapter): - return self.response_model.validate_python(self.raw_response.json()) - return self.response_model.model_validate(self.raw_response.json()) + return self.adapt_type(self.raw_response) def as_request(self, client: BaseClient) -> Request: """Transform the properties of the object into a request""" diff --git a/src/requestmodel/typing.py b/src/requestmodel/typing.py index 561526b..76453a2 100644 --- a/src/requestmodel/typing.py +++ b/src/requestmodel/typing.py @@ -1,11 +1,17 @@ from typing import Any from typing import Dict +from typing import List from typing import Type from typing import TypeVar +from typing import Union from pydantic import BaseModel +from pydantic import TypeAdapter from pydantic.fields import FieldInfo -ResponseType = TypeVar("ResponseType", bound=BaseModel) +ResponseType = TypeVar( + "ResponseType", + bound=Union[BaseModel, TypeAdapter[List[BaseModel]], List[BaseModel]], +) RequestArgs = Dict[Type[FieldInfo], Dict[str, Any]] diff --git a/tests/test_request_model.py b/tests/test_request_model.py index 1f6fc6c..fe67532 100644 --- a/tests/test_request_model.py +++ b/tests/test_request_model.py @@ -211,3 +211,12 @@ def test_type_adapter() -> None: response = request.send(client) assert response == [NameModel(name="test")] + + +def test_type_adapter_exception() -> None: + request = TypeAdapterRequest() + TypeAdapterRequest.response_model = int # type: ignore + with pytest.raises( + ValueError, match="response_model must be a TypeAdapter or a BaseModel" + ): + request.adapt_type(SimpleResponse(data="test"))