Skip to content

Commit

Permalink
Merge branch 'main' into issue-337-redo
Browse files Browse the repository at this point in the history
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
  • Loading branch information
Jesse Whitehouse committed Jan 16, 2024
2 parents ad96587 + 79e8707 commit 54d0491
Show file tree
Hide file tree
Showing 15 changed files with 245 additions and 105 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/build_cluster_http_path.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import os
import re

workspace_id = os.getenv("DBT_DATABRICKS_HOST_NAME")[4:18]
workspace_re = re.compile(r"^.*-(\d+)\..*$")
hostname = os.getenv("DBT_DATABRICKS_HOST_NAME", "")
matches = workspace_re.match(hostname)
if matches:
workspace_id = matches.group(1)
print(workspace_id)
cluster_id = os.getenv("TEST_PECO_CLUSTER_ID")
uc_cluster_id = os.getenv("TEST_PECO_UC_CLUSTER_ID")
http_path = f"sql/protocolv1/o/{workspace_id}/{cluster_id}"
uc_http_path = f"sql/protocolv1/o/{workspace_id}/{uc_cluster_id}"

# https://stackoverflow.com/a/72225291/5093960
env_file = os.getenv("GITHUB_ENV")
env_file = os.getenv("GITHUB_ENV", "")
with open(env_file, "a") as myfile:
myfile.write(f"DBT_DATABRICKS_CLUSTER_HTTP_PATH={http_path}\n")
myfile.write(f"DBT_DATABRICKS_UC_CLUSTER_HTTP_PATH={uc_http_path}\n")
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
## dbt-databricks 1.7.4 (TBD)

### Fixes

