diff --git a/advanced_alchemy/exceptions.py b/advanced_alchemy/exceptions.py index c3cf47c9..1f14dfb6 100644 --- a/advanced_alchemy/exceptions.py +++ b/advanced_alchemy/exceptions.py @@ -274,13 +274,19 @@ def _get_error_message(error_messages: ErrorMessages, key: str, exc: Exception) @contextmanager -def wrap_sqlalchemy_exception( +def wrap_sqlalchemy_exception( # noqa: C901 error_messages: ErrorMessages | None = None, dialect_name: str | None = None, + wrap_exceptions: bool = True, ) -> Generator[None, None, None]: """Do something within context to raise a ``RepositoryError`` chained from an original ``SQLAlchemyError``. + Args: + error_messages: Error messages to use for the exception. + dialect_name: The name of the dialect to use for the exception. + wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised. + >>> try: ... with wrap_sqlalchemy_exception(): ... raise SQLAlchemyError("Original Exception") @@ -294,12 +300,16 @@ def wrap_sqlalchemy_exception( yield except MultipleResultsFound as exc: + if wrap_exceptions is False: + raise if error_messages is not None: msg = _get_error_message(error_messages=error_messages, key="multiple_rows", exc=exc) else: msg = "Multiple rows matched the specified data" raise MultipleResultsFoundError(detail=msg) from exc except SQLAlchemyIntegrityError as exc: + if wrap_exceptions is False: + raise if error_messages is not None and dialect_name is not None: _keys_to_regex = { "duplicate_key": (DUPLICATE_KEY_REGEXES.get(dialect_name, []), DuplicateKeyError), @@ -319,18 +329,26 @@ def wrap_sqlalchemy_exception( ) from exc raise IntegrityError(detail=f"An integrity error occurred: {exc}") from exc except SQLAlchemyInvalidRequestError as exc: + if wrap_exceptions is False: + raise raise InvalidRequestError(detail="An invalid request was made.") from exc except StatementError as exc: + if wrap_exceptions is False: + raise raise IntegrityError( detail=cast(str, getattr(exc.orig, "detail", "There was an issue processing the statement.")) ) from exc except SQLAlchemyError as exc: + if wrap_exceptions is False: + raise if error_messages is not None: msg = _get_error_message(error_messages=error_messages, key="other", exc=exc) else: msg = f"An exception occurred: {exc}" raise RepositoryError(detail=msg) from exc except AttributeError as exc: + if wrap_exceptions is False: + raise if error_messages is not None: msg = _get_error_message(error_messages=error_messages, key="other", exc=exc) else: diff --git a/advanced_alchemy/repository/_async.py b/advanced_alchemy/repository/_async.py index 39fdb82b..62e41be3 100644 --- a/advanced_alchemy/repository/_async.py +++ b/advanced_alchemy/repository/_async.py @@ -74,6 +74,7 @@ class SQLAlchemyAsyncRepositoryProtocol(FilterableRepositoryProtocol[ModelT], Pr auto_commit: bool order_by: list[OrderingPair] | OrderingPair | None = None error_messages: ErrorMessages | None = None + wrap_exceptions: bool = True def __init__( self, @@ -87,6 +88,7 @@ def __init__( execution_options: dict[str, Any] | None = None, order_by: list[OrderingPair] | OrderingPair | None = None, error_messages: ErrorMessages | None | EmptyType = Empty, + wrap_exceptions: bool = True, **kwargs: Any, ) -> None: ... @@ -408,6 +410,8 @@ class SQLAlchemyAsyncRepository(SQLAlchemyAsyncRepositoryProtocol[ModelT], Filte """Default loader options for the repository.""" error_messages: ErrorMessages | None = None """Default error messages for the repository.""" + wrap_exceptions: bool = True + """Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.""" inherit_lazy_relationships: bool = True """Optionally ignore the default ``lazy`` configuration for model relationships. This is useful for when you want to replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration.""" @@ -436,6 +440,7 @@ def __init__( error_messages: ErrorMessages | None | EmptyType = Empty, load: LoadSpec | None = None, execution_options: dict[str, Any] | None = None, + wrap_exceptions: bool = True, **kwargs: Any, ) -> None: """Repository for SQLAlchemy models. @@ -450,6 +455,7 @@ def __init__( load: Set default relationships to be loaded execution_options: Set default execution options error_messages: A set of custom error messages to use for operations + wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised. **kwargs: Additional arguments. """ @@ -461,6 +467,7 @@ def __init__( self.error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages ) + self.wrap_exceptions = wrap_exceptions self._default_loader_options, self._loader_options_have_wildcards = get_abstract_loader_options( loader_options=load if load is not None else self.loader_options, inherit_lazy_relationships=self.inherit_lazy_relationships, diff --git a/advanced_alchemy/repository/_sync.py b/advanced_alchemy/repository/_sync.py index 3a3171b9..16c7be6e 100644 --- a/advanced_alchemy/repository/_sync.py +++ b/advanced_alchemy/repository/_sync.py @@ -75,6 +75,7 @@ class SQLAlchemySyncRepositoryProtocol(FilterableRepositoryProtocol[ModelT], Pro auto_commit: bool order_by: list[OrderingPair] | OrderingPair | None = None error_messages: ErrorMessages | None = None + wrap_exceptions: bool = True def __init__( self, @@ -88,6 +89,7 @@ def __init__( execution_options: dict[str, Any] | None = None, order_by: list[OrderingPair] | OrderingPair | None = None, error_messages: ErrorMessages | None | EmptyType = Empty, + wrap_exceptions: bool = True, **kwargs: Any, ) -> None: ... @@ -409,6 +411,8 @@ class SQLAlchemySyncRepository(SQLAlchemySyncRepositoryProtocol[ModelT], Filtera """Default loader options for the repository.""" error_messages: ErrorMessages | None = None """Default error messages for the repository.""" + wrap_exceptions: bool = True + """Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.""" inherit_lazy_relationships: bool = True """Optionally ignore the default ``lazy`` configuration for model relationships. This is useful for when you want to replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration.""" @@ -437,6 +441,7 @@ def __init__( error_messages: ErrorMessages | None | EmptyType = Empty, load: LoadSpec | None = None, execution_options: dict[str, Any] | None = None, + wrap_exceptions: bool = True, **kwargs: Any, ) -> None: """Repository for SQLAlchemy models. @@ -451,6 +456,7 @@ def __init__( load: Set default relationships to be loaded execution_options: Set default execution options error_messages: A set of custom error messages to use for operations + wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised. **kwargs: Additional arguments. """ @@ -462,6 +468,7 @@ def __init__( self.error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages ) + self.wrap_exceptions = wrap_exceptions self._default_loader_options, self._loader_options_have_wildcards = get_abstract_loader_options( loader_options=load if load is not None else self.loader_options, inherit_lazy_relationships=self.inherit_lazy_relationships, diff --git a/advanced_alchemy/repository/memory/_async.py b/advanced_alchemy/repository/memory/_async.py index e6490c43..80a4c5e6 100644 --- a/advanced_alchemy/repository/memory/_async.py +++ b/advanced_alchemy/repository/memory/_async.py @@ -81,6 +81,7 @@ class SQLAlchemyAsyncMockRepository(SQLAlchemyAsyncRepositoryProtocol[ModelT]): "order_by", "load", "error_messages", + "wrap_exceptions", } def __init__( @@ -93,6 +94,7 @@ def __init__( auto_commit: bool = False, order_by: list[OrderingPair] | OrderingPair | None = None, error_messages: ErrorMessages | None | EmptyType = Empty, + wrap_exceptions: bool = True, load: LoadSpec | None = None, execution_options: dict[str, Any] | None = None, **kwargs: Any, @@ -103,6 +105,7 @@ def __init__( self.auto_refresh = auto_refresh self.auto_commit = auto_commit self.error_messages = self._get_error_messages(error_messages=error_messages) + self.wrap_exceptions = wrap_exceptions self.order_by = order_by self._dialect: Dialect = create_autospec(Dialect, instance=True) self._dialect.name = "mock" diff --git a/advanced_alchemy/repository/memory/_sync.py b/advanced_alchemy/repository/memory/_sync.py index 2169d660..7b52e0fa 100644 --- a/advanced_alchemy/repository/memory/_sync.py +++ b/advanced_alchemy/repository/memory/_sync.py @@ -82,6 +82,7 @@ class SQLAlchemySyncMockRepository(SQLAlchemySyncRepositoryProtocol[ModelT]): "order_by", "load", "error_messages", + "wrap_exceptions", } def __init__( @@ -94,6 +95,7 @@ def __init__( auto_commit: bool = False, order_by: list[OrderingPair] | OrderingPair | None = None, error_messages: ErrorMessages | None | EmptyType = Empty, + wrap_exceptions: bool = True, load: LoadSpec | None = None, execution_options: dict[str, Any] | None = None, **kwargs: Any, @@ -104,6 +106,7 @@ def __init__( self.auto_refresh = auto_refresh self.auto_commit = auto_commit self.error_messages = self._get_error_messages(error_messages=error_messages) + self.wrap_exceptions = wrap_exceptions self.order_by = order_by self._dialect: Dialect = create_autospec(Dialect, instance=True) self._dialect.name = "mock" diff --git a/advanced_alchemy/service/_async.py b/advanced_alchemy/service/_async.py index 21b07c31..e823b2fc 100644 --- a/advanced_alchemy/service/_async.py +++ b/advanced_alchemy/service/_async.py @@ -108,6 +108,7 @@ def __init__( auto_commit: bool = False, order_by: list[OrderingPair] | OrderingPair | None = None, error_messages: ErrorMessages | None | EmptyType = Empty, + wrap_exceptions: bool = True, load: LoadSpec | None = None, execution_options: dict[str, Any] | None = None, **repo_kwargs: Any, @@ -122,6 +123,7 @@ def __init__( auto_commit: Commit objects before returning. order_by: Set default order options for queries. error_messages: A set of custom error messages to use for operations + wrap_exceptions: Wrap exceptions in a RepositoryError load: Set default relationships to be loaded execution_options: Set default execution options **repo_kwargs: passed as keyword args to repo instantiation. @@ -136,6 +138,7 @@ def __init__( auto_commit=auto_commit, order_by=order_by, error_messages=error_messages, + wrap_exceptions=wrap_exceptions, load=load, execution_options=execution_options, **repo_kwargs, diff --git a/advanced_alchemy/service/_sync.py b/advanced_alchemy/service/_sync.py index 7616f1f1..f30ed52d 100644 --- a/advanced_alchemy/service/_sync.py +++ b/advanced_alchemy/service/_sync.py @@ -122,6 +122,7 @@ def __init__( auto_commit: bool = False, order_by: list[OrderingPair] | OrderingPair | None = None, error_messages: ErrorMessages | None | EmptyType = Empty, + wrap_exceptions: bool = True, load: LoadSpec | None = None, execution_options: dict[str, Any] | None = None, **repo_kwargs: Any, @@ -136,6 +137,7 @@ def __init__( auto_commit: Commit objects before returning. order_by: Set default order options for queries. error_messages: A set of custom error messages to use for operations + wrap_exceptions: Wrap exceptions in a RepositoryError load: Set default relationships to be loaded execution_options: Set default execution options **repo_kwargs: passed as keyword args to repo instantiation. @@ -150,6 +152,7 @@ def __init__( auto_commit=auto_commit, order_by=order_by, error_messages=error_messages, + wrap_exceptions=wrap_exceptions, load=load, execution_options=execution_options, **repo_kwargs, diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py index bc5ac421..171576c9 100644 --- a/tests/unit/test_exceptions.py +++ b/tests/unit/test_exceptions.py @@ -1,6 +1,26 @@ import contextlib import pytest +from sqlalchemy.exc import ( + IntegrityError as SQLAlchemyIntegrityError, +) +from sqlalchemy.exc import ( + InvalidRequestError as SQLAlchemyInvalidRequestError, +) +from sqlalchemy.exc import ( + MultipleResultsFound, + SQLAlchemyError, + StatementError, +) + +from advanced_alchemy.exceptions import ( + DuplicateKeyError, + IntegrityError, + InvalidRequestError, + MultipleResultsFoundError, + RepositoryError, + wrap_sqlalchemy_exception, +) async def test_repo_get_or_create_deprecation() -> None: @@ -9,3 +29,110 @@ async def test_repo_get_or_create_deprecation() -> None: with contextlib.suppress(Exception): raise ConflictError + + +def test_wrap_sqlalchemy_exception_multiple_results_found() -> None: + with pytest.raises(MultipleResultsFoundError), wrap_sqlalchemy_exception(): + raise MultipleResultsFound() + + +@pytest.mark.parametrize("dialect_name", ["postgresql", "sqlite", "mysql"]) +def test_wrap_sqlalchemy_exception_integrity_error_duplicate_key(dialect_name: str) -> None: + error_message = { + "postgresql": 'duplicate key value violates unique constraint "uq_%(table_name)s_%(column_0_name)s"', + "sqlite": "UNIQUE constraint failed: %(table_name)s.%(column_0_name)s", + "mysql": "1062 (23000): Duplicate entry '%(value)s' for key '%(table_name)s.%(column_0_name)s'", + } + with pytest.raises(DuplicateKeyError), wrap_sqlalchemy_exception( + dialect_name=dialect_name, + error_messages={"duplicate_key": error_message[dialect_name]}, + ): + if dialect_name == "postgresql": + exception = SQLAlchemyIntegrityError( + "INSERT INTO table (id) VALUES (1)", + {"table_name": "table", "column_0_name": "id"}, + Exception( + 'duplicate key value violates unique constraint "uq_table_id"\nDETAIL: Key (id)=(1) already exists.', + ), + ) + elif dialect_name == "sqlite": + exception = SQLAlchemyIntegrityError( + "INSERT INTO table (id) VALUES (1)", + {"table_name": "table", "column_0_name": "id"}, + Exception("UNIQUE constraint failed: table.id"), + ) + else: + exception = SQLAlchemyIntegrityError( + "INSERT INTO table (id) VALUES (1)", + {"table_name": "table", "column_0_name": "id", "value": "1"}, + Exception("1062 (23000): Duplicate entry '1' for key 'table.id'"), + ) + + raise exception + + +def test_wrap_sqlalchemy_exception_integrity_error_other() -> None: + with pytest.raises(IntegrityError), wrap_sqlalchemy_exception(): + raise SQLAlchemyIntegrityError("original", {}, Exception("original")) + + +def test_wrap_sqlalchemy_exception_invalid_request_error() -> None: + with pytest.raises(InvalidRequestError), wrap_sqlalchemy_exception(): + raise SQLAlchemyInvalidRequestError("original", {}, Exception("original")) + + +def test_wrap_sqlalchemy_exception_statement_error() -> None: + with pytest.raises(IntegrityError), wrap_sqlalchemy_exception(): + raise StatementError("original", None, {}, Exception("original")) # pyright: ignore[reportArgumentType] + + +def test_wrap_sqlalchemy_exception_sqlalchemy_error() -> None: + with pytest.raises(RepositoryError), wrap_sqlalchemy_exception(): + raise SQLAlchemyError("original") + + +def test_wrap_sqlalchemy_exception_attribute_error() -> None: + with pytest.raises(RepositoryError), wrap_sqlalchemy_exception(): + raise AttributeError("original") + + +def test_wrap_sqlalchemy_exception_no_wrap() -> None: + with pytest.raises(SQLAlchemyError), wrap_sqlalchemy_exception(wrap_exceptions=False): + raise SQLAlchemyError("original") + with pytest.raises(SQLAlchemyIntegrityError), wrap_sqlalchemy_exception(wrap_exceptions=False): + raise SQLAlchemyIntegrityError(statement="select 1", params=None, orig=BaseException()) + with pytest.raises(MultipleResultsFound), wrap_sqlalchemy_exception(wrap_exceptions=False): + raise MultipleResultsFound() + with pytest.raises(SQLAlchemyInvalidRequestError), wrap_sqlalchemy_exception(wrap_exceptions=False): + raise SQLAlchemyInvalidRequestError() + with pytest.raises(AttributeError), wrap_sqlalchemy_exception(wrap_exceptions=False): + raise AttributeError() + + +def test_wrap_sqlalchemy_exception_custom_error_message() -> None: + def custom_message(exc: Exception) -> str: + return f"Custom: {exc}" + + with pytest.raises(RepositoryError) as excinfo, wrap_sqlalchemy_exception( + error_messages={"other": custom_message}, + ): + raise SQLAlchemyError("original") + + assert str(excinfo.value) == "Custom: original" + + +def test_wrap_sqlalchemy_exception_no_error_messages() -> None: + with pytest.raises(RepositoryError) as excinfo, wrap_sqlalchemy_exception(): + raise SQLAlchemyError("original") + + assert str(excinfo.value) == "An exception occurred: original" + + +def test_wrap_sqlalchemy_exception_no_match() -> None: + with pytest.raises(IntegrityError) as excinfo, wrap_sqlalchemy_exception( + dialect_name="postgresql", + error_messages={"integrity": "Integrity error"}, + ): + raise SQLAlchemyIntegrityError("original", {}, Exception("original")) + + assert str(excinfo.value) == "Integrity error"