Skip to content

Commit

Permalink
feat: add wrap_exceptions option to exception handler. (#363)
Browse files Browse the repository at this point in the history
There's a new option to make the exception handler more configurable.  Here's the way it'll work going forward:

If you don't want any error message mapping and you don't want the exception handler to wrap the exception in a RepositoryError, then you can set `wrap_exceptions` to `False` on the repo.

If you want the wrapped repository error, but have it skip the "better error messages" bit, send in `None` for the error messages here.  We differentiate Empty from None, so that we can disable the message section when the default is not Empty but is None.

When left alone, it's the current behavior
  • Loading branch information
cofin authored Jan 23, 2025
1 parent 04977b3 commit ce1f26a
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 1 deletion.
20 changes: 19 additions & 1 deletion advanced_alchemy/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,19 @@ def _get_error_message(error_messages: ErrorMessages, key: str, exc: Exception)


@contextmanager
def wrap_sqlalchemy_exception(
def wrap_sqlalchemy_exception( # noqa: C901
error_messages: ErrorMessages | None = None,
dialect_name: str | None = None,
wrap_exceptions: bool = True,
) -> Generator[None, None, None]:
"""Do something within context to raise a ``RepositoryError`` chained
from an original ``SQLAlchemyError``.
Args:
error_messages: Error messages to use for the exception.
dialect_name: The name of the dialect to use for the exception.
wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.
>>> try:
... with wrap_sqlalchemy_exception():
... raise SQLAlchemyError("Original Exception")
Expand All @@ -294,12 +300,16 @@ def wrap_sqlalchemy_exception(
yield

except MultipleResultsFound as exc:
if wrap_exceptions is False:
raise
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="multiple_rows", exc=exc)
else:
msg = "Multiple rows matched the specified data"
raise MultipleResultsFoundError(detail=msg) from exc
except SQLAlchemyIntegrityError as exc:
if wrap_exceptions is False:
raise
if error_messages is not None and dialect_name is not None:
_keys_to_regex = {
"duplicate_key": (DUPLICATE_KEY_REGEXES.get(dialect_name, []), DuplicateKeyError),
Expand All @@ -319,18 +329,26 @@ def wrap_sqlalchemy_exception(
) from exc
raise IntegrityError(detail=f"An integrity error occurred: {exc}") from exc
except SQLAlchemyInvalidRequestError as exc:
if wrap_exceptions is False:
raise
raise InvalidRequestError(detail="An invalid request was made.") from exc
except StatementError as exc:
if wrap_exceptions is False:
raise
raise IntegrityError(
detail=cast(str, getattr(exc.orig, "detail", "There was an issue processing the statement."))
) from exc
except SQLAlchemyError as exc:
if wrap_exceptions is False:
raise
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="other", exc=exc)
else:
msg = f"An exception occurred: {exc}"
raise RepositoryError(detail=msg) from exc
except AttributeError as exc:
if wrap_exceptions is False:
raise
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="other", exc=exc)
else:
Expand Down
7 changes: 7 additions & 0 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class SQLAlchemyAsyncRepositoryProtocol(FilterableRepositoryProtocol[ModelT], Pr
auto_commit: bool
order_by: list[OrderingPair] | OrderingPair | None = None
error_messages: ErrorMessages | None = None
wrap_exceptions: bool = True

def __init__(
self,
Expand All @@ -87,6 +88,7 @@ def __init__(
execution_options: dict[str, Any] | None = None,
order_by: list[OrderingPair] | OrderingPair | None = None,
error_messages: ErrorMessages | None | EmptyType = Empty,
wrap_exceptions: bool = True,
**kwargs: Any,
) -> None: ...

Expand Down Expand Up @@ -408,6 +410,8 @@ class SQLAlchemyAsyncRepository(SQLAlchemyAsyncRepositoryProtocol[ModelT], Filte
"""Default loader options for the repository."""
error_messages: ErrorMessages | None = None
"""Default error messages for the repository."""
wrap_exceptions: bool = True
"""Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised."""
inherit_lazy_relationships: bool = True
"""Optionally ignore the default ``lazy`` configuration for model relationships. This is useful for when you want to
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
Expand Down Expand Up @@ -436,6 +440,7 @@ def __init__(
error_messages: ErrorMessages | None | EmptyType = Empty,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
wrap_exceptions: bool = True,
**kwargs: Any,
) -> None:
"""Repository for SQLAlchemy models.
Expand All @@ -450,6 +455,7 @@ def __init__(
load: Set default relationships to be loaded
execution_options: Set default execution options
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.
**kwargs: Additional arguments.
"""
Expand All @@ -461,6 +467,7 @@ def __init__(
self.error_messages = self._get_error_messages(
error_messages=error_messages, default_messages=self.error_messages
)
self.wrap_exceptions = wrap_exceptions
self._default_loader_options, self._loader_options_have_wildcards = get_abstract_loader_options(
loader_options=load if load is not None else self.loader_options,
inherit_lazy_relationships=self.inherit_lazy_relationships,
Expand Down
7 changes: 7 additions & 0 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class SQLAlchemySyncRepositoryProtocol(FilterableRepositoryProtocol[ModelT], Pro
auto_commit: bool
order_by: list[OrderingPair] | OrderingPair | None = None
error_messages: ErrorMessages | None = None
wrap_exceptions: bool = True

def __init__(
self,
Expand All @@ -88,6 +89,7 @@ def __init__(
execution_options: dict[str, Any] | None = None,
order_by: list[OrderingPair] | OrderingPair | None = None,
error_messages: ErrorMessages | None | EmptyType = Empty,
wrap_exceptions: bool = True,
**kwargs: Any,
) -> None: ...

Expand Down Expand Up @@ -409,6 +411,8 @@ class SQLAlchemySyncRepository(SQLAlchemySyncRepositoryProtocol[ModelT], Filtera
"""Default loader options for the repository."""
error_messages: ErrorMessages | None = None
"""Default error messages for the repository."""
wrap_exceptions: bool = True
"""Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised."""
inherit_lazy_relationships: bool = True
"""Optionally ignore the default ``lazy`` configuration for model relationships. This is useful for when you want to
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
Expand Down Expand Up @@ -437,6 +441,7 @@ def __init__(
error_messages: ErrorMessages | None | EmptyType = Empty,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
wrap_exceptions: bool = True,
**kwargs: Any,
) -> None:
"""Repository for SQLAlchemy models.
Expand All @@ -451,6 +456,7 @@ def __init__(
load: Set default relationships to be loaded
execution_options: Set default execution options
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.
**kwargs: Additional arguments.
"""
Expand All @@ -462,6 +468,7 @@ def __init__(
self.error_messages = self._get_error_messages(
error_messages=error_messages, default_messages=self.error_messages
)
self.wrap_exceptions = wrap_exceptions
self._default_loader_options, self._loader_options_have_wildcards = get_abstract_loader_options(
loader_options=load if load is not None else self.loader_options,
inherit_lazy_relationships=self.inherit_lazy_relationships,
Expand Down
3 changes: 3 additions & 0 deletions advanced_alchemy/repository/memory/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class SQLAlchemyAsyncMockRepository(SQLAlchemyAsyncRepositoryProtocol[ModelT]):
"order_by",
"load",
"error_messages",
"wrap_exceptions",
}

def __init__(
Expand All @@ -93,6 +94,7 @@ def __init__(
auto_commit: bool = False,
order_by: list[OrderingPair] | OrderingPair | None = None,
error_messages: ErrorMessages | None | EmptyType = Empty,
wrap_exceptions: bool = True,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
**kwargs: Any,
Expand All @@ -103,6 +105,7 @@ def __init__(
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.error_messages = self._get_error_messages(error_messages=error_messages)
self.wrap_exceptions = wrap_exceptions
self.order_by = order_by
self._dialect: Dialect = create_autospec(Dialect, instance=True)
self._dialect.name = "mock"
Expand Down
3 changes: 3 additions & 0 deletions advanced_alchemy/repository/memory/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class SQLAlchemySyncMockRepository(SQLAlchemySyncRepositoryProtocol[ModelT]):
"order_by",
"load",
"error_messages",
"wrap_exceptions",
}

def __init__(
Expand All @@ -94,6 +95,7 @@ def __init__(
auto_commit: bool = False,
order_by: list[OrderingPair] | OrderingPair | None = None,
error_messages: ErrorMessages | None | EmptyType = Empty,
wrap_exceptions: bool = True,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
**kwargs: Any,
Expand All @@ -104,6 +106,7 @@ def __init__(
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.error_messages = self._get_error_messages(error_messages=error_messages)
self.wrap_exceptions = wrap_exceptions
self.order_by = order_by
self._dialect: Dialect = create_autospec(Dialect, instance=True)
self._dialect.name = "mock"
Expand Down
3 changes: 3 additions & 0 deletions advanced_alchemy/service/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
auto_commit: bool = False,
order_by: list[OrderingPair] | OrderingPair | None = None,
error_messages: ErrorMessages | None | EmptyType = Empty,
wrap_exceptions: bool = True,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
**repo_kwargs: Any,
Expand All @@ -122,6 +123,7 @@ def __init__(
auto_commit: Commit objects before returning.
order_by: Set default order options for queries.
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap exceptions in a RepositoryError
load: Set default relationships to be loaded
execution_options: Set default execution options
**repo_kwargs: passed as keyword args to repo instantiation.
Expand All @@ -136,6 +138,7 @@ def __init__(
auto_commit=auto_commit,
order_by=order_by,
error_messages=error_messages,
wrap_exceptions=wrap_exceptions,
load=load,
execution_options=execution_options,
**repo_kwargs,
Expand Down
3 changes: 3 additions & 0 deletions advanced_alchemy/service/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
auto_commit: bool = False,
order_by: list[OrderingPair] | OrderingPair | None = None,
error_messages: ErrorMessages | None | EmptyType = Empty,
wrap_exceptions: bool = True,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
**repo_kwargs: Any,
Expand All @@ -136,6 +137,7 @@ def __init__(
auto_commit: Commit objects before returning.
order_by: Set default order options for queries.
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap exceptions in a RepositoryError
load: Set default relationships to be loaded
execution_options: Set default execution options
**repo_kwargs: passed as keyword args to repo instantiation.
Expand All @@ -150,6 +152,7 @@ def __init__(
auto_commit=auto_commit,
order_by=order_by,
error_messages=error_messages,
wrap_exceptions=wrap_exceptions,
load=load,
execution_options=execution_options,
**repo_kwargs,
Expand Down
127 changes: 127 additions & 0 deletions tests/unit/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
import contextlib

import pytest
from sqlalchemy.exc import (
IntegrityError as SQLAlchemyIntegrityError,
)
from sqlalchemy.exc import (
InvalidRequestError as SQLAlchemyInvalidRequestError,
)
from sqlalchemy.exc import (
MultipleResultsFound,
SQLAlchemyError,
StatementError,
)

from advanced_alchemy.exceptions import (
DuplicateKeyError,
IntegrityError,
InvalidRequestError,
MultipleResultsFoundError,
RepositoryError,
wrap_sqlalchemy_exception,
)


async def test_repo_get_or_create_deprecation() -> None:
Expand All @@ -9,3 +29,110 @@ async def test_repo_get_or_create_deprecation() -> None:

with contextlib.suppress(Exception):
raise ConflictError


def test_wrap_sqlalchemy_exception_multiple_results_found() -> None:
with pytest.raises(MultipleResultsFoundError), wrap_sqlalchemy_exception():
raise MultipleResultsFound()


@pytest.mark.parametrize("dialect_name", ["postgresql", "sqlite", "mysql"])
def test_wrap_sqlalchemy_exception_integrity_error_duplicate_key(dialect_name: str) -> None:
error_message = {
"postgresql": 'duplicate key value violates unique constraint "uq_%(table_name)s_%(column_0_name)s"',
"sqlite": "UNIQUE constraint failed: %(table_name)s.%(column_0_name)s",
"mysql": "1062 (23000): Duplicate entry '%(value)s' for key '%(table_name)s.%(column_0_name)s'",
}
with pytest.raises(DuplicateKeyError), wrap_sqlalchemy_exception(
dialect_name=dialect_name,
error_messages={"duplicate_key": error_message[dialect_name]},
):
if dialect_name == "postgresql":
exception = SQLAlchemyIntegrityError(
"INSERT INTO table (id) VALUES (1)",
{"table_name": "table", "column_0_name": "id"},
Exception(
'duplicate key value violates unique constraint "uq_table_id"\nDETAIL: Key (id)=(1) already exists.',
),
)
elif dialect_name == "sqlite":
exception = SQLAlchemyIntegrityError(
"INSERT INTO table (id) VALUES (1)",
{"table_name": "table", "column_0_name": "id"},
Exception("UNIQUE constraint failed: table.id"),
)
else:
exception = SQLAlchemyIntegrityError(
"INSERT INTO table (id) VALUES (1)",
{"table_name": "table", "column_0_name": "id", "value": "1"},
Exception("1062 (23000): Duplicate entry '1' for key 'table.id'"),
)

raise exception


def test_wrap_sqlalchemy_exception_integrity_error_other() -> None:
with pytest.raises(IntegrityError), wrap_sqlalchemy_exception():
raise SQLAlchemyIntegrityError("original", {}, Exception("original"))


def test_wrap_sqlalchemy_exception_invalid_request_error() -> None:
with pytest.raises(InvalidRequestError), wrap_sqlalchemy_exception():
raise SQLAlchemyInvalidRequestError("original", {}, Exception("original"))


def test_wrap_sqlalchemy_exception_statement_error() -> None:
with pytest.raises(IntegrityError), wrap_sqlalchemy_exception():
raise StatementError("original", None, {}, Exception("original")) # pyright: ignore[reportArgumentType]


def test_wrap_sqlalchemy_exception_sqlalchemy_error() -> None:
with pytest.raises(RepositoryError), wrap_sqlalchemy_exception():
raise SQLAlchemyError("original")


def test_wrap_sqlalchemy_exception_attribute_error() -> None:
with pytest.raises(RepositoryError), wrap_sqlalchemy_exception():
raise AttributeError("original")


def test_wrap_sqlalchemy_exception_no_wrap() -> None:
with pytest.raises(SQLAlchemyError), wrap_sqlalchemy_exception(wrap_exceptions=False):
raise SQLAlchemyError("original")
with pytest.raises(SQLAlchemyIntegrityError), wrap_sqlalchemy_exception(wrap_exceptions=False):
raise SQLAlchemyIntegrityError(statement="select 1", params=None, orig=BaseException())
with pytest.raises(MultipleResultsFound), wrap_sqlalchemy_exception(wrap_exceptions=False):
raise MultipleResultsFound()
with pytest.raises(SQLAlchemyInvalidRequestError), wrap_sqlalchemy_exception(wrap_exceptions=False):
raise SQLAlchemyInvalidRequestError()
with pytest.raises(AttributeError), wrap_sqlalchemy_exception(wrap_exceptions=False):
raise AttributeError()


def test_wrap_sqlalchemy_exception_custom_error_message() -> None:
def custom_message(exc: Exception) -> str:
return f"Custom: {exc}"

with pytest.raises(RepositoryError) as excinfo, wrap_sqlalchemy_exception(
error_messages={"other": custom_message},
):
raise SQLAlchemyError("original")

assert str(excinfo.value) == "Custom: original"


def test_wrap_sqlalchemy_exception_no_error_messages() -> None:
with pytest.raises(RepositoryError) as excinfo, wrap_sqlalchemy_exception():
raise SQLAlchemyError("original")

assert str(excinfo.value) == "An exception occurred: original"


def test_wrap_sqlalchemy_exception_no_match() -> None:
with pytest.raises(IntegrityError) as excinfo, wrap_sqlalchemy_exception(
dialect_name="postgresql",
error_messages={"integrity": "Integrity error"},
):
raise SQLAlchemyIntegrityError("original", {}, Exception("original"))

assert str(excinfo.value) == "Integrity error"

0 comments on commit ce1f26a

Please sign in to comment.