Skip to content

Commit

Permalink
Pin to databricks-sdk 0.17.0 (#571)
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db authored Jan 30, 2024
2 parents f9bef99 + 67aedd5 commit 3f050c3
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 24 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## dbt-databricks 1.7.5 (Jan 29, 2024)

### Fixes

- Pin databricks sdk to 0.17.0 to fix connection timeout issue ([571](https://github.com/databricks/dbt-databricks/pull/571))

## dbt-databricks 1.7.4 (Jan 24, 2024)

### Fixes
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version: str = "1.7.4"
version: str = "1.7.5"
15 changes: 8 additions & 7 deletions dbt/adapters/databricks/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Optional
from databricks.sdk.oauth import ClientCredentials, Token, TokenSource
from databricks.sdk.oauth import ClientCredentials, Token
from databricks.sdk.core import CredentialsProvider, HeaderFactory, Config, credentials_provider
from databricks.sdk.oauth import TokenSource


class token_auth(CredentialsProvider):
Expand All @@ -16,12 +17,12 @@ def as_dict(self) -> dict:
return {"token": self._token}

@staticmethod
def from_dict(raw: Optional[dict]) -> CredentialsProvider:
def from_dict(raw: Optional[dict]) -> Optional[CredentialsProvider]:
if not raw:
return None
return token_auth(raw["token"])

def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory:
def __call__(self, _: Optional[Config] = None) -> HeaderFactory:
static_credentials = {"Authorization": f"Bearer {self._token}"}

def inner() -> Dict[str, str]:
Expand All @@ -31,7 +32,7 @@ def inner() -> Dict[str, str]:


class m2m_auth(CredentialsProvider):
_token_source: TokenSource = None
_token_source: Optional[TokenSource] = None

def __init__(self, host: str, client_id: str, client_secret: str) -> None:
@credentials_provider("noop", [])
Expand Down Expand Up @@ -67,12 +68,12 @@ def as_dict(self) -> dict:
@staticmethod
def from_dict(host: str, client_id: str, client_secret: str, raw: dict) -> CredentialsProvider:
c = m2m_auth(host=host, client_id=client_id, client_secret=client_secret)
c._token_source._token = Token.from_dict(raw["token"])
c._token_source._token = Token.from_dict(raw["token"]) # type: ignore
return c

def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory:
def __call__(self, _: Optional[Config] = None) -> HeaderFactory:
def inner() -> Dict[str, str]:
token = self._token_source.token()
token = self._token_source.token() # type: ignore
return {"Authorization": f"{token.token_type} {token.access_token}"}

return inner
27 changes: 17 additions & 10 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from dbt.events.types import ConnectionUsed, SQLQuery, SQLQueryStatus
from dbt.utils import DECIMALS, cast_to_str

from databricks import sql as dbsql
import databricks.sql as dbsql
from databricks.sql.client import (
Connection as DatabricksSQLConnection,
Cursor as DatabricksSQLCursor,
Expand All @@ -82,6 +82,9 @@
logger = AdapterLogger("Databricks")


TCredentialProvider = Union[CredentialsProvider, SessionCredentials]


class DbtCoreHandler(logging.Handler):
def __init__(self, level: Union[str, int], dbt_logger: AdapterLogger):
super().__init__(level=level)
Expand Down Expand Up @@ -344,15 +347,17 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]:
def cluster_id(self) -> Optional[str]:
return self.extract_cluster_id(self.http_path) # type: ignore[arg-type]

def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider:
def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentialProvider:
self.validate_creds()
host: str = self.host or ""
if self._credentials_provider:
return self._provider_from_dict()
return self._provider_from_dict() # type: ignore
if in_provider:
self._credentials_provider = in_provider.as_dict()
if isinstance(in_provider, m2m_auth) or isinstance(in_provider, token_auth):
self._credentials_provider = in_provider.as_dict()
return in_provider

provider: TCredentialProvider
# dbt will spin up multiple threads. This has to be sync. So lock here
self._lock.acquire()
try:
Expand All @@ -373,7 +378,7 @@ def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider:
oauth_client = OAuthClient(
host=host,
client_id=self.client_id if self.client_id else CLIENT_ID,
client_secret=None,
client_secret="",
redirect_url=REDIRECT_URL,
scopes=SCOPES,
)
Expand Down Expand Up @@ -416,7 +421,7 @@ def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider:
finally:
self._lock.release()

def _provider_from_dict(self) -> CredentialsProvider:
def _provider_from_dict(self) -> Optional[TCredentialProvider]:
if self.token:
return token_auth.from_dict(self._credentials_provider)

Expand All @@ -429,14 +434,16 @@ def _provider_from_dict(self) -> CredentialsProvider:
)

oauth_client = OAuthClient(
host=self.host,
host=self.host or "",
client_id=CLIENT_ID,
client_secret=None,
client_secret="",
redirect_url=REDIRECT_URL,
scopes=SCOPES,
)

return SessionCredentials.from_dict(client=oauth_client, raw=self._credentials_provider)
return SessionCredentials.from_dict(
client=oauth_client, raw=self._credentials_provider or {"token": {}}
)


class DatabricksSQLConnectionWrapper:
Expand Down Expand Up @@ -844,7 +851,7 @@ def _reset_handle(self, open: Callable[[Connection], Connection]) -> None:

class DatabricksConnectionManager(SparkConnectionManager):
TYPE: str = "databricks"
credentials_provider: CredentialsProvider = None
credentials_provider: Optional[TCredentialProvider] = None

def __init__(self, profile: AdapterRequiredConfig) -> None:
super().__init__(profile)
Expand Down
7 changes: 3 additions & 4 deletions dbt/adapters/databricks/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from requests import Session

from dbt.adapters.databricks.__version__ import version
from dbt.adapters.databricks.connections import DatabricksCredentials
from dbt.adapters.databricks.connections import DatabricksCredentials, TCredentialProvider
from dbt.adapters.databricks import utils

import base64
Expand All @@ -16,7 +16,6 @@
import dbt.exceptions
from dbt.adapters.base import PythonJobHelper

from databricks.sdk.core import CredentialsProvider
from requests.adapters import HTTPAdapter
from dbt.adapters.databricks.connections import BearerAuth

Expand Down Expand Up @@ -442,7 +441,7 @@ def submit(self, compiled_code: str) -> None:

class DbtDatabricksBasePythonJobHelper(BaseDatabricksHelper):
credentials: DatabricksCredentials # type: ignore[assignment]
_credentials_provider: CredentialsProvider = None
_credentials_provider: Optional[TCredentialProvider] = None

def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None:
super().__init__(
Expand All @@ -463,7 +462,7 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No
connection_parameters.pop("http_headers", {})
)
self._credentials_provider = credentials.authenticate(self._credentials_provider)
header_factory = self._credentials_provider()
header_factory = self._credentials_provider(None) # type: ignore
self.session.auth = BearerAuth(header_factory)

self.extra_headers.update({"User-Agent": user_agent, **http_headers})
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
databricks-sql-connector>=3.0.0, <3.1.0
dbt-spark~=1.7.1
databricks-sdk>=0.9.0, <0.16.0
databricks-sdk==0.17.0
keyring>=23.13.0
pandas<2.2.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _get_plugin_version() -> str:
install_requires=[
"dbt-spark~=1.7.1",
"databricks-sql-connector>=3.0.0, <3.1.0",
"databricks-sdk>=0.9.0, <0.16.0",
"databricks-sdk==0.17.0",
"keyring>=23.13.0",
"pandas<2.2.0",
],
Expand Down

0 comments on commit 3f050c3

Please sign in to comment.