- Added python model specific connection handling to prevent using invalid sessions ([547](https://github.com/databricks/dbt-databricks/pull/547))
- Allow schema to be specified in testing (thanks @case-k-git!) ([538](https://github.com/databricks/dbt-databricks/pull/538))
- Fix dbt incremental_strategy behavior by fixing schema table existing check (thanks @case-k-git!) ([530](https://github.com/databricks/dbt-databricks/pull/530))

### Under the Hood

- Adding retries around API calls in python model submission ([549](https://github.com/databricks/dbt-databricks/pull/549))

## dbt-databricks 1.7.3 (Dec 12, 2023)

### Fixes
Expand Down
77 changes: 67 additions & 10 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,11 +737,26 @@ class DatabricksDBTConnection(Connection):
thread_identifier: Tuple[int, int] = (0, 0)
max_idle_time: int = DEFAULT_MAX_IDLE_TIME

# If the connection is being used for a model we want to track the model language.
# We do this because we need special handling for python models. Python models will
# acquire a connection, but do not actually use it to run the model. This can lead to the
# session timing out on the back end. However, when the connection is released we set the
# last_used_time, essentially indicating that the connection was in use while the python
# model was running. So the session is not refreshed by idle connection cleanup and errors
# the next time it is used.
language: Optional[str] = None

def _acquire(self, node: Optional[ResultNode]) -> None:
"""Indicate that this connection is in use."""
logger.debug(f"DatabricksDBTConnection._acquire: {self._get_conn_info_str()}")
self._log_usage(node)
self.acquire_release_count += 1
if self.last_used_time is None:
self.last_used_time = time.time()
if node and hasattr(node, "language"):
self.language = node.language
else:
self.language = None

def _release(self) -> None:
"""Indicate that this connection is not in use."""
Expand All @@ -751,7 +766,9 @@ def _release(self) -> None:
if self.acquire_release_count > 0:
self.acquire_release_count -= 1

if self.acquire_release_count == 0:
# We don't update the last_used_time for python models because the python model
# is submitted through a different mechanism and doesn't actually use the connection.
if self.acquire_release_count == 0 and self.language != "python":
self.last_used_time = time.time()

def _get_idle_time(self) -> float:
Expand All @@ -765,7 +782,7 @@ def _get_conn_info_str(self) -> str:
return (
f"name: {self.name}, thread: {self.thread_identifier}, "
f"compute: `{self.compute_name}`, acquire_release_count: {self.acquire_release_count},"
f" idle time: {self._get_idle_time()}s"
f" idle time: {self._get_idle_time()}s, language: {self.language}"
)

def _log_usage(self, node: Optional[ResultNode]) -> None:
Expand All @@ -783,6 +800,13 @@ def _log_usage(self, node: Optional[ResultNode]) -> None:
else:
logger.debug(f"Thread {self.thread_identifier} using default compute resource.")

def _reset_handle(self, open: Callable[[Connection], Connection]) -> None:
logger.debug(f"DatabricksDBTConnection._reset_handle: {self._get_conn_info_str()}")
self.handle = LazyHandle(open)
# Reset last_used_time to None because by refreshing this connection becomes associated
# with a new session that hasn't been used yet.
self.last_used_time = None


class DatabricksConnectionManager(SparkConnectionManager):
TYPE: str = "databricks"
Expand Down Expand Up @@ -889,6 +913,19 @@ def release(self) -> None:

conn._release()

# override
@classmethod
def close(cls, connection: Connection) -> Connection:
if not USE_LONG_SESSIONS:
return super().close(connection)

try:
return super().close(connection)
except Exception as e:
logger.warning(f"ignoring error when closing connection: {e}")
connection.state = ConnectionState.CLOSED
return connection

# override
def cleanup_all(self) -> None:
if not USE_LONG_SESSIONS:
Expand Down Expand Up @@ -1063,12 +1100,32 @@ def _cleanup_idle_connections(self) -> None:
), "This path, '_cleanup_idle_connections', should only be reachable with USE_LONG_SESSIONS"

with self.lock:
for thread_conns in self.threads_compute_connections.values():
for conn in thread_conns.values():
if conn.acquire_release_count == 0 and conn._idle_too_long():
logger.debug(f"closing idle connection: {conn._get_conn_info_str()}")
self.close(conn)
conn.handle = LazyHandle(self._open2)
# Get all connections associated with this thread. There can be multiple connections
# if different models use different compute resources
thread_conns = self._get_compute_connections()
for conn in thread_conns.values():
# Generally speaking we only want to close/refresh the connection if the
# acquire_release_count is zero. i.e. the connection is not currently in use.
# However python models acquire a connection then run the pyton model, which
# doesn't actually use the connection. If the python model takes lone enought to
# run the connection can be idle long enough to timeout on the back end.
# If additional sql needs to be run after the python model, but before the
# connection is released, the connection needs to be refreshed or there will
# be a failure. Making an exception when language is 'python' allows the
# the call to _cleanup_idle_connections from get_thread_connection to refresh the
# connection in this scenario.
if (
conn.acquire_release_count == 0 or conn.language == "python"
) and conn._idle_too_long():
logger.debug(f"closing idle connection: {conn._get_conn_info_str()}")
self.close(conn)
conn._reset_handle(self._open2)

def get_thread_connection(self) -> Connection:
if USE_LONG_SESSIONS:
self._cleanup_idle_connections()

return super().get_thread_connection()

def add_query(
self,
Expand Down Expand Up @@ -1181,15 +1238,15 @@ def _execute_cursor(
def list_schemas(self, database: str, schema: Optional[str] = None) -> Table:
database = database.strip("`")
if schema:
schema = schema.strip("`")
schema = schema.strip("`").lower()
return self._execute_cursor(
f"GetSchemas(database={database}, schema={schema})",
lambda cursor: cursor.schemas(catalog_name=database, schema_name=schema),
)

def list_tables(self, database: str, schema: str, identifier: Optional[str] = None) -> Table:
database = database.strip("`")
schema = schema.strip("`")
schema = schema.strip("`").lower()
if identifier:
identifier = identifier.strip("`")
return self._execute_cursor(
Expand Down
64 changes: 42 additions & 22 deletions dbt/adapters/databricks/python_submissions.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
from typing import Any, Dict, Tuple, Optional, Callable, Union

from requests import Session

from dbt.adapters.databricks.__version__ import version
from dbt.adapters.databricks.connections import DatabricksCredentials
from dbt.adapters.databricks import utils

import base64
import time
import requests
import uuid

from urllib3.util.retry import Retry

from dbt.events import AdapterLogger
import dbt.exceptions
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.spark import __version__

from databricks.sdk.core import CredentialsProvider, HeaderFactory
from requests.adapters import HTTPAdapter
from requests.auth import AuthBase
from requests import PreparedRequest



logger = AdapterLogger("Databricks")

Expand All @@ -23,7 +32,7 @@
DBT_SPARK_VERSION = __version__.version


class BearerAuth(requests.auth.AuthBase):
class BearerAuth(AuthBase):
"""See issue #337.
We use this mix-in to stop requests from implicitly reading .netrc
Expand All @@ -34,7 +43,7 @@ class BearerAuth(requests.auth.AuthBase):
def __init__(self, headers: HeaderFactory):
self.headers = headers()

def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
def __call__(self, r: PreparedRequest) -> PreparedRequest:
r.headers.update(**self.headers)
return r

Expand All @@ -47,6 +56,13 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No
self.parsed_model = parsed_model
self.timeout = self.get_timeout()
self.polling_interval = DEFAULT_POLLING_INTERVAL

# This should be passed in, but not sure where this is actually instantiated
retry_strategy = Retry(total=4, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry_strategy)
self.session = Session()
self.session.mount("https://", adapter)

self.check_credentials()
self.extra_headers = {
"User-Agent": f"dbt-labs-dbt-spark/{DBT_SPARK_VERSION} (Databricks)",
Expand All @@ -70,7 +86,7 @@ def check_credentials(self) -> None:
)

def _create_work_dir(self, path: str) -> None:
response = requests.post(
response = self.session.post(
f"https://{self.credentials.host}/api/2.0/workspace/mkdirs",
auth=self.auth,
headers=self.extra_headers,
Expand All @@ -91,7 +107,7 @@ def _update_with_acls(self, cluster_dict: dict) -> dict:

def _upload_notebook(self, path: str, compiled_code: str) -> None:
b64_encoded_content = base64.b64encode(compiled_code.encode()).decode()
response = requests.post(
response = self.session.post(
f"https://{self.credentials.host}/api/2.0/workspace/import",
auth=self.auth,
headers=self.extra_headers,
Expand Down Expand Up @@ -137,7 +153,7 @@ def _submit_job(self, path: str, cluster_spec: dict) -> str:
libraries.append(lib)

job_spec.update({"libraries": libraries}) # type: ignore
submit_response = requests.post(
submit_response = self.session.post(
f"https://{self.credentials.host}/api/2.1/jobs/runs/submit",
auth=self.auth,
headers=self.extra_headers,
Expand All @@ -163,7 +179,7 @@ def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> No
run_id = self._submit_job(whole_file_path, cluster_spec)

self.polling(
status_func=requests.get,
status_func=self.session.get,
status_func_kwargs={
"url": f"https://{self.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}",
"auth": self.auth,
Expand All @@ -176,7 +192,7 @@ def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> No
)

# get end state to return to user
run_output = requests.get(
run_output = self.session.get(
f"https://{self.credentials.host}" f"/api/2.1/jobs/runs/get-output?run_id={run_id}",
auth=self.auth,
headers=self.extra_headers,
Expand Down Expand Up @@ -243,12 +259,14 @@ def __init__(
credentials: DatabricksCredentials,
cluster_id: str,
auth: Union[BearerAuth, None],
extra_headers: dict,
auth_header: dict,
session: Session,
) -> None:
self.auth = auth
self.extra_headers = extra_headers
self.auth_header = auth_header
self.cluster_id = cluster_id
self.host = credentials.host
self.session = session

def create(self) -> str:
# https://docs.databricks.com/dev-tools/api/1.2/index.html#create-an-execution-context
Expand All @@ -262,10 +280,10 @@ def create(self) -> str:
if current_status != "RUNNING":
self._wait_for_cluster_to_start()

response = requests.post(
response = self.session.post(
f"https://{self.host}/api/1.2/contexts/create",
auth=self.auth,
headers=self.extra_headers,
headers=self.auth_header,
json={
"clusterId": self.cluster_id,
"language": SUBMISSION_LANGUAGE,
Expand All @@ -279,7 +297,7 @@ def create(self) -> str:

def destroy(self, context_id: str) -> str:
# https://docs.databricks.com/dev-tools/api/1.2/index.html#delete-an-execution-context
response = requests.post(
response = self.session.post(
f"https://{self.host}/api/1.2/contexts/destroy",
auth=self.auth,
headers=self.extra_headers,
Expand All @@ -297,7 +315,7 @@ def destroy(self, context_id: str) -> str:
def get_cluster_status(self) -> Dict:
# https://docs.databricks.com/dev-tools/api/latest/clusters.html#get

response = requests.get(
response = self.session.get(
f"https://{self.host}/api/2.0/clusters/get",
auth=self.auth,
headers=self.extra_headers,
Expand All @@ -321,7 +339,7 @@ def start_cluster(self) -> None:

logger.debug(f"Sending restart command for cluster id {self.cluster_id}")

response = requests.post(
response = self.session.post(
f"https://{self.host}/api/2.0/clusters/start",
auth=self.auth,
headers=self.extra_headers,
Expand Down Expand Up @@ -362,19 +380,21 @@ def __init__(
credentials: DatabricksCredentials,
cluster_id: str,
auth: Union[BearerAuth, None],
extra_headers: dict,
auth_header: dict,
session: Session,
) -> None:
self.auth = auth
self.extra_headers = extra_headers
self.auth_header = auth_header
self.cluster_id = cluster_id
self.host = credentials.host
self.session = session

def execute(self, context_id: str, command: str) -> str:
# https://docs.databricks.com/dev-tools/api/1.2/index.html#run-a-command
response = requests.post(
response = self.session.post(
f"https://{self.host}/api/1.2/commands/execute",
auth=self.auth,
headers=self.extra_headers,
headers=self.auth_header,
json={
"clusterId": self.cluster_id,
"contextId": context_id,
Expand All @@ -391,7 +411,7 @@ def execute(self, context_id: str, command: str) -> str:

def status(self, context_id: str, command_id: str) -> Dict[str, Any]:
# https://docs.databricks.com/dev-tools/api/1.2/index.html#get-information-about-a-command
response = requests.get(
response = self.session.get(
f"https://{self.host}/api/1.2/commands/status",
auth=self.auth,
headers=self.extra_headers,
Expand Down Expand Up @@ -421,8 +441,8 @@ def submit(self, compiled_code: str) -> None:
config = {"existing_cluster_id": self.cluster_id}
self._submit_through_notebook(compiled_code, self._update_with_acls(config))
else:
context = DBContext(self.credentials, self.cluster_id, self.auth, self.extra_headers)
command = DBCommand(self.credentials, self.cluster_id, self.auth, self.extra_headers)
context = DBContext(self.credentials, self.cluster_id, self.auth, self.auth_header, self.session)
command = DBCommand(self.credentials, self.cluster_id, self.auth, self.auth_header, self.session)
context_id = context.create()
try:
command_id = command.execute(context_id, compiled_code)
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
databricks-sql-connector>=2.9.3, <3.0.0
dbt-spark==1.7.1
databricks-sdk>=0.9.0
dbt-spark~=1.7.1
databricks-sdk>=0.9.0, <0.16.0
keyring>=23.13.0
Loading

0 comments on commit 54d0491

Please sign in to comment.