Skip to content

Commit

Permalink
feat: additional tests and helper methods (#225)
Browse files Browse the repository at this point in the history
* feat: additional tests and helper methods

* feat: re-use functions

* feat: simplify logic further
  • Loading branch information
cofin authored Jun 27, 2024
1 parent cfd5e1b commit da87543
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 25 deletions.
14 changes: 14 additions & 0 deletions advanced_alchemy/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@
ModelDictT,
ModelDTOT,
is_dict,
is_dict_with_field,
is_dict_without_field,
is_msgspec_model,
is_msgspec_model_with_field,
is_msgspec_model_without_field,
is_pydantic_model,
is_pydantic_model_with_field,
is_pydantic_model_without_field,
schema_dump,
)

__all__ = (
Expand All @@ -34,8 +41,15 @@
"find_filter",
"ResultConverter",
"is_dict",
"is_dict_with_field",
"is_dict_without_field",
"is_msgspec_model",
"is_pydantic_model_with_field",
"is_msgspec_model_without_field",
"is_pydantic_model",
"is_msgspec_model_with_field",
"is_pydantic_model_without_field",
"schema_dump",
"LoadSpec",
"model_from_dict",
"ModelT",
Expand Down
44 changes: 29 additions & 15 deletions advanced_alchemy/service/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from typing_extensions import TypeAlias, TypeGuard

from advanced_alchemy.exceptions import AdvancedAlchemyError
from advanced_alchemy.filters import StatementFilter # noqa: TCH001
from advanced_alchemy.repository.typing import ModelT

Expand All @@ -37,6 +36,8 @@
class BaseModel(Protocol): # type: ignore[no-redef] # pragma: nocover
"""Placeholder Implementation"""

model_fields: ClassVar[dict[str, Any]]

def model_dump(*args: Any, **kwargs: Any) -> dict[str, Any]:
"""Placeholder"""
return {}
Expand Down Expand Up @@ -108,24 +109,34 @@ def is_dict_without_field(v: Any, field_name: str) -> TypeGuard[dict[str, Any]]:


def is_pydantic_model_with_field(v: Any, field_name: str) -> TypeGuard[BaseModel]:
return PYDANTIC_INSTALLED and isinstance(v, BaseModel) and field_name in v.model_fields # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType]
return is_pydantic_model(v) and field_name in v.model_fields


def is_pydantic_model_without_field(v: Any, field_name: str) -> TypeGuard[BaseModel]:
return not is_pydantic_model_with_field(v, field_name)


def is_msgspec_model_with_field(v: Any, field_name: str) -> TypeGuard[Struct]:
return MSGSPEC_INSTALLED and isinstance(v, Struct) and field_name in v.__struct_fields__
return is_msgspec_model(v) and field_name in v.__struct_fields__


def is_msgspec_model_without_field(v: Any, field_name: str) -> TypeGuard[Struct]:
return not is_msgspec_model_with_field(v, field_name)


def schema_to_dict(v: Any, exclude_unset: bool = True) -> dict[str, Any]:
if is_dict(v):
return v
if is_pydantic_model(v):
return v.model_dump(exclude_unset=exclude_unset)
if is_msgspec_model(v) and exclude_unset:
return {f: val for f in v.__struct_fields__ if (val := getattr(v, f, None)) != UNSET}
if is_msgspec_model(v) and not exclude_unset:
return {f: getattr(v, f, None) for f in v.__struct_fields__}
msg = f"Unable to convert model to dictionary for '{type(v)}' types"
raise AdvancedAlchemyError(msg)
def schema_dump(
data: dict[str, Any] | ModelT | Struct | BaseModel,
exclude_unset: bool = True,
) -> dict[str, Any] | ModelT:
if is_dict(data):
return data
if is_pydantic_model(data):
return data.model_dump(exclude_unset=exclude_unset)
if is_msgspec_model(data) and exclude_unset:
return {f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET}
if is_msgspec_model(data) and not exclude_unset:
return {f: getattr(data, f, None) for f in data.__struct_fields__}
return cast("ModelT", data)


__all__ = (
Expand All @@ -143,9 +154,12 @@ def schema_to_dict(v: Any, exclude_unset: bool = True) -> dict[str, Any]:
"UNSET",
"is_dict",
"is_dict_with_field",
"is_dict_without_field",
"is_msgspec_model",
"is_pydantic_model_with_field",
"is_msgspec_model_without_field",
"is_pydantic_model",
"is_msgspec_model_with_field",
"schema_to_dict",
"is_pydantic_model_without_field",
"schema_dump",
)
10 changes: 5 additions & 5 deletions tests/fixtures/bigint/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
SQLAlchemyAsyncRepositoryService,
SQLAlchemySyncRepositoryService,
)
from advanced_alchemy.service.typing import ModelDictT, is_dict_with_field, is_dict_without_field, schema_to_dict
from advanced_alchemy.service.typing import ModelDictT, is_dict_with_field, is_dict_without_field, schema_dump
from tests.fixtures.bigint.models import (
BigIntAuthor,
BigIntBook,
Expand Down Expand Up @@ -228,7 +228,7 @@ async def to_model(
data: ModelDictT[BigIntSlugBook],
operation: str | None = None,
) -> BigIntSlugBook:
data = schema_to_dict(data)
data = schema_dump(data)
if is_dict_without_field(data, "slug") and operation == "create":
data["slug"] = await self.repository.get_available_slug(data["title"])
if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update":
Expand All @@ -250,7 +250,7 @@ def to_model(
data: ModelDictT[BigIntSlugBook],
operation: str | None = None,
) -> BigIntSlugBook:
data = schema_to_dict(data)
data = schema_dump(data)
if is_dict_without_field(data, "slug") and operation == "create":
data["slug"] = self.repository.get_available_slug(data["title"])
if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update":
Expand All @@ -272,7 +272,7 @@ async def to_model(
data: ModelDictT[BigIntSlugBook],
operation: str | None = None,
) -> BigIntSlugBook:
data = schema_to_dict(data)
data = schema_dump(data)
if is_dict_without_field(data, "slug") and operation == "create":
data["slug"] = await self.repository.get_available_slug(data["title"])
if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update":
Expand All @@ -294,7 +294,7 @@ def to_model(
data: ModelDictT[BigIntSlugBook],
operation: str | None = None,
) -> BigIntSlugBook:
data = schema_to_dict(data)
data = schema_dump(data)
if is_dict_without_field(data, "slug") and operation == "create":
data["slug"] = self.repository.get_available_slug(data["title"])
if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update":
Expand Down
10 changes: 5 additions & 5 deletions tests/fixtures/uuid/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
PydanticOrMsgspecT,
is_dict_with_field,
is_dict_without_field,
schema_to_dict,
schema_dump,
)
from tests.fixtures.uuid.models import (
UUIDAuthor,
Expand Down Expand Up @@ -229,7 +229,7 @@ async def to_model(
data: UUIDSlugBook | dict[str, Any] | PydanticOrMsgspecT,
operation: str | None = None,
) -> UUIDSlugBook:
data = schema_to_dict(data)
data = schema_dump(data)
if is_dict_without_field(data, "slug") and operation == "create":
data["slug"] = await self.repository.get_available_slug(data["title"])
if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update":
Expand All @@ -250,7 +250,7 @@ def to_model(
data: UUIDSlugBook | dict[str, Any] | PydanticOrMsgspecT,
operation: str | None = None,
) -> UUIDSlugBook:
data = schema_to_dict(data)
data = schema_dump(data)
if is_dict_without_field(data, "slug") and operation == "create":
data["slug"] = self.repository.get_available_slug(data["title"])
if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update":
Expand All @@ -272,7 +272,7 @@ async def to_model(
data: UUIDSlugBook | dict[str, Any] | PydanticOrMsgspecT,
operation: str | None = None,
) -> UUIDSlugBook:
data = schema_to_dict(data)
data = schema_dump(data)
if is_dict_without_field(data, "slug") and operation == "create":
data["slug"] = await self.repository.get_available_slug(data["title"])
if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update":
Expand All @@ -294,7 +294,7 @@ def to_model(
data: UUIDSlugBook | dict[str, Any] | PydanticOrMsgspecT,
operation: str | None = None,
) -> UUIDSlugBook:
data = schema_to_dict(data)
data = schema_dump(data)
if is_dict_without_field(data, "slug") and operation == "create":
data["slug"] = self.repository.get_available_slug(data["title"])
if is_dict_without_field(data, "slug") and is_dict_with_field(data, "title") and operation == "update":
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/test_sqlquery_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from advanced_alchemy.service.typing import (
is_msgspec_model,
is_msgspec_model_with_field,
is_msgspec_model_without_field,
is_pydantic_model,
is_pydantic_model_with_field,
is_pydantic_model_without_field,
)
from advanced_alchemy.utils.fixtures import open_fixture, open_fixture_async

Expand Down Expand Up @@ -161,6 +163,7 @@ def test_sync_fixture_and_query() -> None:
assert isinstance(_pydantic_obj, StateQueryBaseModel)
assert is_pydantic_model(_pydantic_obj)
assert is_pydantic_model_with_field(_pydantic_obj, "state_abbreviation")
assert not is_pydantic_model_without_field(_pydantic_obj, "state_abbreviation")

_msgspec_obj = query_service.to_schema(
data=_get_one_or_none_1,
Expand All @@ -169,6 +172,7 @@ def test_sync_fixture_and_query() -> None:
assert isinstance(_msgspec_obj, StateQueryStruct)
assert is_msgspec_model(_msgspec_obj)
assert is_msgspec_model_with_field(_msgspec_obj, "state_abbreviation")
assert not is_msgspec_model_without_field(_msgspec_obj, "state_abbreviation")

_get_one_or_none = query_service.repository.get_one_or_none(
statement=select(StateQuery).filter_by(state_name="Nope"),
Expand Down Expand Up @@ -234,6 +238,8 @@ async def test_async_fixture_and_query() -> None:
assert isinstance(_pydantic_obj, StateQueryBaseModel)
assert is_pydantic_model(_pydantic_obj)
assert is_pydantic_model_with_field(_pydantic_obj, "state_abbreviation")
assert not is_pydantic_model_without_field(_pydantic_obj, "state_abbreviation")

_msgspec_obj = query_service.to_schema(
data=_get_one_or_none_1,
schema_type=StateQueryStruct,
Expand All @@ -244,4 +250,5 @@ async def test_async_fixture_and_query() -> None:
_get_one_or_none = await query_service.repository.get_one_or_none(
select(StateQuery).filter_by(state_name="Nope"),
)
assert not is_msgspec_model_without_field(_msgspec_obj, "state_abbreviation")
assert _get_one_or_none is None

0 comments on commit da87543

Please sign in to comment.