From 61f3fb7c498b89b5d16d5b2aa190c3dbcb30244c Mon Sep 17 00:00:00 2001 From: eric-wang-1990 <115501094+eric-wang-1990@users.noreply.github.com> Date: Mon, 3 Feb 2025 08:59:12 -0800 Subject: [PATCH] Bump python sdk version (#827) --- CHANGELOG.md | 4 + dbt/adapters/databricks/api_client.py | 6 +- dbt/adapters/databricks/auth.py | 100 ------- dbt/adapters/databricks/connections.py | 14 +- dbt/adapters/databricks/constraints.py | 4 +- dbt/adapters/databricks/credentials.py | 355 +++++++++++++------------ docs/oauth.md | 44 ++- pyproject.toml | 2 +- tests/profiles.py | 2 + tests/unit/test_adapter.py | 7 +- tests/unit/test_auth.py | 14 +- tests/unit/test_compute_config.py | 5 +- tests/unit/test_idle_config.py | 13 +- 13 files changed, 254 insertions(+), 316 deletions(-) delete mode 100644 dbt/adapters/databricks/auth.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b507d111a..bb2cf8b2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,10 +2,14 @@ ### Features +- Support databricks OAuth M2M auth type. Updated OAuth readme doc with instructions.([827](https://github.com/databricks/dbt-databricks/pull/827)) + - Introduced use_materialization_v2 flag for gating materialization revamps. ([844](https://github.com/databricks/dbt-databricks/pull/844)) ### Under the Hood +- Update pinned python SDK version from 0.17.0 to 0.41.0. ([827](https://github.com/databricks/dbt-databricks/pull/827)) + - Implement new constraint logic for use_materialization_v2 flag ([846](https://github.com/databricks/dbt-databricks/pull/846/files)), ([876](https://github.com/databricks/dbt-databricks/pull/876)) ## dbt-databricks 1.9.5 (TBD) diff --git a/dbt/adapters/databricks/api_client.py b/dbt/adapters/databricks/api_client.py index 57bc4ca3c..c5d9add62 100644 --- a/dbt/adapters/databricks/api_client.py +++ b/dbt/adapters/databricks/api_client.py @@ -13,8 +13,7 @@ from dbt.adapters.databricks import utils from dbt.adapters.databricks.__version__ import version -from dbt.adapters.databricks.auth import BearerAuth -from dbt.adapters.databricks.credentials import DatabricksCredentials +from dbt.adapters.databricks.credentials import BearerAuth, DatabricksCredentials from dbt.adapters.databricks.logging import logger DEFAULT_POLLING_INTERVAL = 10 @@ -557,8 +556,7 @@ def create( http_headers = credentials.get_all_http_headers( connection_parameters.pop("http_headers", {}) ) - credentials_provider = credentials.authenticate(None) - header_factory = credentials_provider(None) # type: ignore + header_factory = credentials.authenticate().credentials_provider() session.auth = BearerAuth(header_factory) session.headers.update({"User-Agent": user_agent, **http_headers}) diff --git a/dbt/adapters/databricks/auth.py b/dbt/adapters/databricks/auth.py deleted file mode 100644 index 439fa23fd..000000000 --- a/dbt/adapters/databricks/auth.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Any, Optional - -from requests import PreparedRequest -from requests.auth import AuthBase - -from databricks.sdk.core import Config, CredentialsProvider, HeaderFactory, credentials_provider -from databricks.sdk.oauth import ClientCredentials, Token, TokenSource - - -class token_auth(CredentialsProvider): - _token: str - - def __init__(self, token: str) -> None: - self._token = token - - def auth_type(self) -> str: - return "token" - - def as_dict(self) -> dict: - return {"token": self._token} - - @staticmethod - def from_dict(raw: Optional[dict]) -> Optional[CredentialsProvider]: - if not raw: - return None - return token_auth(raw["token"]) - - def __call__(self, _: Optional[Config] = None) -> HeaderFactory: - static_credentials = {"Authorization": f"Bearer {self._token}"} - - def inner() -> dict[str, str]: - return static_credentials - - return inner - - -class m2m_auth(CredentialsProvider): - _token_source: Optional[TokenSource] = None - - def __init__(self, host: str, client_id: str, client_secret: str) -> None: - @credentials_provider("noop", []) - def noop_credentials(_: Any): # type: ignore - return lambda: {} - - config = Config(host=host, credentials_provider=noop_credentials) - oidc = config.oidc_endpoints - scopes = ["all-apis"] - if not oidc: - raise ValueError(f"{host} does not support OAuth") - if config.is_azure: - # Azure AD only supports full access to Azure Databricks. - scopes = [f"{config.effective_azure_login_app_id}/.default"] - self._token_source = ClientCredentials( - client_id=client_id, - client_secret=client_secret, - token_url=oidc.token_endpoint, - scopes=scopes, - use_header="microsoft" not in oidc.token_endpoint, - use_params="microsoft" in oidc.token_endpoint, - ) - - def auth_type(self) -> str: - return "oauth" - - def as_dict(self) -> dict: - if self._token_source: - return {"token": self._token_source.token().as_dict()} - else: - return {"token": {}} - - @staticmethod - def from_dict(host: str, client_id: str, client_secret: str, raw: dict) -> CredentialsProvider: - c = m2m_auth(host=host, client_id=client_id, client_secret=client_secret) - c._token_source._token = Token.from_dict(raw["token"]) # type: ignore - return c - - def __call__(self, _: Optional[Config] = None) -> HeaderFactory: - def inner() -> dict[str, str]: - token = self._token_source.token() # type: ignore - return {"Authorization": f"{token.token_type} {token.access_token}"} - - return inner - - -class BearerAuth(AuthBase): - """This mix-in is passed to our requests Session to explicitly - use the bearer authentication method. - - Without this, a local .netrc file in the user's home directory - will override the auth headers provided by our header_factory. - - More details in issue #337. - """ - - def __init__(self, header_factory: HeaderFactory): - self.header_factory = header_factory - - def __call__(self, r: PreparedRequest) -> PreparedRequest: - r.headers.update(**self.header_factory()) - return r diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 3a6b4817f..f38b7e4af 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -32,7 +32,10 @@ ) from dbt.adapters.databricks.__version__ import version as __version__ from dbt.adapters.databricks.api_client import DatabricksApiClient -from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider +from dbt.adapters.databricks.credentials import ( + DatabricksCredentialManager, + DatabricksCredentials, +) from dbt.adapters.databricks.events.connection_events import ( ConnectionAcquire, ConnectionCancel, @@ -373,7 +376,7 @@ def _reset_handle(self, open: Callable[[Connection], Connection]) -> None: class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" - credentials_provider: Optional[TCredentialProvider] = None + credentials_manager: Optional[DatabricksCredentialManager] = None _user_agent = f"dbt-databricks/{__version__}" def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): @@ -634,7 +637,7 @@ def open(cls, connection: Connection) -> Connection: timeout = creds.connect_timeout # gotta keep this so we don't prompt users many times - cls.credentials_provider = creds.authenticate(cls.credentials_provider) + cls.credentials_manager = creds.authenticate() invocation_env = creds.get_invocation_env() user_agent_entry = cls._user_agent @@ -652,12 +655,13 @@ def open(cls, connection: Connection) -> Connection: http_path = databricks_connection.http_path def connect() -> DatabricksSQLConnectionWrapper: + assert cls.credentials_manager is not None try: # TODO: what is the error when a user specifies a catalog they don't have access to conn = dbsql.connect( server_hostname=creds.host, http_path=http_path, - credentials_provider=cls.credentials_provider, + credentials_provider=cls.credentials_manager.credentials_provider, http_headers=http_headers if http_headers else None, session_configuration=creds.session_properties, catalog=creds.database, @@ -708,7 +712,7 @@ def _open(cls, connection: Connection, query_header_context: Any = None) -> Conn timeout = creds.connect_timeout # gotta keep this so we don't prompt users many times - cls.credentials_provider = creds.authenticate(cls.credentials_provider) + cls.credentials_manager = creds.authenticate() invocation_env = creds.get_invocation_env() user_agent_entry = cls._user_agent diff --git a/dbt/adapters/databricks/constraints.py b/dbt/adapters/databricks/constraints.py index 5761dacf5..bce44f8b6 100644 --- a/dbt/adapters/databricks/constraints.py +++ b/dbt/adapters/databricks/constraints.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, ClassVar, Optional, Type, TypeVar +from typing import Any, ClassVar, Optional, TypeVar from dbt_common.contracts.constraints import ( ColumnLevelConstraint, @@ -32,7 +32,7 @@ class TypedConstraint(ModelLevelConstraint, ABC): str_type: ClassVar[str] @classmethod - def __post_deserialize__(cls: Type[T], obj: T) -> T: + def __post_deserialize__(cls: type[T], obj: T) -> T: assert obj.type == cls.str_type, "Mismatched constraint type" obj._validate() return obj diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 250e79f65..c68d87044 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -1,24 +1,19 @@ import itertools import json -import os import re import threading from collections.abc import Iterable -from dataclasses import dataclass -from typing import Any, Optional, Union, cast +from dataclasses import dataclass, field +from typing import Any, Callable, Optional, cast -import keyring from dbt_common.exceptions import DbtConfigError, DbtValidationError +from mashumaro import DataClassDictMixin +from requests import PreparedRequest +from requests.auth import AuthBase -from databricks.sdk.core import CredentialsProvider -from databricks.sdk.oauth import OAuthClient, SessionCredentials +from databricks.sdk import WorkspaceClient +from databricks.sdk.core import Config, CredentialsProvider from dbt.adapters.contracts.connection import Credentials -from dbt.adapters.databricks.auth import m2m_auth, token_auth -from dbt.adapters.databricks.events.credential_events import ( - CredentialLoadError, - CredentialSaveError, - CredentialShardEvent, -) from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger @@ -36,8 +31,6 @@ # also expire after 24h. Silently accept this in this case. SPA_CLIENT_FIXED_TIME_LIMIT_ERROR = "AADSTS700084" -TCredentialProvider = Union[CredentialsProvider, SessionCredentials] - @dataclass class DatabricksCredentials(Credentials): @@ -48,6 +41,8 @@ class DatabricksCredentials(Credentials): token: Optional[str] = None client_id: Optional[str] = None client_secret: Optional[str] = None + azure_client_id: Optional[str] = None + azure_client_secret: Optional[str] = None oauth_redirect_url: Optional[str] = None oauth_scopes: Optional[list[str]] = None session_properties: Optional[dict[str, Any]] = None @@ -62,8 +57,7 @@ class DatabricksCredentials(Credentials): connect_timeout: Optional[int] = None retry_all: bool = False connect_max_idle: Optional[int] = None - - _credentials_provider: Optional[dict[str, Any]] = None + _credentials_manager: Optional["DatabricksCredentialManager"] = None _lock = threading.Lock() # to avoid concurrent auth _ALIASES = { @@ -134,6 +128,7 @@ def __post_init__(self) -> None: if "_socket_timeout" not in connection_parameters: connection_parameters["_socket_timeout"] = 600 self.connection_parameters = connection_parameters + self._credentials_manager = DatabricksCredentialManager.create_from(self) def validate_creds(self) -> None: for key in ["host", "http_path"]: @@ -150,6 +145,14 @@ def validate_creds(self) -> None: "to Databricks when 'client_secret' is present" ) + if (not self.azure_client_id and self.azure_client_secret) or ( + self.azure_client_id and not self.azure_client_secret + ): + raise DbtConfigError( + "The config 'azure_client_id' and 'azure_client_secret' " + "must be both present or both absent" + ) + @classmethod def get_invocation_env(cls) -> Optional[str]: invocation_env = GlobalState.get_invocation_env() @@ -234,172 +237,174 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]: def cluster_id(self) -> Optional[str]: return self.extract_cluster_id(self.http_path) # type: ignore[arg-type] - def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentialProvider: + def authenticate(self) -> "DatabricksCredentialManager": self.validate_creds() - host: str = self.host or "" - if self._credentials_provider: - return self._provider_from_dict() # type: ignore - if in_provider: - if isinstance(in_provider, m2m_auth) or isinstance(in_provider, token_auth): - self._credentials_provider = in_provider.as_dict() - return in_provider - - provider: TCredentialProvider - # dbt will spin up multiple threads. This has to be sync. So lock here - self._lock.acquire() - try: - if self.token: - provider = token_auth(self.token) - self._credentials_provider = provider.as_dict() - return provider - - if self.client_id and self.client_secret: - provider = m2m_auth( - host=host, - client_id=self.client_id or "", - client_secret=self.client_secret or "", - ) - self._credentials_provider = provider.as_dict() - return provider - - client_id = self.client_id or CLIENT_ID - redirect_url = self.oauth_redirect_url or REDIRECT_URL - scopes = self.oauth_scopes or SCOPES - - oauth_client = OAuthClient( - host=host, - client_id=client_id, - client_secret="", - redirect_url=redirect_url, - scopes=scopes, - ) - # optional branch. Try and keep going if it does not work - try: - # try to get cached credentials - credsdict = self.get_sharded_password("dbt-databricks", host) - - if credsdict: - provider = SessionCredentials.from_dict(oauth_client, json.loads(credsdict)) - # if refresh token is expired, this will throw - try: - if provider.token().valid: - self._credentials_provider = provider.as_dict() - if json.loads(credsdict) != provider.as_dict(): - # if the provider dict has changed, most likely because of a token - # refresh, save it for further use - self.set_sharded_password( - "dbt-databricks", host, json.dumps(self._credentials_provider) - ) - return provider - except Exception as e: - # SPA token are supposed to expire after 24h, no need to warn - if SPA_CLIENT_FIXED_TIME_LIMIT_ERROR in str(e): - logger.debug(CredentialLoadError(e)) - else: - logger.warning(CredentialLoadError(e)) - # whatever it is, get rid of the cache - self.delete_sharded_password("dbt-databricks", host) - - # error with keyring. Maybe machine has no password persistency - except Exception as e: - logger.warning(CredentialLoadError(e)) - - # no token, go fetch one - consent = oauth_client.initiate_consent() - - provider = consent.launch_external_browser() - # save for later - self._credentials_provider = provider.as_dict() - try: - self.set_sharded_password( - "dbt-databricks", host, json.dumps(self._credentials_provider) - ) - # error with keyring. Maybe machine has no password persistency - except Exception as e: - logger.warning(CredentialSaveError(e)) + assert self._credentials_manager is not None, "Credentials manager is not set." + return self._credentials_manager - return provider - finally: - self._lock.release() +class BearerAuth(AuthBase): + """This mix-in is passed to our requests Session to explicitly + use the bearer authentication method. - def set_sharded_password(self, service_name: str, username: str, password: str) -> None: - max_size = MAX_NT_PASSWORD_SIZE + Without this, a local .netrc file in the user's home directory + will override the auth headers provided by our header_factory. - # if not Windows or "small" password, stick to the default - if os.name != "nt" or len(password) < max_size: - keyring.set_password(service_name, username, password) - else: - logger.debug(CredentialShardEvent(len(password))) - - password_shards = [ - password[i : i + max_size] for i in range(0, len(password), max_size) - ] - shard_info = { - "sharded_password": True, - "shard_count": len(password_shards), - } + More details in issue #337. + """ - # store the "shard info" as the "base" password - keyring.set_password(service_name, username, json.dumps(shard_info)) - # then store all shards with the shard number as postfix - for i, s in enumerate(password_shards): - keyring.set_password(service_name, f"{username}__{i}", s) - - def get_sharded_password(self, service_name: str, username: str) -> Optional[str]: - password = keyring.get_password(service_name, username) - - # check for "shard info" stored as json - try: - password_as_dict = json.loads(str(password)) - if password_as_dict.get("sharded_password"): - # if password was stored shared, reconstruct it - shard_count = int(password_as_dict.get("shard_count")) - - password = "" - for i in range(shard_count): - password += str(keyring.get_password(service_name, f"{username}__{i}")) - except ValueError: - pass - - return password - - def delete_sharded_password(self, service_name: str, username: str) -> None: - password = keyring.get_password(service_name, username) - - # check for "shard info" stored as json. If so delete all shards - try: - password_as_dict = json.loads(str(password)) - if password_as_dict.get("sharded_password"): - shard_count = int(password_as_dict.get("shard_count")) - for i in range(shard_count): - keyring.delete_password(service_name, f"{username}__{i}") - except ValueError: - pass - - # delete "base" password - keyring.delete_password(service_name, username) - - def _provider_from_dict(self) -> Optional[TCredentialProvider]: - if self.token: - return token_auth.from_dict(self._credentials_provider) - - if self.client_id and self.client_secret: - return m2m_auth.from_dict( - host=self.host or "", - client_id=self.client_id or "", - client_secret=self.client_secret or "", - raw=self._credentials_provider or {"token": {}}, - ) + def __init__(self, header_factory: CredentialsProvider): + self.header_factory = header_factory + + def __call__(self, r: PreparedRequest) -> PreparedRequest: + r.headers.update(**self.header_factory()) + return r + + +PySQLCredentialProvider = Callable[[], Callable[[], dict[str, str]]] + + +@dataclass +class DatabricksCredentialManager(DataClassDictMixin): + host: str + client_id: str + client_secret: str + azure_client_id: Optional[str] = None + azure_client_secret: Optional[str] = None + oauth_redirect_url: str = REDIRECT_URL + oauth_scopes: list[str] = field(default_factory=lambda: SCOPES) + token: Optional[str] = None + auth_type: Optional[str] = None + + @classmethod + def create_from(cls, credentials: DatabricksCredentials) -> "DatabricksCredentialManager": + return DatabricksCredentialManager( + host=credentials.host or "", + token=credentials.token, + client_id=credentials.client_id or CLIENT_ID, + client_secret=credentials.client_secret or "", + azure_client_id=credentials.azure_client_id, + azure_client_secret=credentials.azure_client_secret, + oauth_redirect_url=credentials.oauth_redirect_url or REDIRECT_URL, + oauth_scopes=credentials.oauth_scopes or SCOPES, + auth_type=credentials.auth_type, + ) + + def authenticate_with_pat(self) -> Config: + return Config( + host=self.host, + token=self.token, + ) - oauth_client = OAuthClient( - host=self.host or "", - client_id=self.client_id or CLIENT_ID, - client_secret="", - redirect_url=self.oauth_redirect_url or REDIRECT_URL, - scopes=self.oauth_scopes or SCOPES, + def authenticate_with_oauth_m2m(self) -> Config: + return Config( + host=self.host, + client_id=self.client_id, + client_secret=self.client_secret, + auth_type="oauth-m2m", ) - return SessionCredentials.from_dict( - client=oauth_client, raw=self._credentials_provider or {"token": {}} + def authenticate_with_external_browser(self) -> Config: + return Config( + host=self.host, + client_id=self.client_id, + client_secret=self.client_secret, + auth_type="external-browser", ) + + def legacy_authenticate_with_azure_client_secret(self) -> Config: + return Config( + host=self.host, + azure_client_id=self.client_id, + azure_client_secret=self.client_secret, + auth_type="azure-client-secret", + ) + + def authenticate_with_azure_client_secret(self) -> Config: + return Config( + host=self.host, + azure_client_id=self.azure_client_id, + azure_client_secret=self.azure_client_secret, + auth_type="azure-client-secret", + ) + + def __post_init__(self) -> None: + self._lock = threading.Lock() + with self._lock: + if not hasattr(self, "_config"): + self._config: Optional[Config] = None + if self._config is not None: + return + + if self.token: + self._config = self.authenticate_with_pat() + elif self.azure_client_id and self.azure_client_secret: + self._config = self.authenticate_with_azure_client_secret() + elif not self.client_secret: + self._config = self.authenticate_with_external_browser() + else: + auth_methods = { + "oauth-m2m": self.authenticate_with_oauth_m2m, + "legacy-azure-client-secret": self.legacy_authenticate_with_azure_client_secret, + } + + # If the secret starts with dose, high chance is it is a databricks secret + if self.client_secret.startswith("dose"): + auth_sequence = ["oauth-m2m", "legacy-azure-client-secret"] + else: + auth_sequence = ["legacy-azure-client-secret", "oauth-m2m"] + + exceptions = [] + for i, auth_type in enumerate(auth_sequence): + try: + # The Config constructor will implicitly init auth and throw if failed + self._config = auth_methods[auth_type]() + if auth_type == "legacy-azure-client-secret": + logger.warning( + "You are using Azure Service Principal, " + "please use 'azure_client_id' and 'azure_client_secret' instead." + ) + break # Exit loop if authentication is successful + except Exception as e: + exceptions.append((auth_type, e)) + next_auth_type = ( + auth_sequence[i + 1] if i + 1 < len(auth_sequence) else None + ) + if next_auth_type: + logger.warning( + f"Failed to authenticate with {auth_type}, " + f"trying {next_auth_type} next. Error: {e}" + ) + else: + logger.error( + f"Failed to authenticate with {auth_type}. " + f"No more authentication methods to try. Error: {e}" + ) + raise Exception( + f"All authentication methods failed. Details: {exceptions}" + ) + + @property + def api_client(self) -> WorkspaceClient: + return WorkspaceClient(config=self._config) + + @property + def credentials_provider(self) -> PySQLCredentialProvider: + def inner() -> Callable[[], dict[str, str]]: + return self.header_factory + + return inner + + @property + def header_factory(self) -> CredentialsProvider: + if self._config is None: + raise RuntimeError("Config is not initialized") + header_factory = self._config._header_factory + assert header_factory is not None, "Header factory is not set." + return header_factory + + @property + def config(self) -> Config: + if self._config is None: + raise RuntimeError("Config is not initialized") + return self._config diff --git a/docs/oauth.md b/docs/oauth.md index 9d95b652c..4ca15494c 100644 --- a/docs/oauth.md +++ b/docs/oauth.md @@ -2,11 +2,11 @@ This feature is in [Public Preview](https://docs.databricks.com/release-notes/release-types.html). -Databricks DBT adapter now supports authentication via OAuth in AWS and Azure. This is a much safer method as it enables you to generate short-lived (one hour) OAuth access tokens, which eliminates the risk of accidentally exposing longer-lived tokens such as Databricks personal access tokens through version control checkins or other means. OAuth also enables better server-side session invalidation and scoping. +## User to Machine(U2M): -Once an admin correctly configured OAuth in Databricks, you can simply add the config `auth_type` and set it to `oauth`. Config `token` is no longer necessary. +Databricks DBT adapter now supports authentication via OAuth U2M flow in all clouds. This is a much safer method as it enables you to generate short-lived (one hour) OAuth access tokens, which eliminates the risk of accidentally exposing longer-lived tokens such as Databricks personal access tokens through version control checkins or other means. OAuth also enables better server-side session invalidation and scoping. -For Azure, you admin needs to create a Public AD application for dbt and provide you with its client_id. +Simply add the config `auth_type` and set it to `oauth`. Config `token` is no longer necessary. ```YAML jaffle_shop: @@ -17,16 +17,23 @@ jaffle_shop: catalog: schema: auth_type: oauth # new - client_id: # only necessary for Azure type: databricks target: dev ``` -## Troubleshooting +### Troubleshooting DBT expects the OAuth application to have the "All APIs" scope and redirect URL `http://localhost:8020` by default. -If the OAuth application has only been configured with SQL access scopes or a custom redirect URL, you may need to update your profile accordingly: +The default oauth app for dbt-databricks is auto-enabled in every account with expected settings, you can find it in [Account Console](https://accounts.cloud.databricks.com) > [Settings](https://accounts.cloud.databricks.com/settings) > [App Connections](https://accounts.cloud.databricks.com/settings/app-integrations) > dbt adapter for Databricks. If you cannot find it you may have disabled dbt in your account, please refer to this [guide](https://docs.databricks.com/en/integrations/enable-disable-oauth.html) to re-enable dbt as oauth app. + +If you encounter any issues, please refer to the [OAuth user-to-machine (U2M) authentication guide](https://docs.databricks.com/en/dev-tools/auth/oauth-u2m.html). + +## Machine to Machine(M2M): + +Databricks DBT adapter also supports authenticate via OAuth M2M flow in all clouds. +Simply add the config `auth_type` and set it to `oauth`. Follow this [guide](https://docs.databricks.com/en/dev-tools/auth/oauth-m2m.html) to create a databricks service principal and also an OAuth secret. +Set `client_id` to your databricks service principal id and `client_secret` to your OAuth secret for the service principal. ```YAML jaffle_shop: @@ -37,15 +44,26 @@ jaffle_shop: catalog: schema: auth_type: oauth # new - client_id: # only necessary for Azure - oauth_redirect_url: https://example.com - oauth_scopes: - - sql - - offline_access + client_id: + client_secret: type: databricks target: dev ``` -You can find these settings in [Account Console](https://accounts.cloud.databricks.com) > [Settings](https://accounts.cloud.databricks.com/settings) > [App Connections](https://accounts.cloud.databricks.com/settings/app-integrations) > dbt adapter for Databricks +### Azure Service Principal +If you are on Azure Databricks and want to use Azure Service Principal, just set `azure_client_id` to your Azure Client Id and `azure_client_secret` to your Azure Client Secret. -If you encounter any issues, please refer to the [OAuth user-to-machine (U2M) authentication guide](https://docs.databricks.com/en/dev-tools/auth/oauth-u2m.html). +```YAML +jaffle_shop: + outputs: + dev: + host: + http_path: + catalog: + schema: + auth_type: oauth # new + azure_client_id: + azure_client_secret: + type: databricks + target: dev +``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 8a5fc58cb..668564870 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ - "databricks-sdk==0.17.0", + "databricks-sdk==0.41.0", "databricks-sql-connector>=3.5.0, <3.7.0", "dbt-adapters>=1.9.0, <2.0", "dbt-common>=1.12.0, <2.0", diff --git a/tests/profiles.py b/tests/profiles.py index e0c88447e..a3b173e6e 100644 --- a/tests/profiles.py +++ b/tests/profiles.py @@ -26,6 +26,8 @@ def _build_databricks_cluster_target( "token": os.getenv("DBT_DATABRICKS_TOKEN"), "client_id": os.getenv("DBT_DATABRICKS_CLIENT_ID"), "client_secret": os.getenv("DBT_DATABRICKS_CLIENT_SECRET"), + "azure_client_id": os.getenv("DBT_DATABRICKS_AZURE_CLIENT_ID"), + "azure_client_secret": os.getenv("DBT_DATABRICKS_AZURE_CLIENT_SECRET"), "connect_retries": 3, "connect_timeout": 5, "retry_all": True, diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index de33440da..3a094c4a4 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -246,9 +246,12 @@ def connect( ): assert server_hostname == "yourorg.databricks.com" assert http_path == "sql/protocolv1/o/1234567890123456/1234-567890-test123" - if not (expected_no_token or expected_client_creds): - assert credentials_provider._token == "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + if not (expected_no_token or expected_client_creds): + assert ( + credentials_provider()().get("Authorization") + == "Bearer dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + ) if expected_client_creds: assert kwargs.get("client_id") == "foo" assert kwargs.get("client_secret") == "bar" diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 424cd39fa..90ba0f594 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -65,22 +65,21 @@ def test_token(self): http_path="http://foo", schema="dbt", ) - provider = creds.authenticate(None) + credentialManager = creds.authenticate() + provider = credentialManager.credentials_provider() assert provider is not None - headers_fn = provider() + headers_fn = provider headers = headers_fn() assert headers is not None - raw = provider.as_dict() + raw = credentialManager._config.as_dict() assert raw is not None - provider_b = creds._provider_from_dict() - headers_fn2 = provider_b() - headers2 = headers_fn2() - assert headers == headers2 + assert headers == {"Authorization": "Bearer foo"} +@pytest.mark.skip(reason="Cache moved to databricks sdk TokenCache") class TestShardedPassword: def test_store_and_delete_short_password(self): # set the keyring to mock class @@ -133,6 +132,7 @@ def test_store_and_delete_long_password(self): assert retrieved_password is None +@pytest.mark.skip(reason="Cache moved to databricks sdk TokenCache") class MockKeyring(keyring.backend.KeyringBackend): def __init__(self): self.file_location = self._generate_test_root_dir() diff --git a/tests/unit/test_compute_config.py b/tests/unit/test_compute_config.py index 994d4ae9a..b76cdc226 100644 --- a/tests/unit/test_compute_config.py +++ b/tests/unit/test_compute_config.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from dbt_common.exceptions import DbtRuntimeError @@ -23,7 +23,8 @@ def path(self): @pytest.fixture def creds(self, path): - return DatabricksCredentials(http_path=path) + with patch("dbt.adapters.databricks.credentials.Config"): + return DatabricksCredentials(http_path=path) @pytest.fixture def node(self): diff --git a/tests/unit/test_idle_config.py b/tests/unit/test_idle_config.py index de3545680..a733c07d5 100644 --- a/tests/unit/test_idle_config.py +++ b/tests/unit/test_idle_config.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from dbt_common.exceptions import DbtRuntimeError @@ -6,6 +8,7 @@ from dbt.contracts.graph import model_config, nodes +@patch("dbt.adapters.databricks.credentials.Config") class TestDatabricksConnectionMaxIdleTime: """Test the various cases for determining a specified warehouse.""" @@ -13,7 +16,7 @@ class TestDatabricksConnectionMaxIdleTime: "Compute resource foo does not exist or does not specify http_path, " "relation: a_relation" ) - def test_get_max_idle_default(self): + def test_get_max_idle_default(self, _): creds = DatabricksCredentials() # No node and nothing specified in creds @@ -72,7 +75,7 @@ def test_get_max_idle_default(self): # path = connections._get_http_path(node, creds) # self.assertEqual("alternate_path", path) - def test_get_max_idle_creds(self): + def test_get_max_idle_creds(self, _): creds_idle_time = 77 creds = DatabricksCredentials(connect_max_idle=creds_idle_time) @@ -123,7 +126,7 @@ def test_get_max_idle_creds(self): time = connections._get_max_idle_time(node, creds) assert creds_idle_time == time - def test_get_max_idle_compute(self): + def test_get_max_idle_compute(self, _): creds_idle_time = 88 compute_idle_time = 77 creds = DatabricksCredentials(connect_max_idle=creds_idle_time) @@ -151,7 +154,7 @@ def test_get_max_idle_compute(self): time = connections._get_max_idle_time(node, creds) assert compute_idle_time == time - def test_get_max_idle_invalid(self): + def test_get_max_idle_invalid(self, _): creds_idle_time = "foo" compute_idle_time = "bar" creds = DatabricksCredentials(connect_max_idle=creds_idle_time) @@ -204,7 +207,7 @@ def test_get_max_idle_invalid(self): "1,002.3 is not a valid value for connect_max_idle. " "Must be a number of seconds." ) in str(info.value) - def test_get_max_idle_simple_string_conversion(self): + def test_get_max_idle_simple_string_conversion(self, _): creds_idle_time = "12" compute_idle_time = "34" creds = DatabricksCredentials(connect_max_idle=creds_idle_time)