Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing improvements #320

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 34 additions & 36 deletions aiosqlite/tests/smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,32 @@
import sqlite3
from pathlib import Path
from sqlite3 import OperationalError
from tempfile import TemporaryDirectory
from threading import Thread
from unittest import IsolatedAsyncioTestCase as TestCase, SkipTest
from unittest import IsolatedAsyncioTestCase, SkipTest

import aiosqlite
from .helpers import setup_logger

TEST_DB = Path("test.db")

# pypy uses non-standard text factory for low-level sqlite implementation
try:
from _sqlite3 import _unicode_text_factory as default_text_factory
except ImportError:
default_text_factory = str


class SmokeTest(TestCase):
class SmokeTest(IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
setup_logger()

def setUp(self):
if TEST_DB.exists():
TEST_DB.unlink()

def tearDown(self):
if TEST_DB.exists():
TEST_DB.unlink()
td = TemporaryDirectory()
self.addCleanup(td.cleanup)
self.db = Path(td.name).resolve() / "test.db"

async def test_connection_await(self):
db = await aiosqlite.connect(TEST_DB)
db = await aiosqlite.connect(self.db)
self.assertIsInstance(db, aiosqlite.Connection)

async with db.execute("select 1, 2") as cursor:
Expand All @@ -43,21 +39,23 @@ async def test_connection_await(self):
await db.close()

async def test_connection_context(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
self.assertIsInstance(db, aiosqlite.Connection)

async with db.execute("select 1, 2") as cursor:
rows = await cursor.fetchall()
self.assertEqual(rows, [(1, 2)])

async def test_connection_locations(self):
TEST_DB = self.db.as_posix()

class Fake: # pylint: disable=too-few-public-methods
def __str__(self):
return str(TEST_DB)
return TEST_DB

locs = ("test.db", b"test.db", Path("test.db"), Fake())
locs = (Path(TEST_DB), TEST_DB, TEST_DB.encode(), Fake())

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(locs[0]) as db:
await db.execute("create table foo (i integer, k integer)")
await db.execute("insert into foo (i, k) values (1, 5)")
await db.commit()
Expand All @@ -71,27 +69,27 @@ def __str__(self):
self.assertEqual(await cursor.fetchall(), rows)

async def test_multiple_connections(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.execute(
"create table multiple_connections "
"(i integer primary key asc, k integer)"
)

async def do_one_conn(i):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.execute("insert into multiple_connections (k) values (?)", [i])
await db.commit()

await asyncio.gather(*[do_one_conn(i) for i in range(10)])

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.execute("select * from multiple_connections")
rows = await cursor.fetchall()

assert len(rows) == 10

async def test_multiple_queries(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.execute(
"create table multiple_queries "
"(i integer primary key asc, k integer)"
Expand All @@ -106,14 +104,14 @@ async def test_multiple_queries(self):

await db.commit()

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.execute("select * from multiple_queries")
rows = await cursor.fetchall()

assert len(rows) == 10

async def test_iterable_cursor(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.cursor()
await cursor.execute(
"create table iterable_cursor " "(i integer primary key asc, k integer)"
Expand All @@ -123,7 +121,7 @@ async def test_iterable_cursor(self):
)
await db.commit()

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.execute("select * from iterable_cursor")
rows = []
async for row in cursor:
Expand Down Expand Up @@ -165,7 +163,7 @@ async def query():
self.assertEqual(len(rows), 2)

async def test_context_cursor(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
async with db.cursor() as cursor:
await cursor.execute(
"create table context_cursor "
Expand All @@ -177,7 +175,7 @@ async def test_context_cursor(self):
)
await db.commit()

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
async with db.execute("select * from context_cursor") as cursor:
rows = []
async for row in cursor:
Expand All @@ -186,7 +184,7 @@ async def test_context_cursor(self):
assert len(rows) == 10

async def test_cursor_return_self(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.cursor()

result = await cursor.execute(
Expand All @@ -207,7 +205,7 @@ async def test_cursor_return_self(self):
self.assertEqual(result, cursor)

async def test_connection_properties(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
self.assertEqual(db.total_changes, 0)

async with db.cursor() as cursor:
Expand Down Expand Up @@ -262,7 +260,7 @@ async def test_connection_properties(self):
self.assertEqual(row["d"], b"hi")

async def test_fetch_all(self):
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.execute(
"create table test_fetch_all (i integer primary key asc, k integer)"
)
Expand All @@ -271,14 +269,14 @@ async def test_fetch_all(self):
)
await db.commit()

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
cursor = await db.execute("select k from test_fetch_all where k < 30")
rows = await cursor.fetchall()
self.assertEqual(rows, [(10,), (24,), (16,)])

async def test_enable_load_extension(self):
"""Assert that after enabling extension loading, they can be loaded"""
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
try:
await db.enable_load_extension(True)
await db.load_extension("test")
Expand All @@ -294,7 +292,7 @@ async def test_set_progress_handler(self):
"""
Assert that after setting a progress handler returning 1, DB operations are aborted
"""
async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.set_progress_handler(lambda: 1, 1)
with self.assertRaises(OperationalError):
await db.execute(
Expand All @@ -310,7 +308,7 @@ def no_arg():
def one_arg(num):
return num * 2

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.create_function("no_arg", 0, no_arg)
await db.create_function("one_arg", 1, one_arg)

Expand All @@ -331,7 +329,7 @@ async def test_create_function_deterministic(self):
def one_arg(num):
return num * 2

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.create_function("one_arg", 1, one_arg, deterministic=True)
await db.execute("create table foo (id int, bar int)")

Expand All @@ -344,7 +342,7 @@ async def test_set_trace_callback(self):
def callback(statement: str):
statements.append(statement)

async with aiosqlite.connect(TEST_DB) as db:
async with aiosqlite.connect(self.db) as db:
await db.set_trace_callback(callback)

await db.execute("select 10")
Expand Down Expand Up @@ -379,7 +377,7 @@ async def test_iterdump(self):
)

async def test_cursor_on_closed_connection(self):
db = await aiosqlite.connect(TEST_DB)
db = await aiosqlite.connect(self.db)

cursor = await db.execute("select 1, 2")
await db.close()
Expand All @@ -389,7 +387,7 @@ async def test_cursor_on_closed_connection(self):
await cursor.fetchall()

async def test_cursor_on_closed_connection_loop(self):
db = await aiosqlite.connect(TEST_DB)
db = await aiosqlite.connect(self.db)

cursor = await db.execute("select 1, 2")
tasks = []
Expand All @@ -404,7 +402,7 @@ async def test_cursor_on_closed_connection_loop(self):
pass

async def test_close_twice(self):
db = await aiosqlite.connect(TEST_DB)
db = await aiosqlite.connect(self.db)

await db.close()

Expand Down