From ed5cce7584d9f903d707e4e342ac272acc9680d0 Mon Sep 17 00:00:00 2001 From: Gabriel Pajot Date: Mon, 27 Jan 2025 13:31:34 +0100 Subject: [PATCH] fix: close connection thread properly if BaseException raised in connect step --- aiosqlite/core.py | 2 +- aiosqlite/tests/smoke.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/aiosqlite/core.py b/aiosqlite/core.py index 58c3fec..61f5db7 100644 --- a/aiosqlite/core.py +++ b/aiosqlite/core.py @@ -138,7 +138,7 @@ async def _connect(self) -> "Connection": future = asyncio.get_event_loop().create_future() self._tx.put_nowait((future, self._connector)) self._connection = await future - except Exception: + except BaseException: self._stop_running() self._connection = None raise diff --git a/aiosqlite/tests/smoke.py b/aiosqlite/tests/smoke.py index 6fd2566..57c1beb 100644 --- a/aiosqlite/tests/smoke.py +++ b/aiosqlite/tests/smoke.py @@ -7,6 +7,7 @@ from tempfile import TemporaryDirectory from threading import Thread from unittest import IsolatedAsyncioTestCase, SkipTest +from unittest.mock import patch import aiosqlite from .helpers import setup_logger @@ -351,6 +352,23 @@ async def test_connect_error(self): with self.assertRaisesRegex(OperationalError, "unable to open database"): await aiosqlite.connect(bad_db) + async def test_connect_base_exception(self): + # Check if connect task is cancelled, thread is properly closed. + def _raise_cancelled_error(*_, **__): + raise asyncio.CancelledError("I changed my mind") + + connection = aiosqlite.Connection(lambda: sqlite3.connect(":memory:"), 64) + with ( + patch.object(sqlite3, "connect", side_effect=_raise_cancelled_error), + self.assertRaisesRegex(asyncio.CancelledError, "I changed my mind"), + ): + async with connection: + ... + # Terminate the thread here if the test fails to have a clear error. + if connection._running: + connection._stop_running() + raise AssertionError("connection thread was not stopped") + async def test_iterdump(self): async with aiosqlite.connect(":memory:") as db: await db.execute("create table foo (i integer, k charvar(250))")