From 5f6412d71c12a31ba47446f1e4ec3d642691616a Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:03:34 -0800 Subject: [PATCH 1/2] Refactor reading env vars (#888) --- CHANGELOG.md | 6 +++ dbt/adapters/databricks/connections.py | 16 +++---- dbt/adapters/databricks/credentials.py | 8 ++-- dbt/adapters/databricks/global_state.py | 58 +++++++++++++++++++++++++ dbt/adapters/databricks/impl.py | 6 +-- dbt/adapters/databricks/logging.py | 4 +- tests/unit/test_adapter.py | 33 +++++++++----- 7 files changed, 102 insertions(+), 29 deletions(-) create mode 100644 dbt/adapters/databricks/global_state.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a4c107a9..59af816c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## dbt-databricks 1.9.2 (TBD) + +### Under the Hood + +- Refactor global state reading ([888](https://github.com/databricks/dbt-databricks/pull/888)) + ## dbt-databricks 1.9.1 (December 16, 2024) ### Features diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 509686d7..0b523574 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -59,6 +59,7 @@ CursorCreate, ) from dbt.adapters.databricks.events.other_events import QueryError +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker from dbt.adapters.databricks.utils import redact_credentials @@ -86,9 +87,6 @@ DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)") -# toggle for session managements that minimizes the number of sessions opened/closed -USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" - # Number of idle seconds before a connection is automatically closed. Only applicable if # USE_LONG_SESSIONS is true. # Updated when idle times of 180s were causing errors @@ -475,6 +473,8 @@ def add_query( auto_begin: bool = True, bindings: Optional[Any] = None, abridge_sql_log: bool = False, + retryable_exceptions: tuple[type[Exception], ...] = tuple(), + retry_limit: int = 1, *, close_cursor: bool = False, ) -> tuple[Connection, Any]: @@ -707,7 +707,7 @@ def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterRe class ExtendedSessionConnectionManager(DatabricksConnectionManager): def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None: assert ( - USE_LONG_SESSIONS + GlobalState.get_use_long_sessions() ), "This connection manager should only be used when USE_LONG_SESSIONS is enabled" super().__init__(profile, mp_context) self.threads_compute_connections: dict[ @@ -910,7 +910,7 @@ def open(cls, connection: Connection) -> Connection: # Once long session management is no longer under the USE_LONG_SESSIONS toggle # this should be renamed and replace the _open class method. assert ( - USE_LONG_SESSIONS + GlobalState.get_use_long_sessions() ), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS" databricks_connection = cast(DatabricksDBTConnection, connection) @@ -1013,7 +1013,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O # If there is no node we return the http_path for the default compute. if not query_header_context: - if not USE_LONG_SESSIONS: + if not GlobalState.get_use_long_sessions(): logger.debug(f"Thread {thread_id}: using default compute resource.") return creds.http_path @@ -1021,7 +1021,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O # If none is specified return the http_path for the default compute. compute_name = _get_compute_name(query_header_context) if not compute_name: - if not USE_LONG_SESSIONS: + if not GlobalState.get_use_long_sessions(): logger.debug(f"On thread {thread_id}: {relation_name} using default compute resource.") return creds.http_path @@ -1037,7 +1037,7 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O f"does not specify http_path, relation: {relation_name}" ) - if not USE_LONG_SESSIONS: + if not GlobalState.get_use_long_sessions(): logger.debug( f"On thread {thread_id}: {relation_name} using compute resource '{compute_name}'." ) diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 7a318cad..387d0e76 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -19,10 +19,10 @@ CredentialSaveError, CredentialShardEvent, ) +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" -DBT_DATABRICKS_INVOCATION_ENV = "DBT_DATABRICKS_INVOCATION_ENV" DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$") EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)") DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS" @@ -150,7 +150,7 @@ def validate_creds(self) -> None: @classmethod def get_invocation_env(cls) -> Optional[str]: - invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV) + invocation_env = GlobalState.get_invocation_env() if invocation_env: # Thrift doesn't allow nested () so we need to ensure # that the passed user agent is valid. @@ -160,9 +160,7 @@ def get_invocation_env(cls) -> Optional[str]: @classmethod def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]: - http_session_headers_str: Optional[str] = os.environ.get( - DBT_DATABRICKS_HTTP_SESSION_HEADERS - ) + http_session_headers_str = GlobalState.get_http_session_headers() http_session_headers_dict: dict[str, str] = ( { diff --git a/dbt/adapters/databricks/global_state.py b/dbt/adapters/databricks/global_state.py new file mode 100644 index 00000000..de240d39 --- /dev/null +++ b/dbt/adapters/databricks/global_state.py @@ -0,0 +1,58 @@ +import os +from typing import ClassVar, Optional + + +class GlobalState: + """Global state is a bad idea, but since we don't control instantiation, better to have it in a + single place than scattered throughout the codebase. + """ + + __use_long_sessions: ClassVar[Optional[bool]] = None + + @classmethod + def get_use_long_sessions(cls) -> bool: + if cls.__use_long_sessions is None: + cls.__use_long_sessions = ( + os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" + ) + return cls.__use_long_sessions + + __invocation_env: ClassVar[Optional[str]] = None + __invocation_env_set: ClassVar[bool] = False + + @classmethod + def get_invocation_env(cls) -> Optional[str]: + if not cls.__invocation_env_set: + cls.__invocation_env = os.getenv("DBT_DATABRICKS_INVOCATION_ENV") + cls.__invocation_env_set = True + return cls.__invocation_env + + __session_headers: ClassVar[Optional[str]] = None + __session_headers_set: ClassVar[bool] = False + + @classmethod + def get_http_session_headers(cls) -> Optional[str]: + if not cls.__session_headers_set: + cls.__session_headers = os.getenv("DBT_DATABRICKS_HTTP_SESSION_HEADERS") + cls.__session_headers_set = True + return cls.__session_headers + + __describe_char_bypass: ClassVar[Optional[bool]] = None + + @classmethod + def get_char_limit_bypass(cls) -> bool: + if cls.__describe_char_bypass is None: + cls.__describe_char_bypass = ( + os.getenv("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE" + ) + return cls.__describe_char_bypass + + __connector_log_level: ClassVar[Optional[str]] = None + + @classmethod + def get_connector_log_level(cls) -> str: + if cls.__connector_log_level is None: + cls.__connector_log_level = os.getenv( + "DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN" + ).upper() + return cls.__connector_log_level diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index dce432c9..15c333e2 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -32,10 +32,10 @@ ) from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.connections import ( - USE_LONG_SESSIONS, DatabricksConnectionManager, ExtendedSessionConnectionManager, ) +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.python_models.python_submissions import ( AllPurposeClusterPythonJobHelper, JobClusterPythonJobHelper, @@ -142,7 +142,7 @@ def get_identifier_list_string(table_names: set[str]) -> str: """ _identifier = "|".join(table_names) - bypass_2048_char_limit = os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "false") + bypass_2048_char_limit = GlobalState.get_char_limit_bypass() if bypass_2048_char_limit == "true": _identifier = _identifier if len(_identifier) < 2048 else "*" return _identifier @@ -154,7 +154,7 @@ class DatabricksAdapter(SparkAdapter): Relation = DatabricksRelation Column = DatabricksColumn - if USE_LONG_SESSIONS: + if GlobalState.get_use_long_sessions(): ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager else: ConnectionManager = DatabricksConnectionManager diff --git a/dbt/adapters/databricks/logging.py b/dbt/adapters/databricks/logging.py index d0f1d42b..81e7449e 100644 --- a/dbt/adapters/databricks/logging.py +++ b/dbt/adapters/databricks/logging.py @@ -1,7 +1,7 @@ -import os from logging import Handler, LogRecord, getLogger from typing import Union +from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.events.logging import AdapterLogger logger = AdapterLogger("Databricks") @@ -22,7 +22,7 @@ def emit(self, record: LogRecord) -> None: dbt_adapter_logger = AdapterLogger("databricks-sql-connector") pysql_logger = getLogger("databricks.sql") -pysql_logger_level = os.environ.get("DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN").upper() +pysql_logger_level = GlobalState.get_connector_log_level() pysql_logger.setLevel(pysql_logger_level) pysql_handler = DbtCoreHandler(dbt_logger=dbt_adapter_logger, level=pysql_logger_level) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 78ae12cb..d42fa5e1 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -11,8 +11,6 @@ from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.credentials import ( CATALOG_KEY_IN_SESSION_PROPERTIES, - DBT_DATABRICKS_HTTP_SESSION_HEADERS, - DBT_DATABRICKS_INVOCATION_ENV, ) from dbt.adapters.databricks.impl import get_identifier_list_string from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType @@ -114,7 +112,10 @@ def test_invalid_custom_user_agent(self): with pytest.raises(DbtValidationError) as excinfo: config = self._get_config() adapter = DatabricksAdapter(config, get_context("spawn")) - with patch.dict("os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "(Some-thing)"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_invocation_env", + return_value="(Some-thing)", + ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -128,8 +129,9 @@ def test_custom_user_agent(self): "dbt.adapters.databricks.connections.dbsql.connect", new=self._connect_func(expected_invocation_env="databricks-workflows"), ): - with patch.dict( - "os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "databricks-workflows"} + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_invocation_env", + return_value="databricks-workflows", ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -190,9 +192,9 @@ def _test_environment_http_headers( "dbt.adapters.databricks.connections.dbsql.connect", new=self._connect_func(expected_http_headers=expected_http_headers), ): - with patch.dict( - "os.environ", - **{DBT_DATABRICKS_HTTP_SESSION_HEADERS: http_headers_str}, + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_http_session_headers", + return_value=http_headers_str, ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -912,7 +914,10 @@ def test_describe_table_extended_2048_char_limit(self): assert get_identifier_list_string(table_names) == "|".join(table_names) # If environment variable is set, then limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # Long list of table names is capped assert get_identifier_list_string(table_names) == "*" @@ -941,7 +946,10 @@ def test_describe_table_extended_should_limit(self): table_names = set([f"customers_{i}" for i in range(200)]) # If environment variable is set, then limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # Long list of table names is capped assert get_identifier_list_string(table_names) == "*" @@ -954,7 +962,10 @@ def test_describe_table_extended_may_limit(self): table_names = set([f"customers_{i}" for i in range(200)]) # If environment variable is set, then we may limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # But a short list of table names is not capped assert get_identifier_list_string(list(table_names)[:5]) == "|".join( list(table_names)[:5] From b0ff51bd7198d456681a5b1ca822e5adba34b732 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Fri, 20 Dec 2024 13:53:55 -0800 Subject: [PATCH 2/2] Use single quotes in gets in templates (#889) --- .../macros/materializations/seeds/helpers.sql | 2 +- .../macros/relations/constraints.sql | 34 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/dbt/include/databricks/macros/materializations/seeds/helpers.sql b/dbt/include/databricks/macros/materializations/seeds/helpers.sql index df690f18..82acaba3 100644 --- a/dbt/include/databricks/macros/materializations/seeds/helpers.sql +++ b/dbt/include/databricks/macros/materializations/seeds/helpers.sql @@ -6,7 +6,7 @@ {% set batch_size = get_batch_size() %} {% set column_override = model['config'].get('column_types', {}) %} - {% set must_cast = model['config'].get("file_format", "delta") == "parquet" %} + {% set must_cast = model['config'].get('file_format', 'delta') == 'parquet' %} {% set statements = [] %} diff --git a/dbt/include/databricks/macros/relations/constraints.sql b/dbt/include/databricks/macros/relations/constraints.sql index 68f3a44f..34d5b415 100644 --- a/dbt/include/databricks/macros/relations/constraints.sql +++ b/dbt/include/databricks/macros/relations/constraints.sql @@ -106,19 +106,19 @@ {% macro get_constraint_sql(relation, constraint, model, column={}) %} {% set statements = [] %} - {% set type = constraint.get("type", "") %} + {% set type = constraint.get('type', '') %} {% if type == 'check' %} - {% set expression = constraint.get("expression", "") %} + {% set expression = constraint.get('expression', '') %} {% if not expression %} {{ exceptions.raise_compiler_error('Invalid check constraint expression') }} {% endif %} - {% set name = constraint.get("name") %} + {% set name = constraint.get('name') %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} - {%- set name = local_md5 (relation.identifier ~ ";" ~ column.get("name", "") ~ ";" ~ expression ~ ";") -%} + {%- set name = local_md5 (relation.identifier ~ ";" ~ column.get('name', '') ~ ";" ~ expression ~ ";") -%} {% else %} {{ exceptions.raise_compiler_error("Constraint of type " ~ type ~ " with no `name` provided, and no md5 utility.") }} {% endif %} @@ -126,7 +126,7 @@ {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " check (" ~ expression ~ ");" %} {% do statements.append(stmt) %} {% elif type == 'not_null' %} - {% set column_names = constraint.get("columns", []) %} + {% set column_names = constraint.get('columns', []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} @@ -144,7 +144,7 @@ {% if constraint.get('warn_unenforced') %} {{ exceptions.warn("unenforced constraint type: " ~ type)}} {% endif %} - {% set column_names = constraint.get("columns", []) %} + {% set column_names = constraint.get('columns', []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} @@ -161,7 +161,7 @@ {% set joined_names = quoted_names|join(", ") %} - {% set name = constraint.get("name") %} + {% set name = constraint.get('name') %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} @@ -178,7 +178,7 @@ {{ exceptions.warn("unenforced constraint type: " ~ constraint.type)}} {% endif %} - {% set name = constraint.get("name") %} + {% set name = constraint.get('name') %} {% if constraint.get('expression') %} @@ -193,7 +193,7 @@ {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key" ~ constraint.get('expression') %} {% else %} - {% set column_names = constraint.get("columns", []) %} + {% set column_names = constraint.get('columns', []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} @@ -210,7 +210,7 @@ {% set joined_names = quoted_names|join(", ") %} - {% set parent = constraint.get("to") %} + {% set parent = constraint.get('to') %} {% if not parent %} {{ exceptions.raise_compiler_error('No parent table defined for foreign key: ' ~ expression) }} {% endif %} @@ -228,7 +228,7 @@ {% endif %} {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key(" ~ joined_names ~ ") references " ~ parent %} - {% set parent_columns = constraint.get("to_columns") %} + {% set parent_columns = constraint.get('to_columns') %} {% if parent_columns %} {% set stmt = stmt ~ "(" ~ parent_columns|join(", ") ~ ")"%} {% endif %} @@ -236,13 +236,13 @@ {% set stmt = stmt ~ ";" %} {% do statements.append(stmt) %} {% elif type == 'custom' %} - {% set expression = constraint.get("expression", "") %} + {% set expression = constraint.get('expression', '') %} {% if not expression %} {{ exceptions.raise_compiler_error('Missing custom constraint expression') }} {% endif %} - {% set name = constraint.get("name") %} - {% set expression = constraint.get("expression") %} + {% set name = constraint.get('name') %} + {% set expression = constraint.get('expression') %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} @@ -264,15 +264,15 @@ {# convert constraints defined using the original databricks format #} {% set dbt_constraints = [] %} {% for constraint in constraints %} - {% if constraint.get and constraint.get("type") %} + {% if constraint.get and constraint.get('type') %} {# already in model contract format #} {% do dbt_constraints.append(constraint) %} {% else %} {% if column %} {% if constraint == "not_null" %} - {% do dbt_constraints.append({"type": "not_null", "columns": [column.get("name")]}) %} + {% do dbt_constraints.append({"type": "not_null", "columns": [column.get('name')]}) %} {% else %} - {{ exceptions.raise_compiler_error('Invalid constraint for column ' ~ column.get("name", "") ~ '. Only `not_null` is supported.') }} + {{ exceptions.raise_compiler_error('Invalid constraint for column ' ~ column.get('name', "") ~ '. Only `not_null` is supported.') }} {% endif %} {% else %} {% set name = constraint['name'] %}