From bd42ce6eb11695cbf4efbfa6f9b67e3f93fe5d42 Mon Sep 17 00:00:00 2001 From: Amethyst Reese Date: Sun, 6 Oct 2024 20:26:46 -0700 Subject: [PATCH 1/2] Use temporary directories for test databases --- aiosqlite/tests/smoke.py | 66 +++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/aiosqlite/tests/smoke.py b/aiosqlite/tests/smoke.py index 499b287..26aec0a 100644 --- a/aiosqlite/tests/smoke.py +++ b/aiosqlite/tests/smoke.py @@ -6,12 +6,11 @@ from sqlite3 import OperationalError from threading import Thread from unittest import IsolatedAsyncioTestCase as TestCase, SkipTest +from tempfile import TemporaryDirectory import aiosqlite from .helpers import setup_logger -TEST_DB = Path("test.db") - # pypy uses non-standard text factory for low-level sqlite implementation try: from _sqlite3 import _unicode_text_factory as default_text_factory @@ -25,15 +24,12 @@ def setUpClass(cls): setup_logger() def setUp(self): - if TEST_DB.exists(): - TEST_DB.unlink() - - def tearDown(self): - if TEST_DB.exists(): - TEST_DB.unlink() + td = TemporaryDirectory() + self.addCleanup(td.cleanup) + self.db = Path(td.name).resolve() / "test.db" async def test_connection_await(self): - db = await aiosqlite.connect(TEST_DB) + db = await aiosqlite.connect(self.db) self.assertIsInstance(db, aiosqlite.Connection) async with db.execute("select 1, 2") as cursor: @@ -43,7 +39,7 @@ async def test_connection_await(self): await db.close() async def test_connection_context(self): - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: self.assertIsInstance(db, aiosqlite.Connection) async with db.execute("select 1, 2") as cursor: @@ -51,13 +47,15 @@ async def test_connection_context(self): self.assertEqual(rows, [(1, 2)]) async def test_connection_locations(self): + TEST_DB = self.db.as_posix() + class Fake: # pylint: disable=too-few-public-methods def __str__(self): - return str(TEST_DB) + return TEST_DB - locs = ("test.db", b"test.db", Path("test.db"), Fake()) + locs = (Path(TEST_DB), TEST_DB, TEST_DB.encode(), Fake()) - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(locs[0]) as db: await db.execute("create table foo (i integer, k integer)") await db.execute("insert into foo (i, k) values (1, 5)") await db.commit() @@ -71,27 +69,27 @@ def __str__(self): self.assertEqual(await cursor.fetchall(), rows) async def test_multiple_connections(self): - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: await db.execute( "create table multiple_connections " "(i integer primary key asc, k integer)" ) async def do_one_conn(i): - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: await db.execute("insert into multiple_connections (k) values (?)", [i]) await db.commit() await asyncio.gather(*[do_one_conn(i) for i in range(10)]) - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: cursor = await db.execute("select * from multiple_connections") rows = await cursor.fetchall() assert len(rows) == 10 async def test_multiple_queries(self): - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: await db.execute( "create table multiple_queries " "(i integer primary key asc, k integer)" @@ -106,14 +104,14 @@ async def test_multiple_queries(self): await db.commit() - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: cursor = await db.execute("select * from multiple_queries") rows = await cursor.fetchall() assert len(rows) == 10 async def test_iterable_cursor(self): - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: cursor = await db.cursor() await cursor.execute( "create table iterable_cursor " "(i integer primary key asc, k integer)" @@ -123,7 +121,7 @@ async def test_iterable_cursor(self): ) await db.commit() - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: cursor = await db.execute("select * from iterable_cursor") rows = [] async for row in cursor: @@ -165,7 +163,7 @@ async def query(): self.assertEqual(len(rows), 2) async def test_context_cursor(self): - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: async with db.cursor() as cursor: await cursor.execute( "create table context_cursor " @@ -177,7 +175,7 @@ async def test_context_cursor(self): ) await db.commit() - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: async with db.execute("select * from context_cursor") as cursor: rows = [] async for row in cursor: @@ -186,7 +184,7 @@ async def test_context_cursor(self): assert len(rows) == 10 async def test_cursor_return_self(self): - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: cursor = await db.cursor() result = await cursor.execute( @@ -207,7 +205,7 @@ async def test_cursor_return_self(self): self.assertEqual(result, cursor) async def test_connection_properties(self): - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: self.assertEqual(db.total_changes, 0) async with db.cursor() as cursor: @@ -262,7 +260,7 @@ async def test_connection_properties(self): self.assertEqual(row["d"], b"hi") async def test_fetch_all(self): - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: await db.execute( "create table test_fetch_all (i integer primary key asc, k integer)" ) @@ -271,14 +269,14 @@ async def test_fetch_all(self): ) await db.commit() - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: cursor = await db.execute("select k from test_fetch_all where k < 30") rows = await cursor.fetchall() self.assertEqual(rows, [(10,), (24,), (16,)]) async def test_enable_load_extension(self): """Assert that after enabling extension loading, they can be loaded""" - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: try: await db.enable_load_extension(True) await db.load_extension("test") @@ -294,7 +292,7 @@ async def test_set_progress_handler(self): """ Assert that after setting a progress handler returning 1, DB operations are aborted """ - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: await db.set_progress_handler(lambda: 1, 1) with self.assertRaises(OperationalError): await db.execute( @@ -310,7 +308,7 @@ def no_arg(): def one_arg(num): return num * 2 - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: await db.create_function("no_arg", 0, no_arg) await db.create_function("one_arg", 1, one_arg) @@ -331,7 +329,7 @@ async def test_create_function_deterministic(self): def one_arg(num): return num * 2 - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: await db.create_function("one_arg", 1, one_arg, deterministic=True) await db.execute("create table foo (id int, bar int)") @@ -344,7 +342,7 @@ async def test_set_trace_callback(self): def callback(statement: str): statements.append(statement) - async with aiosqlite.connect(TEST_DB) as db: + async with aiosqlite.connect(self.db) as db: await db.set_trace_callback(callback) await db.execute("select 10") @@ -379,7 +377,7 @@ async def test_iterdump(self): ) async def test_cursor_on_closed_connection(self): - db = await aiosqlite.connect(TEST_DB) + db = await aiosqlite.connect(self.db) cursor = await db.execute("select 1, 2") await db.close() @@ -389,7 +387,7 @@ async def test_cursor_on_closed_connection(self): await cursor.fetchall() async def test_cursor_on_closed_connection_loop(self): - db = await aiosqlite.connect(TEST_DB) + db = await aiosqlite.connect(self.db) cursor = await db.execute("select 1, 2") tasks = [] @@ -404,7 +402,7 @@ async def test_cursor_on_closed_connection_loop(self): pass async def test_close_twice(self): - db = await aiosqlite.connect(TEST_DB) + db = await aiosqlite.connect(self.db) await db.close() From c4525f7f7cf17365ad20ed49526ef6b5f8f68544 Mon Sep 17 00:00:00 2001 From: Amethyst Reese Date: Sun, 6 Oct 2024 20:39:43 -0700 Subject: [PATCH 2/2] Stop aliasing imports --- aiosqlite/tests/smoke.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aiosqlite/tests/smoke.py b/aiosqlite/tests/smoke.py index 26aec0a..9391100 100644 --- a/aiosqlite/tests/smoke.py +++ b/aiosqlite/tests/smoke.py @@ -4,9 +4,9 @@ import sqlite3 from pathlib import Path from sqlite3 import OperationalError -from threading import Thread -from unittest import IsolatedAsyncioTestCase as TestCase, SkipTest from tempfile import TemporaryDirectory +from threading import Thread +from unittest import IsolatedAsyncioTestCase, SkipTest import aiosqlite from .helpers import setup_logger @@ -18,7 +18,7 @@ default_text_factory = str -class SmokeTest(TestCase): +class SmokeTest(IsolatedAsyncioTestCase): @classmethod def setUpClass(cls): setup_logger()