Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add structure to support multiple db for async operator execution #1483

Merged
merged 18 commits into from
Jan 27, 2025
Merged
Empty file.
65 changes: 65 additions & 0 deletions cosmos/operators/_asynchronous/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import importlib
import logging
from abc import ABCMeta
from typing import Any, Sequence

from airflow.utils.context import Context

from cosmos.airflow.graph import _snake_case_to_camelcase
from cosmos.config import ProfileConfig
from cosmos.constants import ExecutionMode
from cosmos.operators.local import DbtRunLocalOperator

log = logging.getLogger(__name__)


def _create_async_operator_class(profile_type: str, dbt_class: str) -> Any:
"""
Dynamically constructs and returns an asynchronous operator class for the given profile type and dbt class name.
The function constructs a class path string for an asynchronous operator, based on the provided `profile_type` and
`dbt_class`. It attempts to import the corresponding class dynamically and return it. If the class cannot be found,
it falls back to returning the `DbtRunLocalOperator` class.
:param profile_type: The dbt profile type
:param dbt_class: The dbt class name. Example DbtRun, DbtTest.
"""
execution_mode = ExecutionMode.AIRFLOW_ASYNC.value
class_path = f"cosmos.operators._asynchronous.{profile_type}.{dbt_class}{_snake_case_to_camelcase(execution_mode)}{profile_type.capitalize()}Operator"
try:
module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
operator_class = getattr(module, class_name)
return operator_class
except (ModuleNotFoundError, AttributeError):
log.info("Error in loading class: %s. falling back to DbtRunLocalOperator", class_path)
return DbtRunLocalOperator


class DbtRunAirflowAsyncFactoryOperator(DbtRunLocalOperator, metaclass=ABCMeta): # type: ignore[misc]

template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ("project_dir",) # type: ignore[operator]

def __init__(self, project_dir: str, profile_config: ProfileConfig, **kwargs: Any):
self.project_dir = project_dir
self.profile_config = profile_config

async_operator_class = self.create_async_operator()

# Dynamically modify the base classes.
# This is necessary because the async operator class is only known at runtime.
# When using composition instead of inheritance to initialize the async class and run its execute method,
# Airflow throws a `DuplicateTaskIdFound` error.
DbtRunAirflowAsyncFactoryOperator.__bases__ = (async_operator_class,)
super().__init__(project_dir=project_dir, profile_config=profile_config, **kwargs)

def create_async_operator(self) -> Any:

profile_type = self.profile_config.get_profile_type()

async_class_operator = _create_async_operator_class(profile_type, "DbtRun")

return async_class_operator

def execute(self, context: Context) -> None:
super().execute(context)
108 changes: 108 additions & 0 deletions cosmos/operators/_asynchronous/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos import settings
from cosmos.config import ProfileConfig
from cosmos.exceptions import CosmosValueError
from cosmos.settings import remote_target_path, remote_target_path_conn_id


class DbtRunAirflowAsyncBigqueryOperator(BigQueryInsertJobOperator): # type: ignore[misc]

template_fields: Sequence[str] = (
"full_refresh",
"gcp_project",
"dataset",
"location",
)

def __init__(
self,
project_dir: str,
profile_config: ProfileConfig,
extra_context: dict[str, Any] | None = None,
**kwargs: Any,
):
self.project_dir = project_dir
self.profile_config = profile_config
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
profile = self.profile_config.profile_mapping.profile # type: ignore
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]
self.extra_context = extra_context or {}
self.full_refresh = None
if "full_refresh" in kwargs:
self.full_refresh = kwargs.pop("full_refresh")
self.configuration: dict[str, Any] = {}
super().__init__(
gcp_conn_id=self.gcp_conn_id,
configuration=self.configuration,
deferrable=True,
**kwargs,
)

def get_remote_sql(self) -> str:
if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.")
from airflow.io.path import ObjectStoragePath

file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore
dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"]

remote_target_path_str = str(remote_target_path).rstrip("/")

if TYPE_CHECKING: # pragma: no cover
assert self.project_dir is not None

project_dir_parent = str(Path(self.project_dir).parent)
relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/")
remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}"

object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp: # type: ignore
return fp.read() # type: ignore

def drop_table_sql(self) -> None:
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};"

Check warning on line 73 in cosmos/operators/_asynchronous/bigquery.py

View check run for this annotation

Codecov / codecov/patch

cosmos/operators/_asynchronous/bigquery.py#L72-L73

Added lines #L72 - L73 were not covered by tests

hook = BigQueryHook(

Check warning on line 75 in cosmos/operators/_asynchronous/bigquery.py

View check run for this annotation

Codecov / codecov/patch

cosmos/operators/_asynchronous/bigquery.py#L75

Added line #L75 was not covered by tests
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.configuration = {

Check warning on line 79 in cosmos/operators/_asynchronous/bigquery.py

View check run for this annotation

Codecov / codecov/patch

cosmos/operators/_asynchronous/bigquery.py#L79

Added line #L79 was not covered by tests
"query": {
"query": sql,
"useLegacySql": False,
}
}
hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project)

Check warning on line 85 in cosmos/operators/_asynchronous/bigquery.py

View check run for this annotation

Codecov / codecov/patch

cosmos/operators/_asynchronous/bigquery.py#L85

Added line #L85 was not covered by tests

def execute(self, context: Context) -> Any | None:

