From 395801ec7810c8c8dd7d71e461a66cdda35c5e18 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Thu, 23 Jan 2025 09:23:57 -0800 Subject: [PATCH] Compress to one connection manager (#910) --- CHANGELOG.md | 6 + dbt/adapters/databricks/connections.py | 386 ++++++++++-------------- dbt/adapters/databricks/global_state.py | 10 - dbt/adapters/databricks/impl.py | 10 +- 4 files changed, 162 insertions(+), 250 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20511cb1..3e72506e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## dbt-databricks 1.9.3 (TBD) + +### Under the Hood + +- Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910)) + ## dbt-databricks 1.9.2 (Jan 21, 2024) ### Features diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 0b523574..3a6b4817 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -1,5 +1,4 @@ import decimal -import os import re import sys import time @@ -10,7 +9,6 @@ from dataclasses import dataclass from multiprocessing.context import SpawnContext from numbers import Number -from threading import get_ident from typing import TYPE_CHECKING, Any, Optional, cast from dbt_common.events.contextvars import get_node_info @@ -59,7 +57,6 @@ CursorCreate, ) from dbt.adapters.databricks.events.other_events import QueryError -from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker from dbt.adapters.databricks.utils import redact_credentials @@ -383,6 +380,9 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): super().__init__(profile, mp_context) creds = cast(DatabricksCredentials, self.profile.credentials) self.api_client = DatabricksApiClient.create(creds, 15 * 60) + self.threads_compute_connections: dict[ + Hashable, dict[Hashable, DatabricksDBTConnection] + ] = {} def cancel_open(self) -> list[str]: cancelled = super().cancel_open() @@ -431,39 +431,19 @@ def set_connection_name( 'connection_named', called by 'connection_for(node)'. Creates a connection for this thread if one doesn't already exist, and will rename an existing connection.""" + self._cleanup_idle_connections() conn_name: str = "master" if name is None else name # Get a connection for this thread - conn = self.get_if_exists() - - if conn and conn.name == conn_name and conn.state == ConnectionState.OPEN: - # Found a connection and nothing to do, so just return it - return conn + conn = self._get_if_exists_compute_connection(_get_compute_name(query_header_context) or "") if conn is None: - # Create a new connection - conn = DatabricksDBTConnection( - type=Identifier(self.TYPE), - name=conn_name, - state=ConnectionState.INIT, - transaction_open=False, - handle=None, - credentials=self.profile.credentials, - ) - conn.handle = LazyHandle(self.get_open_for_context(query_header_context)) - # Add the connection to thread_connections for this thread - self.set_thread_connection(conn) - fire_event( - NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) - ) + conn = self._create_compute_connection(conn_name, query_header_context) else: # existing connection either wasn't open or didn't have the right name - if conn.state != ConnectionState.OPEN: - conn.handle = LazyHandle(self.get_open_for_context(query_header_context)) - if conn.name != conn_name: - orig_conn_name: str = conn.name or "" - conn.name = conn_name - fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name)) + conn = self._update_compute_connection(conn, conn_name) + + conn._acquire(query_header_context) return conn @@ -601,6 +581,34 @@ def list_tables(self, database: str, schema: str, identifier: Optional[str] = No ), ) + # override + def release(self) -> None: + with self.lock: + conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + if conn is None: + return + + conn._release() + + # override + def cleanup_all(self) -> None: + with self.lock: + for thread_connections in self.threads_compute_connections.values(): + for connection in thread_connections.values(): + if connection.acquire_release_count > 0: + fire_event( + ConnectionLeftOpenInCleanup(conn_name=cast_to_str(connection.name)) + ) + else: + fire_event( + ConnectionClosedInCleanup(conn_name=cast_to_str(connection.name)) + ) + self.close(connection) + + # garbage collect these connections + self.thread_connections.clear() + self.threads_compute_connections.clear() + @classmethod def get_open_for_context( cls, query_header_context: Any = None @@ -617,13 +625,8 @@ def open_for_model(connection: Connection) -> Connection: @classmethod def open(cls, connection: Connection) -> Connection: - # Simply call _open with no ResultNode argument. - # Because this is an overridden method we can't just add - # a ResultNode parameter to open. - return cls._open(connection) + databricks_connection = cast(DatabricksDBTConnection, connection) - @classmethod - def _open(cls, connection: Connection, query_header_context: Any = None) -> Connection: if connection.state == ConnectionState.OPEN: return connection @@ -646,12 +649,12 @@ def _open(cls, connection: Connection, query_header_context: Any = None) -> Conn # If a model specifies a compute resource the http path # may be different than the http_path property of creds. - http_path = _get_http_path(query_header_context, creds) + http_path = databricks_connection.http_path def connect() -> DatabricksSQLConnectionWrapper: try: # TODO: what is the error when a user specifies a catalog they don't have access to - conn: DatabricksSQLConnection = dbsql.connect( + conn = dbsql.connect( server_hostname=creds.host, http_path=http_path, credentials_provider=cls.credentials_provider, @@ -663,7 +666,11 @@ def connect() -> DatabricksSQLConnectionWrapper: _user_agent_entry=user_agent_entry, **connection_parameters, ) - logger.debug(ConnectionCreated(str(conn))) + + if conn: + databricks_connection.session_id = conn.get_session_id_hex() + databricks_connection.last_used_time = time.time() + logger.debug(ConnectionCreated(str(databricks_connection))) return DatabricksSQLConnectionWrapper( conn, @@ -693,58 +700,74 @@ def exponential_backoff(attempt: int) -> int: ) @classmethod - def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse: - _query_id = getattr(cursor, "hex_query_id", None) - if cursor is None: - logger.debug("No cursor was provided. Query ID not available.") - query_id = "N/A" - else: - query_id = _query_id - message = "OK" - return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore + def _open(cls, connection: Connection, query_header_context: Any = None) -> Connection: + if connection.state == ConnectionState.OPEN: + return connection + creds: DatabricksCredentials = connection.credentials + timeout = creds.connect_timeout -class ExtendedSessionConnectionManager(DatabricksConnectionManager): - def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None: - assert ( - GlobalState.get_use_long_sessions() - ), "This connection manager should only be used when USE_LONG_SESSIONS is enabled" - super().__init__(profile, mp_context) - self.threads_compute_connections: dict[ - Hashable, dict[Hashable, DatabricksDBTConnection] - ] = {} + # gotta keep this so we don't prompt users many times + cls.credentials_provider = creds.authenticate(cls.credentials_provider) - def set_connection_name( - self, name: Optional[str] = None, query_header_context: Any = None - ) -> Connection: - """Called by 'acquire_connection' in DatabricksAdapter, which is called by - 'connection_named', called by 'connection_for(node)'. - Creates a connection for this thread if one doesn't already - exist, and will rename an existing connection.""" - self._cleanup_idle_connections() + invocation_env = creds.get_invocation_env() + user_agent_entry = cls._user_agent + if invocation_env: + user_agent_entry = f"{cls._user_agent}; {invocation_env}" - conn_name: str = "master" if name is None else name + connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] - # Get a connection for this thread - conn = self._get_if_exists_compute_connection(_get_compute_name(query_header_context) or "") + http_headers: list[tuple[str, str]] = list( + creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() + ) - if conn is None: - conn = self._create_compute_connection(conn_name, query_header_context) - else: # existing connection either wasn't open or didn't have the right name - conn = self._update_compute_connection(conn, conn_name) + # If a model specifies a compute resource the http path + # may be different than the http_path property of creds. + http_path = _get_http_path(query_header_context, creds) - conn._acquire(query_header_context) + def connect() -> DatabricksSQLConnectionWrapper: + try: + # TODO: what is the error when a user specifies a catalog they don't have access to + conn: DatabricksSQLConnection = dbsql.connect( + server_hostname=creds.host, + http_path=http_path, + credentials_provider=cls.credentials_provider, + http_headers=http_headers if http_headers else None, + session_configuration=creds.session_properties, + catalog=creds.database, + use_inline_params="silent", + # schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL. + _user_agent_entry=user_agent_entry, + **connection_parameters, + ) + logger.debug(ConnectionCreated(str(conn))) - return conn + return DatabricksSQLConnectionWrapper( + conn, + is_cluster=creds.cluster_id is not None, + creds=creds, + user_agent=user_agent_entry, + ) + except Error as exc: + logger.error(ConnectionCreateError(exc)) + raise - # override - def release(self) -> None: - with self.lock: - conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) - if conn is None: - return + def exponential_backoff(attempt: int) -> int: + return attempt * attempt - conn._release() + retryable_exceptions = [] + # this option is for backwards compatibility + if creds.retry_all: + retryable_exceptions = [Error] + + return cls.retry_connection( + connection, + connect=connect, + logger=logger, + retryable_exceptions=retryable_exceptions, + retry_limit=creds.connect_retries, + retry_timeout=(timeout if timeout is not None else exponential_backoff), + ) # override @classmethod @@ -756,46 +779,22 @@ def close(cls, connection: Connection) -> Connection: connection.state = ConnectionState.CLOSED return connection - # override - def cleanup_all(self) -> None: - with self.lock: - for thread_connections in self.threads_compute_connections.values(): - for connection in thread_connections.values(): - if connection.acquire_release_count > 0: - fire_event( - ConnectionLeftOpenInCleanup(conn_name=cast_to_str(connection.name)) - ) - else: - fire_event( - ConnectionClosedInCleanup(conn_name=cast_to_str(connection.name)) - ) - self.close(connection) - - # garbage collect these connections - self.thread_connections.clear() - self.threads_compute_connections.clear() - - def _update_compute_connection( - self, conn: DatabricksDBTConnection, new_name: str - ) -> DatabricksDBTConnection: - if conn.name == new_name and conn.state == ConnectionState.OPEN: - # Found a connection and nothing to do, so just return it - return conn - - orig_conn_name: str = conn.name or "" - - if conn.state != ConnectionState.OPEN: - conn.handle = LazyHandle(self.open) - if conn.name != new_name: - conn.name = new_name - fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=new_name)) - - current_thread_conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) - if current_thread_conn and current_thread_conn.compute_name != conn.compute_name: - self.clear_thread_connection() - self.set_thread_connection(conn) + @classmethod + def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse: + _query_id = getattr(cursor, "hex_query_id", None) + if cursor is None: + logger.debug("No cursor was provided. Query ID not available.") + query_id = "N/A" + else: + query_id = _query_id + message = "OK" + return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore - logger.debug(ConnectionReuse(str(conn), orig_conn_name)) + def get_thread_connection(self) -> Connection: + conn = super().get_thread_connection() + self._cleanup_idle_connections() + dbr_conn = cast(DatabricksDBTConnection, conn) + logger.debug(ConnectionRetrieve(str(dbr_conn))) return conn @@ -810,28 +809,6 @@ def _add_compute_connection(self, conn: DatabricksDBTConnection) -> None: ) thread_map[conn.compute_name] = conn - def _get_compute_connections( - self, - ) -> dict[Hashable, DatabricksDBTConnection]: - """Retrieve a map of compute name to connection for the current thread.""" - - thread_id = self.get_thread_identifier() - with self.lock: - thread_map = self.threads_compute_connections.get(thread_id) - if not thread_map: - thread_map = {} - self.threads_compute_connections[thread_id] = thread_map - return thread_map - - def _get_if_exists_compute_connection( - self, compute_name: str - ) -> Optional[DatabricksDBTConnection]: - """Get the connection for the current thread and named compute, if it exists.""" - - with self.lock: - threads_map = self._get_compute_connections() - return threads_map.get(compute_name) - def _cleanup_idle_connections(self) -> None: with self.lock: # Get all connections associated with this thread. There can be multiple connections @@ -897,95 +874,51 @@ def _create_compute_connection( return conn - def get_thread_connection(self) -> Connection: - conn = super().get_thread_connection() - self._cleanup_idle_connections() - dbr_conn = cast(DatabricksDBTConnection, conn) - logger.debug(ConnectionRetrieve(str(dbr_conn))) - - return conn - - @classmethod - def open(cls, connection: Connection) -> Connection: - # Once long session management is no longer under the USE_LONG_SESSIONS toggle - # this should be renamed and replace the _open class method. - assert ( - GlobalState.get_use_long_sessions() - ), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS" - - databricks_connection = cast(DatabricksDBTConnection, connection) - - if connection.state == ConnectionState.OPEN: - return connection - - creds: DatabricksCredentials = connection.credentials - timeout = creds.connect_timeout - - # gotta keep this so we don't prompt users many times - cls.credentials_provider = creds.authenticate(cls.credentials_provider) - - invocation_env = creds.get_invocation_env() - user_agent_entry = cls._user_agent - if invocation_env: - user_agent_entry = f"{cls._user_agent}; {invocation_env}" + def _get_if_exists_compute_connection( + self, compute_name: str + ) -> Optional[DatabricksDBTConnection]: + """Get the connection for the current thread and named compute, if it exists.""" - connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] + with self.lock: + threads_map = self._get_compute_connections() + return threads_map.get(compute_name) - http_headers: list[tuple[str, str]] = list( - creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() - ) + def _get_compute_connections( + self, + ) -> dict[Hashable, DatabricksDBTConnection]: + """Retrieve a map of compute name to connection for the current thread.""" - # If a model specifies a compute resource the http path - # may be different than the http_path property of creds. - http_path = databricks_connection.http_path + thread_id = self.get_thread_identifier() + with self.lock: + thread_map = self.threads_compute_connections.get(thread_id) + if not thread_map: + thread_map = {} + self.threads_compute_connections[thread_id] = thread_map + return thread_map - def connect() -> DatabricksSQLConnectionWrapper: - try: - # TODO: what is the error when a user specifies a catalog they don't have access to - conn = dbsql.connect( - server_hostname=creds.host, - http_path=http_path, - credentials_provider=cls.credentials_provider, - http_headers=http_headers if http_headers else None, - session_configuration=creds.session_properties, - catalog=creds.database, - use_inline_params="silent", - # schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL. - _user_agent_entry=user_agent_entry, - **connection_parameters, - ) + def _update_compute_connection( + self, conn: DatabricksDBTConnection, new_name: str + ) -> DatabricksDBTConnection: + if conn.name == new_name and conn.state == ConnectionState.OPEN: + # Found a connection and nothing to do, so just return it + return conn - if conn: - databricks_connection.session_id = conn.get_session_id_hex() - databricks_connection.last_used_time = time.time() - logger.debug(ConnectionCreated(str(databricks_connection))) + orig_conn_name: str = conn.name or "" - return DatabricksSQLConnectionWrapper( - conn, - is_cluster=creds.cluster_id is not None, - creds=creds, - user_agent=user_agent_entry, - ) - except Error as exc: - logger.error(ConnectionCreateError(exc)) - raise + if conn.state != ConnectionState.OPEN: + conn.handle = LazyHandle(self.open) + if conn.name != new_name: + conn.name = new_name + fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=new_name)) - def exponential_backoff(attempt: int) -> int: - return attempt * attempt + current_thread_conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + if current_thread_conn and current_thread_conn.compute_name != conn.compute_name: + self.clear_thread_connection() + self.set_thread_connection(conn) - retryable_exceptions = [] - # this option is for backwards compatibility - if creds.retry_all: - retryable_exceptions = [Error] + logger.debug(ConnectionReuse(str(conn), orig_conn_name)) - return cls.retry_connection( - connection, - connect=connect, - logger=logger, - retryable_exceptions=retryable_exceptions, - retry_limit=creds.connect_retries, - retry_timeout=(timeout if timeout is not None else exponential_backoff), - ) + return conn def _get_compute_name(query_header_context: Any) -> Optional[str]: @@ -1005,24 +938,18 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O """Get the http_path for the compute specified for the node. If none is specified default will be used.""" - thread_id = (os.getpid(), get_ident()) - # ResultNode *should* have relation_name attr, but we work around a core # issue by checking. relation_name = getattr(query_header_context, "relation_name", "[unknown]") # If there is no node we return the http_path for the default compute. if not query_header_context: - if not GlobalState.get_use_long_sessions(): - logger.debug(f"Thread {thread_id}: using default compute resource.") return creds.http_path # Get the name of the compute resource specified in the node's config. # If none is specified return the http_path for the default compute. compute_name = _get_compute_name(query_header_context) if not compute_name: - if not GlobalState.get_use_long_sessions(): - logger.debug(f"On thread {thread_id}: {relation_name} using default compute resource.") return creds.http_path # Get the http_path for the named compute. @@ -1037,11 +964,6 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O f"does not specify http_path, relation: {relation_name}" ) - if not GlobalState.get_use_long_sessions(): - logger.debug( - f"On thread {thread_id}: {relation_name} using compute resource '{compute_name}'." - ) - return http_path diff --git a/dbt/adapters/databricks/global_state.py b/dbt/adapters/databricks/global_state.py index de240d39..cdc5df98 100644 --- a/dbt/adapters/databricks/global_state.py +++ b/dbt/adapters/databricks/global_state.py @@ -7,16 +7,6 @@ class GlobalState: single place than scattered throughout the codebase. """ - __use_long_sessions: ClassVar[Optional[bool]] = None - - @classmethod - def get_use_long_sessions(cls) -> bool: - if cls.__use_long_sessions is None: - cls.__use_long_sessions = ( - os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" - ) - return cls.__use_long_sessions - __invocation_env: ClassVar[Optional[str]] = None __invocation_env_set: ClassVar[bool] = False diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 15c333e2..d106dd1c 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -31,10 +31,7 @@ GetColumnsByInformationSchema, ) from dbt.adapters.databricks.column import DatabricksColumn -from dbt.adapters.databricks.connections import ( - DatabricksConnectionManager, - ExtendedSessionConnectionManager, -) +from dbt.adapters.databricks.connections import DatabricksConnectionManager from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.python_models.python_submissions import ( AllPurposeClusterPythonJobHelper, @@ -154,10 +151,7 @@ class DatabricksAdapter(SparkAdapter): Relation = DatabricksRelation Column = DatabricksColumn - if GlobalState.get_use_long_sessions(): - ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager - else: - ConnectionManager = DatabricksConnectionManager + ConnectionManager = DatabricksConnectionManager connections: DatabricksConnectionManager