Skip to content

Commit

Permalink
Merge branch 'main' into 1.10.latest
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Jan 7, 2025
2 parents 87ec5d1 + b0ff51b commit ea2b52d
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 47 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

- Implement new constraint logic for use_materialization_v2 flag ([846](https://github.com/databricks/dbt-databricks/pull/846/files)), ([876](https://github.com/databricks/dbt-databricks/pull/876))

## dbt-databricks 1.9.2 (TBD)

### Under the Hood

- Refactor global state reading ([888](https://github.com/databricks/dbt-databricks/pull/888))

## dbt-databricks 1.9.1 (December 16, 2024)

### Features
Expand Down
16 changes: 8 additions & 8 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
CursorCreate,
)
from dbt.adapters.databricks.events.other_events import QueryError
from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.databricks.logging import logger
from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker
from dbt.adapters.databricks.utils import redact_credentials
Expand Down Expand Up @@ -86,9 +87,6 @@
DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)")


# toggle for session managements that minimizes the number of sessions opened/closed
USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"

# Number of idle seconds before a connection is automatically closed. Only applicable if
# USE_LONG_SESSIONS is true.
# Updated when idle times of 180s were causing errors
Expand Down Expand Up @@ -475,6 +473,8 @@ def add_query(
auto_begin: bool = True,
bindings: Optional[Any] = None,
abridge_sql_log: bool = False,
retryable_exceptions: tuple[type[Exception], ...] = tuple(),
retry_limit: int = 1,
*,
close_cursor: bool = False,
) -> tuple[Connection, Any]:
Expand Down Expand Up @@ -707,7 +707,7 @@ def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterRe
class ExtendedSessionConnectionManager(DatabricksConnectionManager):
def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None:
assert (
USE_LONG_SESSIONS
GlobalState.get_use_long_sessions()
), "This connection manager should only be used when USE_LONG_SESSIONS is enabled"
super().__init__(profile, mp_context)
self.threads_compute_connections: dict[
Expand Down Expand Up @@ -910,7 +910,7 @@ def open(cls, connection: Connection) -> Connection:
# Once long session management is no longer under the USE_LONG_SESSIONS toggle
# this should be renamed and replace the _open class method.
assert (
USE_LONG_SESSIONS
GlobalState.get_use_long_sessions()
), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS"

databricks_connection = cast(DatabricksDBTConnection, connection)
Expand Down Expand Up @@ -1013,15 +1013,15 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O

# If there is no node we return the http_path for the default compute.
if not query_header_context:
if not USE_LONG_SESSIONS:
if not GlobalState.get_use_long_sessions():
logger.debug(f"Thread {thread_id}: using default compute resource.")
return creds.http_path

# Get the name of the compute resource specified in the node's config.
# If none is specified return the http_path for the default compute.
compute_name = _get_compute_name(query_header_context)
if not compute_name:
if not USE_LONG_SESSIONS:
if not GlobalState.get_use_long_sessions():
logger.debug(f"On thread {thread_id}: {relation_name} using default compute resource.")
return creds.http_path

Expand All @@ -1037,7 +1037,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O
f"does not specify http_path, relation: {relation_name}"
)

if not USE_LONG_SESSIONS:
if not GlobalState.get_use_long_sessions():
logger.debug(
f"On thread {thread_id}: {relation_name} using compute resource '{compute_name}'."
)
Expand Down
8 changes: 3 additions & 5 deletions dbt/adapters/databricks/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
CredentialSaveError,
CredentialShardEvent,
)
from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.databricks.logging import logger

CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog"
DBT_DATABRICKS_INVOCATION_ENV = "DBT_DATABRICKS_INVOCATION_ENV"
DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$")
EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)")
DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS"
Expand Down Expand Up @@ -150,7 +150,7 @@ def validate_creds(self) -> None:

@classmethod
def get_invocation_env(cls) -> Optional[str]:
invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV)
invocation_env = GlobalState.get_invocation_env()
if invocation_env:
# Thrift doesn't allow nested () so we need to ensure
# that the passed user agent is valid.
Expand All @@ -160,9 +160,7 @@ def get_invocation_env(cls) -> Optional[str]:

@classmethod
def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]:
http_session_headers_str: Optional[str] = os.environ.get(
DBT_DATABRICKS_HTTP_SESSION_HEADERS
)
http_session_headers_str = GlobalState.get_http_session_headers()

