diff --git a/.pylintrc b/.pylintrc index 588b150..fd798d3 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,2 +1,2 @@ [MASTER] -disable=fixme,missing-function-docstring,missing-module-docstring +disable=fixme,missing-function-docstring,missing-module-docstring,too-few-public-methods \ No newline at end of file diff --git a/src/destinations/dune.py b/src/destinations/dune.py index 0d7bada..c598930 100644 --- a/src/destinations/dune.py +++ b/src/destinations/dune.py @@ -22,6 +22,7 @@ class DuneDestination(Destination[DataFrame]): def __init__(self, api_key: str, table_name: str): self.client = DuneClient(api_key) self.table_name: str = table_name + super().__init__() def validate(self) -> bool: # Nothing I can think of to validate here... diff --git a/src/destinations/postgres.py b/src/destinations/postgres.py index 302b36d..5ea3a46 100644 --- a/src/destinations/postgres.py +++ b/src/destinations/postgres.py @@ -27,6 +27,7 @@ def __init__( self.engine: sqlalchemy.engine.Engine = create_engine(db_url) self.table_name: str = table_name self.if_exists: TableExistsPolicy = if_exists + super().__init__() def validate(self) -> bool: # Nothing I can think of to validate here... diff --git a/src/interfaces.py b/src/interfaces.py index 78b3ff3..ad4d375 100644 --- a/src/interfaces.py +++ b/src/interfaces.py @@ -10,29 +10,33 @@ T = TypeVar("T") -class Source(ABC, Generic[T]): +class Validate(ABC): + """Enforces validation on inheriting classes""" + + def __init__(self) -> None: + if not self.validate(): + raise ValueError(f"Config for {self.__class__.__name__} is invalid") + + @abstractmethod + def validate(self) -> bool: + """Validate the configuration""" + + +class Source(Validate, Generic[T]): """Abstract base class for data sources""" @abstractmethod def fetch(self) -> T: """Fetch data from the source""" - @abstractmethod - def validate(self) -> bool: - """Validate the source configuration""" - @abstractmethod def is_empty(self, data: T) -> bool: """Return True if the fetched data is empty""" -class Destination(ABC, Generic[T]): +class Destination(Validate, Generic[T]): """Abstract base class for data destinations""" @abstractmethod def save(self, data: T) -> None: """Save data to the destination""" - - @abstractmethod - def validate(self) -> bool: - """Validate the destination configuration""" diff --git a/src/sources/dune.py b/src/sources/dune.py index 52487ae..785fa33 100644 --- a/src/sources/dune.py +++ b/src/sources/dune.py @@ -70,6 +70,7 @@ def __init__( self.query = query self.poll_frequency = poll_frequency self.client = DuneClient(api_key, performance=query_engine) + super().__init__() def validate(self) -> bool: # Nothing I can think of to validate here... diff --git a/src/sources/postgres.py b/src/sources/postgres.py index d5f4481..6216321 100644 --- a/src/sources/postgres.py +++ b/src/sources/postgres.py @@ -36,7 +36,7 @@ def __init__(self, db_url: str, query_string: str): self.engine: sqlalchemy.engine.Engine = create_engine(db_url) self.query_string = "" self._set_query_string(query_string) - self.validate() + super().__init__() def validate(self) -> bool: try: diff --git a/tests/fixtures/config/basic.yaml b/tests/fixtures/config/basic.yaml index 8dc8554..99e0302 100644 --- a/tests/fixtures/config/basic.yaml +++ b/tests/fixtures/config/basic.yaml @@ -26,7 +26,7 @@ jobs: ref: postgres table_name: foo_table if_exists: append - query_string: SELECT * FROM foo; + query_string: SELECT 1; destination: ref: dune table_name: table_name diff --git a/tests/unit/sources_test.py b/tests/unit/sources_test.py index b39c063..0de59e5 100644 --- a/tests/unit/sources_test.py +++ b/tests/unit/sources_test.py @@ -3,13 +3,14 @@ from unittest.mock import patch import pandas as pd +import sqlalchemy from dune_client.models import ExecutionResult, ResultMetadata from sqlalchemy import BIGINT from sqlalchemy.dialects.postgresql import BYTEA from src.config import RuntimeConfig from src.sources.dune import _reformat_varbinary_columns, dune_result_to_df -from src.sources.postgres import _convert_bytea_to_hex +from src.sources.postgres import PostgresSource, _convert_bytea_to_hex from tests import fixtures_root, config_root @@ -70,14 +71,21 @@ def test_convert_bytea_to_hex(self): class TestPostgresSource(unittest.TestCase): - @patch.dict( - os.environ, - { - "DUNE_API_KEY": "test_key", - "DB_URL": "postgresql://postgres:postgres@localhost:5432/postgres", - }, - clear=True, - ) + @classmethod + def setUpClass(cls): + cls.env_patcher = patch.dict( + os.environ, + { + "DUNE_API_KEY": "test_key", + "DB_URL": "postgresql://postgres:postgres@localhost:5432/postgres", + }, + clear=True, + ) + cls.env_patcher.start() + + # TODO: This test is a Config loader test not directly testing PostgresSource + # When changing it to call PGSource directly, yields a bug with the constructor. + # The constructor only accepts string input, not Path! def test_load_sql_file(self): os.chdir(fixtures_root) @@ -88,3 +96,30 @@ def test_load_sql_file(self): missing_file.unlink(missing_ok=True) with self.assertRaises(RuntimeError): RuntimeConfig.load_from_yaml(config_root / "invalid_sql_file.yaml") + + def test_invalid_query_string(self): + with self.assertRaises(ValueError) as context: + PostgresSource( + db_url=os.environ["DB_URL"], + query_string="SELECT * FROM does_not_exist", + ) + self.assertEqual("Config for PostgresSource is invalid", str(context.exception)) + + def test_invalid_connection_string(self): + with self.assertRaises(sqlalchemy.exc.ArgumentError) as context: + PostgresSource( + db_url="invalid connection string", + query_string="SELECT 1", + ) + self.assertEqual( + "Could not parse SQLAlchemy URL from string 'invalid connection string'", + str(context.exception), + ) + + def test_invalid_db_url(self): + with self.assertRaises(ValueError) as context: + PostgresSource( + db_url="postgresql://postgres:BAD_PASSWORD@localhost:5432/postgres", + query_string="SELECT 1", + ) + self.assertEqual("Config for PostgresSource is invalid", str(context.exception))