Skip to content

Commit

Permalink
Revert "Revert "fix wrong merge""
Browse files Browse the repository at this point in the history
This reverts commit f0ceb10.
  • Loading branch information
eric-wang-1990 committed Jan 28, 2025
1 parent a287d60 commit e68d4d8
Show file tree
Hide file tree
Showing 46 changed files with 176 additions and 886 deletions.
57 changes: 1 addition & 56 deletions dbt/adapters/databricks/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _poll_api(


@dataclass(frozen=True, eq=True, unsafe_hash=True)
class CommandExecution:
class CommandExecution(object):
command_id: str
context_id: str
cluster_id: str
Expand Down Expand Up @@ -459,60 +459,6 @@ def run(self, job_id: str, enable_queueing: bool = True) -> str:
return response_json["run_id"]


class DltPipelineApi(PollableApi):
def __init__(self, session: Session, host: str, polling_interval: int):
super().__init__(session, host, "/api/2.0/pipelines", polling_interval, 60 * 60)

def poll_for_completion(self, pipeline_id: str) -> None:
self._poll_api(
url=f"/{pipeline_id}",
params={},
get_state_func=lambda response: response.json()["state"],
terminal_states={"IDLE", "FAILED", "DELETED"},
expected_end_state="IDLE",
unexpected_end_state_func=self._get_exception,
)

def _get_exception(self, response: Response) -> None:
response_json = response.json()
cause = response_json.get("cause")
if cause:
raise DbtRuntimeError(f"Pipeline {response_json.get('pipeline_id')} failed: {cause}")
else:
latest_update = response_json.get("latest_updates")[0]
last_error = self.get_update_error(response_json.get("pipeline_id"), latest_update)
raise DbtRuntimeError(
f"Pipeline {response_json.get('pipeline_id')} failed: {last_error}"
)

def get_update_error(self, pipeline_id: str, update_id: str) -> str:
response = self.session.get(f"/{pipeline_id}/events")
if response.status_code != 200:
raise DbtRuntimeError(
f"Error getting pipeline event info for {pipeline_id}: {response.text}"
)

events = response.json().get("events", [])
update_events = [
e
for e in events
if e.get("event_type", "") == "update_progress"
and e.get("origin", {}).get("update_id") == update_id
]

error_events = [
e
for e in update_events
if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED"
]

msg = ""
if error_events:
msg = error_events[0].get("message", "")

return msg


class DatabricksApiClient:
def __init__(
self,
Expand All @@ -534,7 +480,6 @@ def __init__(
self.job_runs = JobRunsApi(session, host, polling_interval, timeout)
self.workflows = WorkflowJobApi(session, host)
self.workflow_permissions = JobPermissionsApi(session, host)
self.dlt_pipelines = DltPipelineApi(session, host, polling_interval)

@staticmethod
def create(
Expand Down
18 changes: 2 additions & 16 deletions dbt/adapters/databricks/column.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Any, ClassVar, Optional
from typing import ClassVar, Optional

from dbt.adapters.databricks.utils import quote
from dbt.adapters.spark.column import SparkColumn


Expand All @@ -28,17 +27,4 @@ def data_type(self) -> str:
return self.translate_type(self.dtype)

def __repr__(self) -> str:
return f"<DatabricksColumn {self.name} ({self.data_type})>"

@staticmethod
def get_name(column: dict[str, Any]) -> str:
name = column["name"]
return quote(name) if column.get("quote", False) else name

@staticmethod
def format_remove_column_list(columns: list["DatabricksColumn"]) -> str:
return ", ".join([quote(c.name) for c in columns])

@staticmethod
def format_add_column_list(columns: list["DatabricksColumn"]) -> str:
return ", ".join([f"{quote(c.name)} {c.data_type}" for c in columns])
return "<DatabricksColumn {} ({})>".format(self.name, self.data_type)
33 changes: 16 additions & 17 deletions dbt/adapters/databricks/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,10 @@
from databricks.sdk import WorkspaceClient
from databricks.sdk.core import Config, CredentialsProvider
from dbt.adapters.contracts.connection import Credentials
from dbt.adapters.databricks.auth import m2m_auth, token_auth
from dbt.adapters.databricks.events.credential_events import (
CredentialLoadError,
CredentialSaveError,
CredentialShardEvent,
)
from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.databricks.logging import logger

CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog"
DBT_DATABRICKS_INVOCATION_ENV = "DBT_DATABRICKS_INVOCATION_ENV"
DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$")
EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)")
DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS"
Expand Down Expand Up @@ -76,10 +70,8 @@ class DatabricksCredentials(Credentials):
@classmethod
def __pre_deserialize__(cls, data: dict[Any, Any]) -> dict[Any, Any]:
data = super().__pre_deserialize__(data)
data.setdefault("database", None)
data.setdefault("connection_parameters", {})
data["connection_parameters"].setdefault("_retry_stop_after_attempts_count", 30)
data["connection_parameters"].setdefault("_retry_delay_max", 60)
if "database" not in data:
data["database"] = None
return data

def __post_init__(self) -> None:
Expand Down Expand Up @@ -141,16 +133,21 @@ def __post_init__(self) -> None:
def validate_creds(self) -> None:
for key in ["host", "http_path"]:
if not getattr(self, key):
raise DbtConfigError(f"The config '{key}' is required to connect to Databricks")
raise DbtConfigError(
"The config '{}' is required to connect to Databricks".format(key)
)

if not self.token and self.auth_type != "oauth":
raise DbtConfigError(
"The config `auth_type: oauth` is required when not using access token"
("The config `auth_type: oauth` is required when not using access token")
)

if not self.client_id and self.client_secret:
raise DbtConfigError(
"The config 'client_id' is required to connect "
"to Databricks when 'client_secret' is present"
(
"The config 'client_id' is required to connect "
"to Databricks when 'client_secret' is present"
)
)

if (not self.azure_client_id and self.azure_client_secret) or (
Expand All @@ -165,7 +162,7 @@ def validate_creds(self) -> None:

@classmethod
def get_invocation_env(cls) -> Optional[str]:
invocation_env = GlobalState.get_invocation_env()
invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV)
if invocation_env:
# Thrift doesn't allow nested () so we need to ensure
# that the passed user agent is valid.
Expand All @@ -175,7 +172,9 @@ def get_invocation_env(cls) -> Optional[str]:

@classmethod
def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]:
http_session_headers_str = GlobalState.get_http_session_headers()
http_session_headers_str: Optional[str] = os.environ.get(
DBT_DATABRICKS_HTTP_SESSION_HEADERS
)

http_session_headers_dict: dict[str, str] = (
{
Expand Down
48 changes: 0 additions & 48 deletions dbt/adapters/databricks/global_state.py

This file was deleted.

4 changes: 2 additions & 2 deletions dbt/adapters/databricks/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from logging import Handler, LogRecord, getLogger
from typing import Union

from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.events.logging import AdapterLogger

logger = AdapterLogger("Databricks")
Expand All @@ -22,7 +22,7 @@ def emit(self, record: LogRecord) -> None:
dbt_adapter_logger = AdapterLogger("databricks-sql-connector")

pysql_logger = getLogger("databricks.sql")
pysql_logger_level = GlobalState.get_connector_log_level()
pysql_logger_level = os.environ.get("DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN").upper()
pysql_logger.setLevel(pysql_logger_level)

pysql_handler = DbtCoreHandler(dbt_logger=dbt_adapter_logger, level=pysql_logger_level)
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/python_models/run_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dbt.adapters.databricks.logging import logger


class PythonRunTracker:
class PythonRunTracker(object):
_run_ids: set[str] = set()
_commands: set[CommandExecution] = set()
_lock = threading.Lock()
Expand Down
6 changes: 2 additions & 4 deletions dbt/adapters/databricks/relation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Optional, Type # noqa
from typing import Any, Optional, Type

from dbt_common.dataclass_schema import StrEnum
from dbt_common.exceptions import DbtRuntimeError
Expand Down Expand Up @@ -39,8 +39,6 @@ class DatabricksRelationType(StrEnum):
Foreign = "foreign"
StreamingTable = "streaming_table"
External = "external"
ManagedShallowClone = "managed_shallow_clone"
ExternalShallowClone = "external_shallow_clone"
Unknown = "unknown"


Expand Down Expand Up @@ -133,7 +131,7 @@ def matches(
return match

@classproperty
def get_relation_type(cls) -> Type[DatabricksRelationType]: # noqa
def get_relation_type(cls) -> Type[DatabricksRelationType]:
return DatabricksRelationType

def information_schema(self, view_name: Optional[str] = None) -> InformationSchema:
Expand Down
1 change: 0 additions & 1 deletion dbt/adapters/databricks/relation_configs/tblproperties.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class TblPropertiesConfig(DatabricksComponentConfig):
"delta.feature.rowTracking",
"delta.rowTracking.materializedRowCommitVersionColumnName",
"delta.rowTracking.materializedRowIdColumnName",
"spark.internal.pipelines.top_level_entry.user_specified_name",
]

def __eq__(self, __value: Any) -> bool:
Expand Down
4 changes: 0 additions & 4 deletions dbt/adapters/databricks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,3 @@ def handle_missing_objects(exec: Callable[[], T], default: T) -> T:
if check_not_found_error(errmsg):
return default
raise e


def quote(name: str) -> str:
return f"`{name}`"
17 changes: 0 additions & 17 deletions dbt/include/databricks/macros/adapters/columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,3 @@

{% do return(load_result('get_columns_comments_via_information_schema').table) %}
{% endmacro %}

{% macro databricks__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %}
{% if remove_columns %}
{% if not relation.is_delta %}
{{ exceptions.raise_compiler_error('Delta format required for dropping columns from tables') }}
{% endif %}
{%- call statement('alter_relation_remove_columns') -%}
ALTER TABLE {{ relation.render() }} DROP COLUMNS ({{ api.Column.format_remove_column_list(remove_columns) }})
{%- endcall -%}
{% endif %}

{% if add_columns %}
{%- call statement('alter_relation_add_columns') -%}
ALTER TABLE {{ relation.render() }} ADD COLUMNS ({{ api.Column.format_add_column_list(add_columns) }})
{%- endcall -%}
{% endif %}
{% endmacro %}
25 changes: 21 additions & 4 deletions dbt/include/databricks/macros/adapters/persist_docs.sql
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
{% macro databricks__alter_column_comment(relation, column_dict) %}
{% if config.get('file_format', default='delta') in ['delta', 'hudi'] %}
{% for column in column_dict.values() %}
{% set comment = column['description'] %}
{% for column_name in column_dict %}
{% set comment = column_dict[column_name]['description'] %}
{% set escaped_comment = comment | replace('\'', '\\\'') %}
{% set comment_query %}
alter table {{ relation.render()|lower }} change column {{ api.Column.get_name(column) }} comment '{{ escaped_comment }}';
alter table {{ relation }} change column
{{ adapter.quote(column_name) if column_dict[column_name]['quote'] else column_name }}
comment '{{ escaped_comment }}';
{% endset %}
{% do run_query(comment_query) %}
{% endfor %}
Expand All @@ -13,7 +15,7 @@

{% macro alter_table_comment(relation, model) %}
{% set comment_query %}
comment on table {{ relation.render()|lower }} is '{{ model.description | replace("'", "\\'") }}'
comment on table {{ relation|lower }} is '{{ model.description | replace("'", "\\'") }}'
{% endset %}
{% do run_query(comment_query) %}
{% endmacro %}
Expand All @@ -28,3 +30,18 @@
{% do alter_column_comment(relation, columns_to_persist_docs) %}
{% endif %}
{% endmacro %}

{% macro get_column_comment_sql(column_name, column_dict) -%}
{% if column_name in column_dict and column_dict[column_name]["description"] -%}
{% set escaped_description = column_dict[column_name]["description"] | replace("'", "\\'") %}
{% set column_comment_clause = "comment '" ~ escaped_description ~ "'" %}
{%- endif -%}
{{ adapter.quote(column_name) }} {{ column_comment_clause }}
{% endmacro %}

{% macro get_persist_docs_column_list(model_columns, query_columns) %}
{% for column_name in query_columns %}
{{ get_column_comment_sql(column_name, model_columns) }}
{{- ", " if not loop.last else "" }}
{% endfor %}
{% endmacro %}
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,8 @@ select {{source_cols_csv}} from {{ source_relation }}

{%- set not_matched_by_source_action = config.get('not_matched_by_source_action') -%}
{%- set not_matched_by_source_condition = config.get('not_matched_by_source_condition') -%}

{%- set not_matched_by_source_action_trimmed = not_matched_by_source_action | lower | trim(' \n\t') %}
{%- set not_matched_by_source_action_is_set = (
not_matched_by_source_action_trimmed == 'delete'
or not_matched_by_source_action_trimmed.startswith('update')
)
%}



{% if unique_key %}
{% if unique_key is sequence and unique_key is not mapping and unique_key is not string %}
{% for key in unique_key %}
Expand Down Expand Up @@ -144,12 +137,12 @@ select {{source_cols_csv}} from {{ source_relation }}
then insert
{{ get_merge_insert(on_schema_change, source_columns, source_alias) }}
{%- endif %}
{%- if not_matched_by_source_action_is_set %}
{%- if not_matched_by_source_action == 'delete' %}
when not matched by source
{%- if not_matched_by_source_condition %}
and ({{ not_matched_by_source_condition }})
{%- endif %}
then {{ not_matched_by_source_action }}
then delete
{%- endif %}
{% endmacro %}

Expand Down
Loading

0 comments on commit e68d4d8

Please sign in to comment.