http_session_headers_dict: dict[str, str] = (
{
Expand Down
58 changes: 58 additions & 0 deletions dbt/adapters/databricks/global_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
from typing import ClassVar, Optional


class GlobalState:
"""Global state is a bad idea, but since we don't control instantiation, better to have it in a
single place than scattered throughout the codebase.
"""

__use_long_sessions: ClassVar[Optional[bool]] = None

@classmethod
def get_use_long_sessions(cls) -> bool:
if cls.__use_long_sessions is None:
cls.__use_long_sessions = (
os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"
)
return cls.__use_long_sessions

__invocation_env: ClassVar[Optional[str]] = None
__invocation_env_set: ClassVar[bool] = False

@classmethod
def get_invocation_env(cls) -> Optional[str]:
if not cls.__invocation_env_set:
cls.__invocation_env = os.getenv("DBT_DATABRICKS_INVOCATION_ENV")
cls.__invocation_env_set = True
return cls.__invocation_env

__session_headers: ClassVar[Optional[str]] = None
__session_headers_set: ClassVar[bool] = False

@classmethod
def get_http_session_headers(cls) -> Optional[str]:
if not cls.__session_headers_set:
cls.__session_headers = os.getenv("DBT_DATABRICKS_HTTP_SESSION_HEADERS")
cls.__session_headers_set = True
return cls.__session_headers

__describe_char_bypass: ClassVar[Optional[bool]] = None

@classmethod
def get_char_limit_bypass(cls) -> bool:
if cls.__describe_char_bypass is None:
cls.__describe_char_bypass = (
os.getenv("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE"
)
return cls.__describe_char_bypass

__connector_log_level: ClassVar[Optional[str]] = None

@classmethod
def get_connector_log_level(cls) -> str:
if cls.__connector_log_level is None:
cls.__connector_log_level = os.getenv(
"DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN"
).upper()
return cls.__connector_log_level
6 changes: 3 additions & 3 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
)
from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.connections import (
USE_LONG_SESSIONS,
DatabricksConnectionManager,
ExtendedSessionConnectionManager,
)
from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.databricks.python_models.python_submissions import (
AllPurposeClusterPythonJobHelper,
JobClusterPythonJobHelper,
Expand Down Expand Up @@ -152,7 +152,7 @@ def get_identifier_list_string(table_names: set[str]) -> str:
"""

_identifier = "|".join(table_names)
bypass_2048_char_limit = os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "false")
bypass_2048_char_limit = GlobalState.get_char_limit_bypass()
if bypass_2048_char_limit == "true":
_identifier = _identifier if len(_identifier) < 2048 else "*"
return _identifier
Expand All @@ -164,7 +164,7 @@ class DatabricksAdapter(SparkAdapter):
Relation = DatabricksRelation
Column = DatabricksColumn

if USE_LONG_SESSIONS:
if GlobalState.get_use_long_sessions():
ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager
else:
ConnectionManager = DatabricksConnectionManager
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/databricks/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from logging import Handler, LogRecord, getLogger
from typing import Union

from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.events.logging import AdapterLogger

logger = AdapterLogger("Databricks")
Expand All @@ -22,7 +22,7 @@ def emit(self, record: LogRecord) -> None:
dbt_adapter_logger = AdapterLogger("databricks-sql-connector")

pysql_logger = getLogger("databricks.sql")
pysql_logger_level = os.environ.get("DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN").upper()
pysql_logger_level = GlobalState.get_connector_log_level()
pysql_logger.setLevel(pysql_logger_level)

pysql_handler = DbtCoreHandler(dbt_logger=dbt_adapter_logger, level=pysql_logger_level)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

{% set batch_size = get_batch_size() %}
{% set column_override = model['config'].get('column_types', {}) %}
{% set must_cast = model['config'].get("file_format", "delta") == "parquet" %}
{% set must_cast = model['config'].get('file_format', 'delta') == 'parquet' %}

{% set statements = [] %}

Expand Down
34 changes: 17 additions & 17 deletions dbt/include/databricks/macros/relations/constraints.sql
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,27 @@

{% macro get_constraint_sql(relation, constraint, model, column={}) %}
{% set statements = [] %}
{% set type = constraint.get("type", "") %}
{% set type = constraint.get('type', '') %}

{% if type == 'check' %}
{% set expression = constraint.get("expression", "") %}
{% set expression = constraint.get('expression', '') %}
{% if not expression %}
{{ exceptions.raise_compiler_error('Invalid check constraint expression') }}
{% endif %}

{% set name = constraint.get("name") %}
{% set name = constraint.get('name') %}
{% if not name %}
{% if local_md5 %}
{{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }}
{%- set name = local_md5 (relation.identifier ~ ";" ~ column.get("name", "") ~ ";" ~ expression ~ ";") -%}
{%- set name = local_md5 (relation.identifier ~ ";" ~ column.get('name', '') ~ ";" ~ expression ~ ";") -%}
{% else %}
{{ exceptions.raise_compiler_error("Constraint of type " ~ type ~ " with no `name` provided, and no md5 utility.") }}
{% endif %}
{% endif %}
{% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " check (" ~ expression ~ ");" %}
{% do statements.append(stmt) %}
{% elif type == 'not_null' %}
{% set column_names = constraint.get("columns", []) %}
{% set column_names = constraint.get('columns', []) %}
{% if column and not column_names %}
{% set column_names = [column['name']] %}
{% endif %}
Expand All @@ -144,7 +144,7 @@
{% if constraint.get('warn_unenforced') %}
{{ exceptions.warn("unenforced constraint type: " ~ type)}}
{% endif %}
{% set column_names = constraint.get("columns", []) %}
{% set column_names = constraint.get('columns', []) %}
{% if column and not column_names %}
{% set column_names = [column['name']] %}
{% endif %}
Expand All @@ -161,7 +161,7 @@

{% set joined_names = quoted_names|join(", ") %}

{% set name = constraint.get("name") %}
{% set name = constraint.get('name') %}
{% if not name %}
{% if local_md5 %}
{{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }}
Expand All @@ -178,7 +178,7 @@
{{ exceptions.warn("unenforced constraint type: " ~ constraint.type)}}
{% endif %}

{% set name = constraint.get("name") %}
{% set name = constraint.get('name') %}

{% if constraint.get('expression') %}

Expand All @@ -193,7 +193,7 @@

{% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key" ~ constraint.get('expression') %}
{% else %}
{% set column_names = constraint.get("columns", []) %}
{% set column_names = constraint.get('columns', []) %}
{% if column and not column_names %}
{% set column_names = [column['name']] %}
{% endif %}
Expand All @@ -210,7 +210,7 @@

{% set joined_names = quoted_names|join(", ") %}

{% set parent = constraint.get("to") %}
{% set parent = constraint.get('to') %}
{% if not parent %}
{{ exceptions.raise_compiler_error('No parent table defined for foreign key: ' ~ expression) }}
{% endif %}
Expand All @@ -228,21 +228,21 @@
{% endif %}

{% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key(" ~ joined_names ~ ") references " ~ parent %}
{% set parent_columns = constraint.get("to_columns") %}
{% set parent_columns = constraint.get('to_columns') %}
{% if parent_columns %}
{% set stmt = stmt ~ "(" ~ parent_columns|join(", ") ~ ")"%}
{% endif %}
{% endif %}
{% set stmt = stmt ~ ";" %}
{% do statements.append(stmt) %}
{% elif type == 'custom' %}
{% set expression = constraint.get("expression", "") %}
{% set expression = constraint.get('expression', '') %}
{% if not expression %}
{{ exceptions.raise_compiler_error('Missing custom constraint expression') }}
{% endif %}

{% set name = constraint.get("name") %}
{% set expression = constraint.get("expression") %}
{% set name = constraint.get('name') %}
{% set expression = constraint.get('expression') %}
{% if not name %}
{% if local_md5 %}
{{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }}
Expand All @@ -264,15 +264,15 @@
{# convert constraints defined using the original databricks format #}
{% set dbt_constraints = [] %}
{% for constraint in constraints %}
{% if constraint.get and constraint.get("type") %}
{% if constraint.get and constraint.get('type') %}
{# already in model contract format #}
{% do dbt_constraints.append(constraint) %}
{% else %}
{% if column %}
{% if constraint == "not_null" %}
{% do dbt_constraints.append({"type": "not_null", "columns": [column.get("name")]}) %}
{% do dbt_constraints.append({"type": "not_null", "columns": [column.get('name')]}) %}
{% else %}
{{ exceptions.raise_compiler_error('Invalid constraint for column ' ~ column.get("name", "") ~ '. Only `not_null` is supported.') }}
{{ exceptions.raise_compiler_error('Invalid constraint for column ' ~ column.get('name', "") ~ '. Only `not_null` is supported.') }}
{% endif %}
{% else %}
{% set name = constraint['name'] %}
Expand Down
Loading

0 comments on commit ea2b52d

Please sign in to comment.