diff --git a/advanced_alchemy/service/__init__.py b/advanced_alchemy/service/__init__.py index 0bdfda2b..926d1072 100644 --- a/advanced_alchemy/service/__init__.py +++ b/advanced_alchemy/service/__init__.py @@ -16,8 +16,15 @@ ModelDictT, ModelDTOT, is_dict, + is_dict_with_field, + is_dict_without_field, is_msgspec_model, + is_msgspec_model_with_field, + is_msgspec_model_without_field, is_pydantic_model, + is_pydantic_model_with_field, + is_pydantic_model_without_field, + schema_dump, ) __all__ = ( @@ -34,8 +41,15 @@ "find_filter", "ResultConverter", "is_dict", + "is_dict_with_field", + "is_dict_without_field", "is_msgspec_model", + "is_pydantic_model_with_field", + "is_msgspec_model_without_field", "is_pydantic_model", + "is_msgspec_model_with_field", + "is_pydantic_model_without_field", + "schema_dump", "LoadSpec", "model_from_dict", "ModelT", diff --git a/advanced_alchemy/service/typing.py b/advanced_alchemy/service/typing.py index 2e4c39bd..e59b4158 100644 --- a/advanced_alchemy/service/typing.py +++ b/advanced_alchemy/service/typing.py @@ -22,7 +22,6 @@ from typing_extensions import TypeAlias, TypeGuard -from advanced_alchemy.exceptions import AdvancedAlchemyError from advanced_alchemy.filters import StatementFilter # noqa: TCH001 from advanced_alchemy.repository.typing import ModelT @@ -37,6 +36,8 @@ class BaseModel(Protocol): # type: ignore[no-redef] # pragma: nocover """Placeholder Implementation""" + model_fields: ClassVar[dict[str, Any]] + def model_dump(*args: Any, **kwargs: Any) -> dict[str, Any]: """Placeholder""" return {} @@ -108,24 +109,34 @@ def is_dict_without_field(v: Any, field_name: str) -> TypeGuard[dict[str, Any]]: def is_pydantic_model_with_field(v: Any, field_name: str) -> TypeGuard[BaseModel]: - return PYDANTIC_INSTALLED and isinstance(v, BaseModel) and field_name in v.model_fields # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType] + return is_pydantic_model(v) and field_name in v.model_fields + + +def is_pydantic_model_without_field(v: Any, field_name: str) -> TypeGuard[BaseModel]: + return not is_pydantic_model_with_field(v, field_name) def is_msgspec_model_with_field(v: Any, field_name: str) -> TypeGuard[Struct]: - return MSGSPEC_INSTALLED and isinstance(v, Struct) and field_name in v.__struct_fields__ + return is_msgspec_model(v) and field_name in v.__struct_fields__ + + +def is_msgspec_model_without_field(v: Any, field_name: str) -> TypeGuard[Struct]: + return not is_msgspec_model_with_field(v, field_name) -def schema_to_dict(v: Any, exclude_unset: bool = True) -> dict[str, Any]: - if is_dict(v): - return v - if is_pydantic_model(v): - return v.model_dump(exclude_unset=exclude_unset) - if is_msgspec_model(v) and exclude_unset: - return {f: val for f in v.__struct_fields__ if (val := getattr(v, f, None)) != UNSET} - if is_msgspec_model(v) and not exclude_unset: - return {f: getattr(v, f, None) for f in v.__struct_fields__} - msg = f"Unable to convert model to dictionary for '{type(v)}' types" - raise AdvancedAlchemyError(msg) +def schema_dump( + data: dict[str, Any] | ModelT | Struct | BaseModel, + exclude_unset: bool = True, +) -> dict[str, Any] | ModelT: + if is_dict(data): + return data + if is_pydantic_model(data): + return data.model_dump(exclude_unset=exclude_unset) + if is_msgspec_model(data) and exclude_unset: + return {f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET} + if is_msgspec_model(data) and not exclude_unset: + return {f: getattr(data, f, None) for f in data.__struct_fields__} + return cast("ModelT", data) __all__ = ( @@ -143,9 +154,12 @@ def schema_to_dict(v: Any, exclude_unset: bool = True) -> dict[str, Any]: "UNSET", "is_dict", "is_dict_with_field", + "is_dict_without_field", "is_msgspec_model", "is_pydantic_model_with_field", + "is_msgspec_model_without_field", "is_pydantic_model", "is_msgspec_model_with_field", - "schema_to_dict", + "is_pydantic_model_without_field", + "schema_dump", ) diff --git a/tests/fixtures/bigint/services.py b/tests/fixtures/bigint/services.py index fe650ef5..a65e24ec 100644 --- a/tests/fixtures/bigint/services.py +++ b/tests/fixtures/bigint/services.py @@ -8,7 +8,7 @@ SQLAlchemyAsyncRepositoryService, SQLAlchemySyncRepositoryService, ) -from advanced_alchemy.service.typing import ModelDictT, is_dict_with_field, is_dict_without_field, schema_to_dict +from advanced_alchemy.service.typing import ModelDictT, is_dict_with_field, is_dict_without_field, schema_dump from tests.fixtures.bigint.models import ( BigIntAuthor, BigIntBook, @@ -228,7 +228,7 @@ async def to_model( data: ModelDictT[BigIntSlugBook], operation: str | None = None, ) -> BigIntSlugBook: - data = schema_to_dict(data) + data = schema_dump(data) if is_dict_without_field(data, "slug") and operation == "create": data["slug"] = await self.repository.get_available_slug(data["title"]) if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update": @@ -250,7 +250,7 @@ def to_model( data: ModelDictT[BigIntSlugBook], operation: str | None = None, ) -> BigIntSlugBook: - data = schema_to_dict(data) + data = schema_dump(data) if is_dict_without_field(data, "slug") and operation == "create": data["slug"] = self.repository.get_available_slug(data["title"]) if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update": @@ -272,7 +272,7 @@ async def to_model( data: ModelDictT[BigIntSlugBook], operation: str | None = None, ) -> BigIntSlugBook: - data = schema_to_dict(data) + data = schema_dump(data) if is_dict_without_field(data, "slug") and operation == "create": data["slug"] = await self.repository.get_available_slug(data["title"]) if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update": @@ -294,7 +294,7 @@ def to_model( data: ModelDictT[BigIntSlugBook], operation: str | None = None, ) -> BigIntSlugBook: - data = schema_to_dict(data) + data = schema_dump(data) if is_dict_without_field(data, "slug") and operation == "create": data["slug"] = self.repository.get_available_slug(data["title"]) if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update": diff --git a/tests/fixtures/uuid/services.py b/tests/fixtures/uuid/services.py index 1fec4e68..6b0d039a 100644 --- a/tests/fixtures/uuid/services.py +++ b/tests/fixtures/uuid/services.py @@ -12,7 +12,7 @@ PydanticOrMsgspecT, is_dict_with_field, is_dict_without_field, - schema_to_dict, + schema_dump, ) from tests.fixtures.uuid.models import ( UUIDAuthor, @@ -229,7 +229,7 @@ async def to_model( data: UUIDSlugBook | dict[str, Any] | PydanticOrMsgspecT, operation: str | None = None, ) -> UUIDSlugBook: - data = schema_to_dict(data) + data = schema_dump(data) if is_dict_without_field(data, "slug") and operation == "create": data["slug"] = await self.repository.get_available_slug(data["title"]) if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update": @@ -250,7 +250,7 @@ def to_model( data: UUIDSlugBook | dict[str, Any] | PydanticOrMsgspecT, operation: str | None = None, ) -> UUIDSlugBook: - data = schema_to_dict(data) + data = schema_dump(data) if is_dict_without_field(data, "slug") and operation == "create": data["slug"] = self.repository.get_available_slug(data["title"]) if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update": @@ -272,7 +272,7 @@ async def to_model( data: UUIDSlugBook | dict[str, Any] | PydanticOrMsgspecT, operation: str | None = None, ) -> UUIDSlugBook: - data = schema_to_dict(data) + data = schema_dump(data) if is_dict_without_field(data, "slug") and operation == "create": data["slug"] = await self.repository.get_available_slug(data["title"]) if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update": @@ -294,7 +294,7 @@ def to_model( data: UUIDSlugBook | dict[str, Any] | PydanticOrMsgspecT, operation: str | None = None, ) -> UUIDSlugBook: - data = schema_to_dict(data) + data = schema_dump(data) if is_dict_without_field(data, "slug") and operation == "create": data["slug"] = self.repository.get_available_slug(data["title"]) if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update": diff --git a/tests/integration/test_sqlquery_service.py b/tests/integration/test_sqlquery_service.py index 2e608c83..d12f66a8 100644 --- a/tests/integration/test_sqlquery_service.py +++ b/tests/integration/test_sqlquery_service.py @@ -20,8 +20,10 @@ from advanced_alchemy.service.typing import ( is_msgspec_model, is_msgspec_model_with_field, + is_msgspec_model_without_field, is_pydantic_model, is_pydantic_model_with_field, + is_pydantic_model_without_field, ) from advanced_alchemy.utils.fixtures import open_fixture, open_fixture_async @@ -161,6 +163,7 @@ def test_sync_fixture_and_query() -> None: assert isinstance(_pydantic_obj, StateQueryBaseModel) assert is_pydantic_model(_pydantic_obj) assert is_pydantic_model_with_field(_pydantic_obj, "state_abbreviation") + assert not is_pydantic_model_without_field(_pydantic_obj, "state_abbreviation") _msgspec_obj = query_service.to_schema( data=_get_one_or_none_1, @@ -169,6 +172,7 @@ def test_sync_fixture_and_query() -> None: assert isinstance(_msgspec_obj, StateQueryStruct) assert is_msgspec_model(_msgspec_obj) assert is_msgspec_model_with_field(_msgspec_obj, "state_abbreviation") + assert not is_msgspec_model_without_field(_msgspec_obj, "state_abbreviation") _get_one_or_none = query_service.repository.get_one_or_none( statement=select(StateQuery).filter_by(state_name="Nope"), @@ -234,6 +238,8 @@ async def test_async_fixture_and_query() -> None: assert isinstance(_pydantic_obj, StateQueryBaseModel) assert is_pydantic_model(_pydantic_obj) assert is_pydantic_model_with_field(_pydantic_obj, "state_abbreviation") + assert not is_pydantic_model_without_field(_pydantic_obj, "state_abbreviation") + _msgspec_obj = query_service.to_schema( data=_get_one_or_none_1, schema_type=StateQueryStruct, @@ -244,4 +250,5 @@ async def test_async_fixture_and_query() -> None: _get_one_or_none = await query_service.repository.get_one_or_none( select(StateQuery).filter_by(state_name="Nope"), ) + assert not is_msgspec_model_without_field(_msgspec_obj, "state_abbreviation") assert _get_one_or_none is None