Skip to content

Commit

Permalink
Refactor to consolidate async dbt adapter code (#1509)
Browse files Browse the repository at this point in the history
This PR addresses the feedback in PR #1474 suggesting to consolidate dbt
adapter methods
in the _asynchronous module and additionally avoiding keeping dictionary
type maps for adapter
callables. 

closes: #1508
  • Loading branch information
pankajkoti authored Feb 7, 2025
1 parent 03358bf commit bdb77e4
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 138 deletions.
Empty file added cosmos/_utils/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions cosmos/_utils/importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import importlib
from typing import Any, Callable


def load_method_from_module(module_path: str, method_name: str) -> Callable[..., Any]:
try:
module = importlib.import_module(module_path)
method = getattr(module, method_name)
return method # type: ignore
except ModuleNotFoundError:
raise ModuleNotFoundError(f"Module {module_path} not found")
except AttributeError:
raise AttributeError(f"Method {method_name} not found in module {module_path}")
18 changes: 0 additions & 18 deletions cosmos/dbt_adapters/__init__.py

This file was deleted.

33 changes: 0 additions & 33 deletions cosmos/dbt_adapters/bigquery.py

This file was deleted.

29 changes: 29 additions & 0 deletions cosmos/operators/_asynchronous/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,40 @@
from cosmos import settings
from cosmos.config import ProfileConfig
from cosmos.dataset import get_dataset_alias_name
from cosmos.exceptions import CosmosValueError
from cosmos.operators.local import AbstractDbtLocalBase

AIRFLOW_VERSION = Version(airflow.__version__)


def _mock_bigquery_adapter() -> None:
from typing import Optional, Tuple

import agate
from dbt.adapters.bigquery.connections import BigQueryAdapterResponse, BigQueryConnectionManager
from dbt_common.clients.agate_helper import empty_table

def execute( # type: ignore[no-untyped-def]
self, sql, auto_begin=False, fetch=None, limit: Optional[int] = None
) -> Tuple[BigQueryAdapterResponse, agate.Table]:
return BigQueryAdapterResponse("mock_bigquery_adapter_response"), empty_table()

BigQueryConnectionManager.execute = execute


def _configure_bigquery_async_op_args(async_op_obj: Any, **kwargs: Any) -> Any:
sql = kwargs.get("sql")
if not sql:
raise CosmosValueError("Keyword argument 'sql' is required for BigQuery Async operator")
async_op_obj.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return async_op_obj


class DbtRunAirflowAsyncBigqueryOperator(BigQueryInsertJobOperator, AbstractDbtLocalBase): # type: ignore[misc]

template_fields: Sequence[str] = (
Expand Down
14 changes: 9 additions & 5 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from packaging.version import Version

from cosmos import cache, settings
from cosmos._utils.importer import load_method_from_module
from cosmos.cache import (
_copy_cached_package_lockfile_to_project,
_get_latest_cached_package_lockfile,
Expand Down Expand Up @@ -68,7 +69,6 @@
parse_number_of_warnings_subprocess,
)
from cosmos.dbt.project import create_symlinks
from cosmos.dbt_adapters import PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP, associate_async_operator_args
from cosmos.hooks.subprocess import (
FullOutputSubprocessHook,
FullOutputSubprocessResult,
Expand Down Expand Up @@ -444,9 +444,9 @@ def _mock_dbt_adapter(async_context: dict[str, Any] | None) -> None:
if "profile_type" not in async_context:
raise CosmosValueError("`profile_type` needs to be specified in `async_context` when running as async")
profile_type = async_context["profile_type"]
if profile_type not in PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP:
raise CosmosValueError(f"Mock adapter callable function not available for profile_type {profile_type}")
mock_adapter_callable = PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP[profile_type]
module_path = f"cosmos.operators._asynchronous.{profile_type}"
method_name = f"_mock_{profile_type}_adapter"
mock_adapter_callable = load_method_from_module(module_path, method_name)
mock_adapter_callable()

def _handle_datasets(self, context: Context) -> None:
Expand All @@ -473,7 +473,11 @@ def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None

def _handle_async_execution(self, tmp_project_dir: str, context: Context, async_context: dict[str, Any]) -> None:
sql = self._read_run_sql_from_target_dir(tmp_project_dir, async_context)
associate_async_operator_args(self, async_context["profile_type"], sql=sql)
profile_type = async_context["profile_type"]
module_path = f"cosmos.operators._asynchronous.{profile_type}"
method_name = f"_configure_{profile_type}_async_op_args"
async_op_configurator = load_method_from_module(module_path, method_name)
async_op_configurator(self, sql=sql)
async_context["async_operator"].execute(self, context)

def run_command(
Expand Down
45 changes: 0 additions & 45 deletions tests/dbt_adapters/test_bigquery.py

This file was deleted.

15 changes: 0 additions & 15 deletions tests/dbt_adapters/test_init.py

This file was deleted.

46 changes: 44 additions & 2 deletions tests/operators/_asynchronous/test_bigquery.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator

from cosmos.config import ProfileConfig
from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator
from cosmos.exceptions import CosmosValueError
from cosmos.operators._asynchronous.bigquery import (
DbtRunAirflowAsyncBigqueryOperator,
_configure_bigquery_async_op_args,
_mock_bigquery_adapter,
)


@pytest.fixture
Expand Down Expand Up @@ -69,3 +74,40 @@ def test_dbt_run_airflow_async_bigquery_operator_execute(mock_build_and_run_cmd,
"async_operator": BigQueryInsertJobOperator,
},
)


@pytest.fixture
def async_operator_mock():
"""Fixture to create a mock async operator object."""
return Mock()


@pytest.mark.integration
def test_mock_bigquery_adapter():
"""Test _mock_bigquery_adapter to verify it modifies BigQueryConnectionManager.execute."""
from dbt.adapters.bigquery.connections import BigQueryConnectionManager

_mock_bigquery_adapter()

assert hasattr(BigQueryConnectionManager, "execute")

response, table = BigQueryConnectionManager.execute(None, sql="SELECT 1")
assert response._message == "mock_bigquery_adapter_response"
assert table is not None


def test_configure_bigquery_async_op_args_valid(async_operator_mock):
"""Test _configure_bigquery_async_op_args correctly configures the async operator."""
sql_query = "SELECT * FROM test_table"

result = _configure_bigquery_async_op_args(async_operator_mock, sql=sql_query)

assert result == async_operator_mock
assert result.configuration["query"]["query"] == sql_query
assert result.configuration["query"]["useLegacySql"] is False


def test_configure_bigquery_async_op_args_missing_sql(async_operator_mock):
"""Test _configure_bigquery_async_op_args raises CosmosValueError when 'sql' is missing."""
with pytest.raises(CosmosValueError, match="Keyword argument 'sql' is required for BigQuery Async operator"):
_configure_bigquery_async_op_args(async_operator_mock)
29 changes: 9 additions & 20 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,21 +1362,7 @@ def test_upload_compiled_sql_should_upload(mock_configure_remote, mock_object_st
mock_object_storage_path.return_value.copy.assert_any_call(mock_object_storage_path.return_value)


MOCK_ADAPTER_CALLABLE_MAP = {
"snowflake": MagicMock(),
"bigquery": MagicMock(),
}


@pytest.fixture
def mock_adapter_map(monkeypatch):
monkeypatch.setattr(
"cosmos.operators.local.PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP",
MOCK_ADAPTER_CALLABLE_MAP,
)


def test_mock_dbt_adapter_valid_context(mock_adapter_map):
def test_mock_dbt_adapter_valid_context():
"""
Test that the _mock_dbt_adapter method calls the correct mock adapter function
when provided with a valid async_context.
Expand All @@ -1387,9 +1373,12 @@ def test_mock_dbt_adapter_valid_context(mock_adapter_map):
}
AbstractDbtLocalBase.__abstractmethods__ = set()
operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock())
operator._mock_dbt_adapter(async_context)
with patch("cosmos.operators.local.load_method_from_module") as mock_load_method:
operator._mock_dbt_adapter(async_context)

MOCK_ADAPTER_CALLABLE_MAP["bigquery"].assert_called_once()
expected_module_path = "cosmos.operators._asynchronous.bigquery"
expected_method_name = "_mock_bigquery_adapter"
mock_load_method.assert_called_once_with(expected_module_path, expected_method_name)


def test_mock_dbt_adapter_missing_async_context():
Expand Down Expand Up @@ -1433,7 +1422,7 @@ def test_mock_dbt_adapter_missing_profile_type():
operator._mock_dbt_adapter(async_context)


def test_mock_dbt_adapter_unsupported_profile_type(mock_adapter_map):
def test_mock_dbt_adapter_unsupported_profile_type():
"""
Test that the _mock_dbt_adapter method raises a CosmosValueError
when the profile_type is not supported.
Expand All @@ -1445,7 +1434,7 @@ def test_mock_dbt_adapter_unsupported_profile_type(mock_adapter_map):
AbstractDbtLocalBase.__abstractmethods__ = set()
operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock())
with pytest.raises(
CosmosValueError,
match="Mock adapter callable function not available for profile_type unsupported_profile",
ModuleNotFoundError,
match="Module cosmos.operators._asynchronous.unsupported_profile not found",
):
operator._mock_dbt_adapter(async_context)
48 changes: 48 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import sys
import types

import pytest

from cosmos._utils.importer import load_method_from_module

dummy_module_text = """
def dummy_method():
return "Hello from dummy_method"
"""


@pytest.fixture
def create_dummy_module():
"""Creates a temporary in-memory module with a dummy method."""
module_name = "test_dummy_module"
method_name = "dummy_method"

module = types.ModuleType(module_name)
exec(dummy_module_text, module.__dict__)

sys.modules[module_name] = module

yield module_name, method_name

del sys.modules[module_name]


def test_load_valid_method(create_dummy_module):
"""Test that a valid method is loaded successfully."""
module_name, method_name = create_dummy_module
method = load_method_from_module(module_name, method_name)
assert callable(method)
assert method() == "Hello from dummy_method"


def test_load_invalid_module():
"""Test that ModuleNotFoundError is raised for an invalid module."""
with pytest.raises(ModuleNotFoundError, match="Module invalid_module not found"):
load_method_from_module("invalid_module", "dummy_method")


def test_load_invalid_method(create_dummy_module):
"""Test that AttributeError is raised for a missing method in a valid module."""
module_name, _ = create_dummy_module
with pytest.raises(AttributeError, match=f"Method invalid_method not found in module {module_name}"):
load_method_from_module(module_name, "invalid_method")

0 comments on commit bdb77e4

Please sign in to comment.