Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to consolidate async dbt adapter code #1509

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added cosmos/_utils/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions cosmos/_utils/importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import importlib
from typing import Any, Callable


def load_method_from_module(module_path: str, method_name: str) -> Callable[..., Any]:
try:
module = importlib.import_module(module_path)
method = getattr(module, method_name)
return method # type: ignore
except ModuleNotFoundError:
raise ModuleNotFoundError(f"Module {module_path} not found")
except AttributeError:
raise AttributeError(f"Method {method_name} not found in module {module_path}")
18 changes: 0 additions & 18 deletions cosmos/dbt_adapters/__init__.py

This file was deleted.

33 changes: 0 additions & 33 deletions cosmos/dbt_adapters/bigquery.py

This file was deleted.

29 changes: 29 additions & 0 deletions cosmos/operators/_asynchronous/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,40 @@
from cosmos import settings
from cosmos.config import ProfileConfig
from cosmos.dataset import get_dataset_alias_name
from cosmos.exceptions import CosmosValueError
from cosmos.operators.local import AbstractDbtLocalBase

AIRFLOW_VERSION = Version(airflow.__version__)


def _mock_bigquery_adapter() -> None:
from typing import Optional, Tuple

import agate
from dbt.adapters.bigquery.connections import BigQueryAdapterResponse, BigQueryConnectionManager
from dbt_common.clients.agate_helper import empty_table

def execute( # type: ignore[no-untyped-def]
self, sql, auto_begin=False, fetch=None, limit: Optional[int] = None
) -> Tuple[BigQueryAdapterResponse, agate.Table]:
return BigQueryAdapterResponse("mock_bigquery_adapter_response"), empty_table()

BigQueryConnectionManager.execute = execute


def _configure_bigquery_async_op_args(async_op_obj: Any, **kwargs: Any) -> Any:
sql = kwargs.get("sql")
if not sql:
raise CosmosValueError("Keyword argument 'sql' is required for BigQuery Async operator")
async_op_obj.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return async_op_obj


class DbtRunAirflowAsyncBigqueryOperator(BigQueryInsertJobOperator, AbstractDbtLocalBase): # type: ignore[misc]