if not self.full_refresh:
raise CosmosValueError("The async execution only supported for full_refresh")
else:
# It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it
# https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666
# https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation
pankajastro marked this conversation as resolved.
Show resolved Hide resolved
# We're emulating this behaviour here
# The compiled SQL has several limitations here, but these will be addressed in the PR: https://github.com/astronomer/astronomer-cosmos/pull/1474.
self.drop_table_sql()
sql = self.get_remote_sql()
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
# prefix explicit create command to create table
sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}"
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return super().execute(context)
14 changes: 14 additions & 0 deletions cosmos/operators/_asynchronous/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# TODO: Implement it

from typing import Any

from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context


class DbtRunAirflowAsyncDatabricksOperator(BaseOperator):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)

def execute(self, context: Context) -> None:
raise NotImplementedError()
115 changes: 9 additions & 106 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
from __future__ import annotations

import inspect
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos import settings
from cosmos.config import ProfileConfig
from cosmos.exceptions import CosmosValueError
from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator
from cosmos.operators.base import AbstractDbtBaseOperator
from cosmos.operators.local import (
DbtBuildLocalOperator,
Expand All @@ -24,7 +17,6 @@
DbtSourceLocalOperator,
DbtTestLocalOperator,
)
from cosmos.settings import remote_target_path, remote_target_path_conn_id

_SUPPORTED_DATABASES = ["bigquery"]

Expand All @@ -35,8 +27,8 @@

class DbtBaseAirflowAsyncOperator(BaseOperator, metaclass=ABCMeta):
def __init__(self, **kwargs) -> None: # type: ignore
self.location = kwargs.pop("location")
self.configuration = kwargs.pop("configuration", {})
if "location" in kwargs:
kwargs.pop("location")
super().__init__(**kwargs)


Expand All @@ -60,47 +52,17 @@ class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalO
pass


class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator): # type: ignore

template_fields: Sequence[str] = (
"full_refresh",
"project_dir",
"gcp_project",
"dataset",
"location",
)
class DbtRunAirflowAsyncOperator(DbtRunAirflowAsyncFactoryOperator): # type: ignore

def __init__( # type: ignore
self,
project_dir: str,
profile_config: ProfileConfig,
location: str, # This is a mandatory parameter when using BigQueryInsertJobOperator with deferrable=True
full_refresh: bool = False,
extra_context: dict[str, object] | None = None,
configuration: dict[str, object] | None = None,
**kwargs,
) -> None:
# dbt task param
self.project_dir = project_dir
self.extra_context = extra_context or {}
self.full_refresh = full_refresh
self.profile_config = profile_config
if not self.profile_config or not self.profile_config.profile_mapping:
raise CosmosValueError(f"Cosmos async support is only available when using ProfileMapping")

self.profile_type: str = profile_config.get_profile_type() # type: ignore
if self.profile_type not in _SUPPORTED_DATABASES:
raise CosmosValueError(f"Async run are only supported: {_SUPPORTED_DATABASES}")

# airflow task param
self.location = location
self.configuration = configuration or {}
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
profile = self.profile_config.profile_mapping.profile
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]

# Cosmos attempts to pass many kwargs that BigQueryInsertJobOperator simply does not accept.

# Cosmos attempts to pass many kwargs that async operator simply does not accept.
# We need to pop them.
clean_kwargs = {}
non_async_args = set(inspect.signature(AbstractDbtBaseOperator.__init__).parameters.keys())
Expand All @@ -113,71 +75,12 @@ def __init__( # type: ignore

# The following are the minimum required parameters to run BigQueryInsertJobOperator using the deferrable mode
super().__init__(
gcp_conn_id=self.gcp_conn_id,
configuration=self.configuration,
location=self.location,
deferrable=True,
project_dir=project_dir,
profile_config=profile_config,
extra_context=extra_context,
**clean_kwargs,
)

def get_remote_sql(self) -> str:
if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.")
from airflow.io.path import ObjectStoragePath

file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore
dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"]

remote_target_path_str = str(remote_target_path).rstrip("/")

if TYPE_CHECKING:
assert self.project_dir is not None

project_dir_parent = str(Path(self.project_dir).parent)
relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/")
remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}"

object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp: # type: ignore
return fp.read() # type: ignore

def drop_table_sql(self) -> None:
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};"

hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project)

def execute(self, context: Context) -> Any | None:
if not self.full_refresh:
raise CosmosValueError("The async execution only supported for full_refresh")
else:
# It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it
# https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666
# https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation
# We're emulating this behaviour here
self.drop_table_sql()
sql = self.get_remote_sql()
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
# prefix explicit create command to create table
sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}"
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return super().execute(context)


class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore
pass
Expand Down
1 change: 0 additions & 1 deletion tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def test_build_airflow_graph_with_dbt_compile_task():
"project_dir": SAMPLE_PROJ_PATH,
"conn_id": "fake_conn",
"profile_config": bigquery_profile_config,
"location": "",
}
render_config = RenderConfig(
select=["tag:some"],
Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection


@pytest.fixture()
def mock_bigquery_conn(): # type: ignore
"""
Mocks and returns an Airflow BigQuery connection.
"""
extra = {
"project": "my_project",
"key_path": "my_key_path.json",
}
conn = Connection(
conn_id="my_bigquery_connection",
conn_type="google_cloud_platform",
extra=json.dumps(extra),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn
Empty file.
Loading
Loading