Skip to content

Commit

Permalink
Testing improvements (#320)
Browse files Browse the repository at this point in the history
* Use temporary directories for test databases

* Stop aliasing imports
  • Loading branch information
amyreese authored Feb 3, 2025
1 parent b5ddd85 commit 8a95cd3
Showing 1 changed file with 34 additions and 36 deletions.
70 changes: 34 additions & 36 deletions aiosqlite/tests/smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,32 @@
import sqlite3
from pathlib import Path
from sqlite3 import OperationalError
from tempfile import TemporaryDirectory
from threading import Thread
from unittest import IsolatedAsyncioTestCase as TestCase, SkipTest
from unittest import IsolatedAsyncioTestCase, SkipTest

import aiosqlite
from .helpers import setup_logger

TEST_DB = Path("test.db")

# pypy uses non-standard text factory for low-level sqlite implementation
try:
from _sqlite3 import _unicode_text_factory as default_text_factory
except ImportError:
default_text_factory = str


class SmokeTest(TestCase):
class SmokeTest(IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
setup_logger()

def setUp(self):
if TEST_DB.exists():
TEST_DB.unlink()

def tearDown(self):
if TEST_DB.exists():
TEST_DB.unlink()
td = TemporaryDirectory()
self.addCleanup(td.cleanup)
self.db = Path(td.name).resolve() / "test.db"

async def test_connection_await(self):
db = await aiosqlite.connect(TEST_DB)
db = await aiosqlite.connect(self.db)
self.assertIsInstance(db, aiosqlite.Connection)

async with db.execute("select 1, 2") as cursor:
Expand All @@ -43,21 +39,23 @@ async def test_connection_await(self):
await db.close()

async def test_connection_context(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
self.assertIsInstance(db, aiosqlite.Connection)

async with db.execute("select 1, 2") as cursor:
rows = await cursor.fetchall()
self.assertEqual(rows, [(1, 2)])

async def test_connection_locations(self):
TEST_DB = self.db.as_posix()

class Fake: # pylint: disable=too-few-public-methods
def __str__(self):
return str(TEST_DB)
return TEST_DB

locs = ("test.db", b"test.db", Path("test.db"), Fake())
locs = (Path(TEST_DB), TEST_DB, TEST_DB.encode(), Fake())

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(locs[0]) as db:
await db.execute("create table foo (i integer, k integer)")
await db.execute("insert into foo (i, k) values (1, 5)")
await db.commit()
Expand All @@ -71,27 +69,27 @@ def __str__(self):
self.assertEqual(await cursor.fetchall(), rows)

async def test_multiple_connections(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.execute(
"create table multiple_connections "
"(i integer primary key asc, k integer)"
)

async def do_one_conn(i):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.execute("insert into multiple_connections (k) values (?)", [i])
await db.commit()

await asyncio.gather(*[do_one_conn(i) for i in range(10)])

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.execute("select * from multiple_connections")
rows = await cursor.fetchall()

assert len(rows) == 10

async def test_multiple_queries(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.execute(
"create table multiple_queries "
"(i integer primary key asc, k integer)"
Expand All @@ -106,14 +104,14 @@ async def test_multiple_queries(self):

await db.commit()

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.execute("select * from multiple_queries")
rows = await cursor.fetchall()

assert len(rows) == 10

async def test_iterable_cursor(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.cursor()
await cursor.execute(
"create table iterable_cursor " "(i integer primary key asc, k integer)"
Expand All @@ -123,7 +121,7 @@ async def test_iterable_cursor(self):
)
await db.commit()

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.execute("select * from iterable_cursor")
rows = []
async for row in cursor:
Expand Down Expand Up @@ -165,7 +163,7 @@ async def query():
self.assertEqual(len(rows), 2)

async def test_context_cursor(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
async with db.cursor() as cursor:
await cursor.execute(
"create table context_cursor "
Expand All @@ -177,7 +175,7 @@ async def test_context_cursor(self):
)
await db.commit()

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
async with db.execute("select * from context_cursor") as cursor:
rows = []
async for row in cursor:
Expand All @@ -186,7 +184,7 @@ async def test_context_cursor(self):
assert len(rows) == 10

async def test_cursor_return_self(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.cursor()

result = await cursor.execute(
Expand All @@ -207,7 +205,7 @@ async def test_cursor_return_self(self):
self.assertEqual(result, cursor)

async def test_connection_properties(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
self.assertEqual(db.total_changes, 0)

async with db.cursor() as cursor:
Expand Down Expand Up @@ -262,7 +260,7 @@ async def test_connection_properties(self):
self.assertEqual(row["d"], b"hi")

async def test_fetch_all(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.execute(
"create table test_fetch_all (i integer primary key asc, k integer)"
)
Expand All @@ -271,14 +269,14 @@ async def test_fetch_all(self):
)
await db.commit()

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.execute("select k from test_fetch_all where k < 30")
rows = await cursor.fetchall()
self.assertEqual(rows, [(10,), (24,), (16,)])

async def test_enable_load_extension(self):
"""Assert that after enabling extension loading, they can be loaded"""
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
try:
await db.enable_load_extension(True)
await db.load_extension("test")
Expand All @@ -294,7 +292,7 @@ async def test_set_progress_handler(self):
"""
Assert that after setting a progress handler returning 1, DB operations are aborted
"""
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.set_progress_handler(lambda: 1, 1)
with self.assertRaises(OperationalError):
await db.execute(
Expand All @@ -310,7 +308,7 @@ def no_arg():
def one_arg(num):
return num * 2

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.create_function("no_arg", 0, no_arg)
await db.create_function("one_arg", 1, one_arg)

Expand All @@ -331,7 +329,7 @@ async def test_create_function_deterministic(self):
def one_arg(num):
return num * 2

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.create_function("one_arg", 1, one_arg, deterministic=True)
await db.execute("create table foo (id int, bar int)")

Expand All @@ -344,7 +342,7 @@ async def test_set_trace_callback(self):
def callback(statement: str):
statements.append(statement)

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.set_trace_callback(callback)

await db.execute("select 10")
Expand Down Expand Up @@ -379,7 +377,7 @@ async def test_iterdump(self):
)

async def test_cursor_on_closed_connection(self):
db = await aiosqlite.connect(TEST_DB)
db = await aiosqlite.connect(self.db)

cursor = await db.execute("select 1, 2")
await db.close()
Expand All @@ -389,7 +387,7 @@ async def test_cursor_on_closed_connection(self):
await cursor.fetchall()

async def test_cursor_on_closed_connection_loop(self):
db = await aiosqlite.connect(TEST_DB)
db = await aiosqlite.connect(self.db)

cursor = await db.execute("select 1, 2")
tasks = []
Expand All @@ -404,7 +402,7 @@ async def test_cursor_on_closed_connection_loop(self):
pass

async def test_close_twice(self):
db = await aiosqlite.connect(TEST_DB)
db = await aiosqlite.connect(self.db)

await db.close()

Expand Down

0 comments on commit 8a95cd3

Please sign in to comment.