Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add on_warning_callback to DbtSourceKubernetesOperator and refactor previous operators #1501

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions cosmos/operators/kubernetes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from abc import ABC
from os import PathLike
from typing import Any, Callable, Sequence

Expand Down Expand Up @@ -136,31 +137,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtSourceKubernetesOperator(DbtSourceMixin, DbtKubernetesBaseOperator):
"""
Executes a dbt source freshness command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtRunKubernetesOperator(DbtRunMixin, DbtKubernetesBaseOperator):
"""
Executes a dbt core run command.
"""

template_fields: Sequence[str] = DbtKubernetesBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtTestKubernetesOperator(DbtTestMixin, DbtKubernetesBaseOperator):
"""
Executes a dbt core test command.
"""

class DbtWarningKubernetesOperator(DbtKubernetesBaseOperator, ABC):
def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None:
if not on_warning_callback:
super().__init__(**kwargs)
Expand All @@ -181,7 +158,7 @@ def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwar
kwargs["is_delete_operator_pod"] = False
kwargs["on_finish_action"] = OnFinishAction.KEEP_POD

# Add an additional callback to both success and failure callbacks.
# Add a callback to both success and failure callbacks.
# In case of success, check for a warning in the logs and clean up the pod.
self.on_success_callback = kwargs.get("on_success_callback", None) or []
if isinstance(self.on_success_callback, list):
Expand All @@ -208,7 +185,10 @@ def _handle_warnings(self, context: Context) -> None:
"""
if not (
isinstance(context["task_instance"], TaskInstance)
and isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
and (
isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
or isinstance(context["task_instance"].task, DbtSourceKubernetesOperator)
)
):
return
task = context["task_instance"].task
Expand Down Expand Up @@ -243,7 +223,10 @@ def _cleanup_pod(self, context: Context) -> None:
"""
if not (
isinstance(context["task_instance"], TaskInstance)
and isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
and (
isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
or isinstance(context["task_instance"].task, DbtSourceKubernetesOperator)
)
):
return
task = context["task_instance"].task
Expand All @@ -252,6 +235,35 @@ def _cleanup_pod(self, context: Context) -> None:
task.cleanup(pod=task.pod, remote_pod=task.remote_pod)


class DbtTestKubernetesOperator(DbtTestMixin, DbtWarningKubernetesOperator):
"""
Executes a dbt core test command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtSourceKubernetesOperator(DbtSourceMixin, DbtWarningKubernetesOperator):
"""
Executes a dbt source freshness command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtRunKubernetesOperator(DbtRunMixin, DbtKubernetesBaseOperator):
"""
Executes a dbt core run command.
"""

template_fields: Sequence[str] = DbtKubernetesBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtRunOperationKubernetesOperator(DbtRunOperationMixin, DbtKubernetesBaseOperator):
"""
Executes a dbt core run-operation command.
Expand Down
84 changes: 82 additions & 2 deletions tests/operators/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DbtLSKubernetesOperator,
DbtRunKubernetesOperator,
DbtSeedKubernetesOperator,
DbtSourceKubernetesOperator,
DbtTestKubernetesOperator,
)

Expand Down Expand Up @@ -118,11 +119,9 @@ def test_dbt_kubernetes_operator_get_env(p_context_to_airflow_vars: MagicMock, b
"no_version_check": True,
}


if version.parse(airflow_version) == version.parse("2.4"):
base_kwargs["name"] = "some-pod-name"


result_map = {
"ls": DbtLSKubernetesOperator(**base_kwargs),
"run": DbtRunKubernetesOperator(**base_kwargs),
Expand Down Expand Up @@ -208,6 +207,62 @@ def test_dbt_test_kubernetes_operator_constructor(additional_kwargs, expected_re
assert test_operator.on_finish_action_original == OnFinishAction(expected_results[3])


@pytest.mark.parametrize(
"additional_kwargs,expected_results",
[
({"on_success_callback": None, "is_delete_operator_pod": True}, (1, 1, True, "delete_pod")),
(
{"on_success_callback": (lambda **kwargs: None), "is_delete_operator_pod": False},
(2, 1, False, "keep_pod"),
),
(
{"on_success_callback": [(lambda **kwargs: None), (lambda **kwargs: None)], "is_delete_operator_pod": None},
(3, 1, True, "delete_pod"),
),
(
{"on_failure_callback": None, "is_delete_operator_pod": True, "on_finish_action": "keep_pod"},
(1, 1, True, "delete_pod"),
),
(
{
"on_failure_callback": (lambda **kwargs: None),
"is_delete_operator_pod": None,
"on_finish_action": "delete_pod",
},
(1, 2, True, "delete_pod"),
),
(
{
"on_failure_callback": [(lambda **kwargs: None), (lambda **kwargs: None)],
"is_delete_operator_pod": None,
"on_finish_action": "delete_succeeded_pod",
},
(1, 3, False, "delete_succeeded_pod"),
),
({"is_delete_operator_pod": None, "on_finish_action": "keep_pod"}, (1, 1, False, "keep_pod")),
({}, (1, 1, True, "delete_pod")),
],
)
@pytest.mark.skipif(
not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available"
)
def test_dbt_source_kubernetes_operator_constructor(additional_kwargs, expected_results):
source_operator = DbtSourceKubernetesOperator(
on_warning_callback=(lambda **kwargs: None), **additional_kwargs, **base_kwargs
)

print(additional_kwargs, source_operator.__dict__)

assert isinstance(source_operator.on_success_callback, list)
assert isinstance(source_operator.on_failure_callback, list)
assert source_operator._handle_warnings in source_operator.on_success_callback
assert source_operator._cleanup_pod in source_operator.on_failure_callback
assert len(source_operator.on_success_callback) == expected_results[0]
assert len(source_operator.on_failure_callback) == expected_results[1]
assert source_operator.is_delete_operator_pod_original == expected_results[2]
assert source_operator.on_finish_action_original == OnFinishAction(expected_results[3])


class FakePodManager:
def read_pod_logs(self, pod, container):
assert pod == "pod"
Expand Down Expand Up @@ -259,6 +314,31 @@ def cleanup(pod: str, remote_pod: str):
test_operator._handle_warnings(context)


@pytest.mark.skipif(
not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available"
)
def test_dbt_source_kubernetes_operator_handle_warnings_and_cleanup_pod():
def on_warning_callback(context: Context):
assert context["test_names"] == ["dbt_utils_accepted_range_table_col__12__0"]
assert context["test_results"] == ["Got 252 results, configured to warn if >0"]

def cleanup(pod: str, remote_pod: str):
assert pod == remote_pod

test_operator = DbtSourceKubernetesOperator(
is_delete_operator_pod=True, on_warning_callback=on_warning_callback, **base_kwargs
)
task_instance = TaskInstance(test_operator)
task_instance.task.pod_manager = FakePodManager()
task_instance.task.pod = task_instance.task.remote_pod = "pod"
task_instance.task.cleanup = cleanup

context = Context()
context_merge(context, task_instance=task_instance)

test_operator._handle_warnings(context)


def test_created_pod():
ls_kwargs = {"env_vars": {"FOO": "BAR"}, "namespace": "foo", "append_env": False}
ls_kwargs.update(base_kwargs)
Expand Down