Skip to content

Commit

Permalink
Finished testing readers
Browse files Browse the repository at this point in the history
  • Loading branch information
lewis-chambers committed Sep 16, 2024
1 parent f03b9a1 commit fc6deef
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 43 deletions.
44 changes: 27 additions & 17 deletions src/driutils/io/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ def __enter__(self) -> Self:

def __exit__(self, *args) -> None:
"""Closes the connection when exiting the context"""
self._connection.close()
self.close()

def __del__(self):
"""Closes the connection when deleted"""
self.close()

def close(self) -> None:
"""Closes the connection"""
self._connection.close()

@abstractmethod
Expand All @@ -49,10 +53,6 @@ def read(self, query: str, params: Optional[List] = None) -> DuckDBPyConnection:

return self._connection.execute(query, params)

def close(self) -> None:
"""Close the connection"""
self._connection.close()


class DuckDBS3Reader(DuckDBReader):
"""Concrete Implementation of a DuckDB reader for reading
Expand All @@ -70,36 +70,46 @@ def __init__(self, auth_type: str, endpoint_url: Optional[str] = None, use_ssl:

super().__init__()

auth_type = auth_type.lower()
auth_type = str(auth_type).lower()

VALID_AUTH_METHODS = ["auto", "sts", "custom_endpoint"]

if auth_type not in VALID_AUTH_METHODS:
raise ValueError(f"Invalid `auth_type`, must be one of {VALID_AUTH_METHODS}")

self._connection.install_extension("httpfs")
self._connection.load_extension("httpfs")
self._connection.execute("""
INSTALL httpfs;
LOAD httpfs;
SET force_download = true;
SET http_keep_alive = false;
""")

if auth_type == "auto":
self._authenticate(auth_type, endpoint_url, use_ssl)

def _authenticate(self, method: str, endpoint_url: Optional[str] = None, use_ssl: Optional[bool] = None) -> None:
"""Handles authentication selection
Args:
method: method of authentication used
endpoint_url: Custom s3 endpoint
use_ssl: Flag for using ssl (https connections)
"""
if method == "auto":
self._auto_auth()
elif auth_type == "sts":
elif method == "sts":
self._sts_auth()
elif auth_type == "custom_endpoint":
if not isinstance(endpoint_url, str):
endpoint_url = str(endpoint_url)
elif method == "custom_endpoint":
if not endpoint_url:
raise ValueError("`endpoint_url` must be provided for `custom_endpoint` authentication")

self._custom_endpoint_auth(endpoint_url, use_ssl)

def _auto_auth(self) -> None:
"""Automatically authenticates using environment variables"""

self._connection.install_extension("aws")
self._connection.load_extension("aws")
self._connection.execute("""
INSTALL aws;
LOAD aws;
CREATE SECRET (
TYPE S3,
PROVIDER CREDENTIAL_CHAIN
Expand All @@ -108,10 +118,10 @@ def _auto_auth(self) -> None:

def _sts_auth(self) -> None:
"""Authenicates using assumed roles on AWS"""
self._connection.install_extension("aws")
self._connection.load_extension("aws")

self._connection.execute("""
INSTALL aws;
LOAD aws;
CREATE SECRET (
TYPE S3,
PROVIDER CREDENTIAL_CHAIN,
Expand Down
94 changes: 68 additions & 26 deletions tests/io/test_readers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
from unittest.mock import patch, MagicMock
from driutils.io.read import DuckDBFileReader
from driutils.io.read import DuckDBFileReader, DuckDBS3Reader
from duckdb import DuckDBPyConnection
from parameterized import parameterized

class TestDuckDBFileReader(unittest.TestCase):

Expand All @@ -11,50 +12,91 @@ def test_initialization(self):

self.assertIsInstance(reader._connection, DuckDBPyConnection)

def test_context_manager_is_function(self):
@patch("driutils.io.read.DuckDBFileReader.close")
def test_context_manager_is_functional(self, mock):
"""Should be able to use context manager to auto-close file connection"""

mock = MagicMock()

with DuckDBFileReader() as con:
self.assertIsInstance(con._connection, DuckDBPyConnection)

con._connection = mock

self.assertTrue(mock.close.called)
mock.assert_called_once()

def test_connection_closed_on_delete(self):
@patch("driutils.io.read.DuckDBFileReader.close")
def test_connection_closed_on_delete(self, mock):
"""Tests that duckdb connection is closed when object is deleted"""
assert False

reader = DuckDBFileReader()
del reader
mock.assert_called_once()

def test_close_method_closes_connection(self):
"""Tests that the .close() method closes the connection"""
assert False

reader = DuckDBFileReader()
reader._connection = MagicMock()

reader.close()

reader._connection.close.assert_called()

def test_read_executes_query(self):
"""Tests that the .read() method executes a query"""
assert False

reader = DuckDBFileReader()

reader._connection = MagicMock()

query = "read this plz"
params = ["param1", "param2"]

reader.read(query, params)

reader._connection.execute.assert_called_once_with(query, params)

class TestDuckDBS3Reader(unittest.TestCase):

def test_value_error_if_invalid_auth_option(self):
@parameterized.expand(["a", 1, "cutom_endpoint"])
def test_value_error_if_invalid_auth_option(self, value):
"""Test that a ValueError is raised if a bad auth option is selected"""
assert False

with self.assertRaises(ValueError):
DuckDBS3Reader(value)

def test_init_auto_authentication(self):
@parameterized.expand(["auto", "AUTO", "aUtO"])
@patch("driutils.io.read.DuckDBS3Reader._authenticate")
def test_upper_or_lowercase_option_accepted(self, value, mock):
"""Tests that the auth options can be provided in any case"""
DuckDBS3Reader(value)

mock.assert_called_once()

@patch.object(DuckDBS3Reader, "_auto_auth", side_effect=DuckDBS3Reader._auto_auth, autospec=True)
def test_init_auto_authentication(self, mock):
"""Tests that the reader can use the 'auto' auth option"""
assert False

def test_init_sts_authentication(self):
"""Tests that the reader can use the 'sts' auth option"""
assert False
DuckDBS3Reader("auto")
mock.assert_called_once()

def test_init_custom_endpoint_authentication_https(self):
@patch.object(DuckDBS3Reader, "_sts_auth", side_effect=DuckDBS3Reader._sts_auth, autospec=True)
def test_init_sts_authentication(self, mock):
"""Tests that the reader can use the 'sts' auth option"""
DuckDBS3Reader("sts")
mock.assert_called_once()

@parameterized.expand([
["https://s3-a-real-endpoint", True],
["http://localhost:8080", False]
])
@patch.object(DuckDBS3Reader, "_custom_endpoint_auth", wraps=DuckDBS3Reader._custom_endpoint_auth, autospec=True)
def test_init_custom_endpoint_authentication_https(self, url, ssl, mock):
"""Tests that the reader can authenticate to a custom endpoint
with https protocol"""
assert False

def test_init_custom_endpoint_authentication_http(self):
"""Tests that the reader can authenticate to a custom endpoint
with http protocol"""
assert False

reader = DuckDBS3Reader("custom_endpoint", url, ssl)
mock.assert_called_once_with(reader, url, ssl)

def test_error_if_custom_endpoint_not_provided(self):
"""Test that an error is raised if custom_endpoint authentication used but
endpoint_url_not_given"""

with self.assertRaises(ValueError):
DuckDBS3Reader("custom_endpoint")

0 comments on commit fc6deef

Please sign in to comment.