template_fields: Sequence[str] = (
Expand Down
14 changes: 9 additions & 5 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from packaging.version import Version

from cosmos import cache, settings
from cosmos._utils.importer import load_method_from_module
from cosmos.cache import (
_copy_cached_package_lockfile_to_project,
_get_latest_cached_package_lockfile,
Expand Down Expand Up @@ -68,7 +69,6 @@
parse_number_of_warnings_subprocess,
)
from cosmos.dbt.project import create_symlinks
from cosmos.dbt_adapters import PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP, associate_async_operator_args
from cosmos.hooks.subprocess import (
FullOutputSubprocessHook,
FullOutputSubprocessResult,
Expand Down Expand Up @@ -444,9 +444,9 @@ def _mock_dbt_adapter(async_context: dict[str, Any] | None) -> None:
if "profile_type" not in async_context:
raise CosmosValueError("`profile_type` needs to be specified in `async_context` when running as async")
profile_type = async_context["profile_type"]
if profile_type not in PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP:
raise CosmosValueError(f"Mock adapter callable function not available for profile_type {profile_type}")
mock_adapter_callable = PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP[profile_type]
module_path = f"cosmos.operators._asynchronous.{profile_type}"
method_name = f"_mock_{profile_type}_adapter"
mock_adapter_callable = load_method_from_module(module_path, method_name)
mock_adapter_callable()

def _handle_datasets(self, context: Context) -> None:
Expand All @@ -473,7 +473,11 @@ def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None

def _handle_async_execution(self, tmp_project_dir: str, context: Context, async_context: dict[str, Any]) -> None:
sql = self._read_run_sql_from_target_dir(tmp_project_dir, async_context)
associate_async_operator_args(self, async_context["profile_type"], sql=sql)
profile_type = async_context["profile_type"]
module_path = f"cosmos.operators._asynchronous.{profile_type}"
method_name = f"_configure_{profile_type}_async_op_args"
async_op_configurator = load_method_from_module(module_path, method_name)
async_op_configurator(self, sql=sql)
async_context["async_operator"].execute(self, context)

def run_command(
Expand Down
45 changes: 0 additions & 45 deletions tests/dbt_adapters/test_bigquery.py

This file was deleted.

15 changes: 0 additions & 15 deletions tests/dbt_adapters/test_init.py

This file was deleted.

46 changes: 44 additions & 2 deletions tests/operators/_asynchronous/test_bigquery.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator

from cosmos.config import ProfileConfig
from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator
from cosmos.exceptions import CosmosValueError
from cosmos.operators._asynchronous.bigquery import (
DbtRunAirflowAsyncBigqueryOperator,
_configure_bigquery_async_op_args,
_mock_bigquery_adapter,
)


@pytest.fixture
Expand Down Expand Up @@ -69,3 +74,40 @@ def test_dbt_run_airflow_async_bigquery_operator_execute(mock_build_and_run_cmd,
"async_operator": BigQueryInsertJobOperator,
},
)


@pytest.fixture
def async_operator_mock():
"""Fixture to create a mock async operator object."""
return Mock()


@pytest.mark.integration
def test_mock_bigquery_adapter():
"""Test _mock_bigquery_adapter to verify it modifies BigQueryConnectionManager.execute."""
from dbt.adapters.bigquery.connections import BigQueryConnectionManager

_mock_bigquery_adapter()

assert hasattr(BigQueryConnectionManager, "execute")

response, table = BigQueryConnectionManager.execute(None, sql="SELECT 1")
assert response._message == "mock_bigquery_adapter_response"
assert table is not None


def test_configure_bigquery_async_op_args_valid(async_operator_mock):
"""Test _configure_bigquery_async_op_args correctly configures the async operator."""
sql_query = "SELECT * FROM test_table"

result = _configure_bigquery_async_op_args(async_operator_mock, sql=sql_query)

assert result == async_operator_mock
assert result.configuration["query"]["query"] == sql_query
assert result.configuration["query"]["useLegacySql"] is False


def test_configure_bigquery_async_op_args_missing_sql(async_operator_mock):
"""Test _configure_bigquery_async_op_args raises CosmosValueError when 'sql' is missing."""
with pytest.raises(CosmosValueError, match="Keyword argument 'sql' is required for BigQuery Async operator"):
_configure_bigquery_async_op_args(async_operator_mock)
29 changes: 9 additions & 20 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,21 +1362,7 @@ def test_upload_compiled_sql_should_upload(mock_configure_remote, mock_object_st
mock_object_storage_path.return_value.copy.assert_any_call(mock_object_storage_path.return_value)


MOCK_ADAPTER_CALLABLE_MAP = {
"snowflake": MagicMock(),
"bigquery": MagicMock(),
}


@pytest.fixture
def mock_adapter_map(monkeypatch):
monkeypatch.setattr(
"cosmos.operators.local.PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP",
MOCK_ADAPTER_CALLABLE_MAP,
)


def test_mock_dbt_adapter_valid_context(mock_adapter_map):
def test_mock_dbt_adapter_valid_context():
"""
Test that the _mock_dbt_adapter method calls the correct mock adapter function
when provided with a valid async_context.
Expand All @@ -1387,9 +1373,12 @@ def test_mock_dbt_adapter_valid_context(mock_adapter_map):
}
AbstractDbtLocalBase.__abstractmethods__ = set()
operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock())
operator._mock_dbt_adapter(async_context)
with patch("cosmos.operators.local.load_method_from_module") as mock_load_method:
operator._mock_dbt_adapter(async_context)

MOCK_ADAPTER_CALLABLE_MAP["bigquery"].assert_called_once()
expected_module_path = "cosmos.operators._asynchronous.bigquery"
expected_method_name = "_mock_bigquery_adapter"
mock_load_method.assert_called_once_with(expected_module_path, expected_method_name)


def test_mock_dbt_adapter_missing_async_context():
Expand Down Expand Up @@ -1433,7 +1422,7 @@ def test_mock_dbt_adapter_missing_profile_type():
operator._mock_dbt_adapter(async_context)


def test_mock_dbt_adapter_unsupported_profile_type(mock_adapter_map):
def test_mock_dbt_adapter_unsupported_profile_type():
"""
Test that the _mock_dbt_adapter method raises a CosmosValueError
when the profile_type is not supported.
Expand All @@ -1445,7 +1434,7 @@ def test_mock_dbt_adapter_unsupported_profile_type(mock_adapter_map):
AbstractDbtLocalBase.__abstractmethods__ = set()
operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock())
with pytest.raises(
CosmosValueError,
match="Mock adapter callable function not available for profile_type unsupported_profile",
ModuleNotFoundError,
match="Module cosmos.operators._asynchronous.unsupported_profile not found",
):
operator._mock_dbt_adapter(async_context)
48 changes: 48 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import sys
import types

import pytest

from cosmos._utils.importer import load_method_from_module

dummy_module_text = """
def dummy_method():
return "Hello from dummy_method"
"""


@pytest.fixture
def create_dummy_module():
"""Creates a temporary in-memory module with a dummy method."""
module_name = "test_dummy_module"
method_name = "dummy_method"

module = types.ModuleType(module_name)
exec(dummy_module_text, module.__dict__)

sys.modules[module_name] = module

yield module_name, method_name

del sys.modules[module_name]


def test_load_valid_method(create_dummy_module):
"""Test that a valid method is loaded successfully."""
module_name, method_name = create_dummy_module
method = load_method_from_module(module_name, method_name)
assert callable(method)
assert method() == "Hello from dummy_method"


def test_load_invalid_module():
"""Test that ModuleNotFoundError is raised for an invalid module."""
with pytest.raises(ModuleNotFoundError, match="Module invalid_module not found"):
load_method_from_module("invalid_module", "dummy_method")


def test_load_invalid_method(create_dummy_module):
"""Test that AttributeError is raised for a missing method in a valid module."""
module_name, _ = create_dummy_module
with pytest.raises(AttributeError, match=f"Method invalid_method not found in module {module_name}"):
load_method_from_module(module_name, "invalid_method")