Skip to content

Commit

Permalink
[PECO-1387] Fix: mv/st network request authentication is being overri…
Browse files Browse the repository at this point in the history
…dden by local .netrc file (#555)

Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
  • Loading branch information
Jesse authored Jan 18, 2024
1 parent b2a77d9 commit 319c427
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Fix dbt incremental_strategy behavior by fixing schema table existing check (thanks @case-k-git!) ([530](https://github.com/databricks/dbt-databricks/pull/530))
- Fixed bug that was causing streaming tables to be dropped and recreated instead of refreshed. ([552](https://github.com/databricks/dbt-databricks/pull/552))
- Fix: Python models authentication could be overridden by a `.netrc` file in the user's home directory ([338](https://github.com/databricks/dbt-databricks/pull/338))
- Fix: MV/ST REST api authentication could be overriden by a `.netrc` file in the user's home directory ([555](https://github.com/databricks/dbt-databricks/pull/555))

### Under the Hood

Expand Down
37 changes: 21 additions & 16 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import sys
import threading
import time
import requests
from threading import get_ident
from typing import (
Any,
Expand Down Expand Up @@ -78,6 +77,7 @@
from databricks.sdk.core import HeaderFactory

import keyring
from requests import Session

logger = AdapterLogger("Databricks")

Expand Down Expand Up @@ -124,11 +124,13 @@ def emit(self, record: logging.LogRecord) -> None:


class BearerAuth(AuthBase):
"""See issue #337.
"""This mix-in is passed to our requests Session to explicitly
use the bearer authentication method.
We use this mix-in to stop requests from implicitly reading .netrc
Without this, a local .netrc file in the user's home directory
will override the auth headers provided by our header_factory.
Solution taken from SO post in issue description.
More details in issue #337.
"""

def __init__(self, headers: HeaderFactory):
Expand Down Expand Up @@ -578,10 +580,13 @@ def pollRefreshPipeline(
stopped_states = ("COMPLETED", "FAILED", "CANCELED")
host: str = self._creds.host or ""
headers = self._cursor.connection.thrift_backend._auth_provider._header_factory()
headers["User-Agent"] = self._user_agent

pipeline_id = _get_table_view_pipeline_id(host, headers, model_name)
pipeline = _get_pipeline_state(host, headers, pipeline_id)
session = Session()
session.auth = BearerAuth(headers)
session.headers = {"User-Agent": self._user_agent}

pipeline_id = _get_table_view_pipeline_id(session, host, model_name)
pipeline = _get_pipeline_state(session, host, pipeline_id)
# get the most recently created update for the pipeline
latest_update = _find_update(pipeline)
if not latest_update:
Expand All @@ -606,7 +611,7 @@ def pollRefreshPipeline(
# should we do exponential backoff?
time.sleep(polling_interval)

pipeline = _get_pipeline_state(host, headers, pipeline_id)
pipeline = _get_pipeline_state(session, host, pipeline_id)
# get the update we are currently polling
update = _find_update(pipeline, update_id)
if not update:
Expand All @@ -623,7 +628,7 @@ def pollRefreshPipeline(

if state == "FAILED":
logger.error(f"pipeline {pipeline_id} update {update_id} failed")
msg = _get_update_error_msg(host, headers, pipeline_id, update_id)
msg = _get_update_error_msg(session, host, pipeline_id, update_id)
if msg:
logger.error(msg)

Expand All @@ -644,7 +649,7 @@ def pollRefreshPipeline(
raise dbt.exceptions.DbtRuntimeError("timed out waiting for materialized view refresh")

if state == "FAILED":
msg = _get_update_error_msg(host, headers, pipeline_id, update_id)
msg = _get_update_error_msg(session, host, pipeline_id, update_id)
raise dbt.exceptions.DbtRuntimeError(f"error refreshing model {model_name} {msg}")

if state == "CANCELED":
Expand Down Expand Up @@ -1476,9 +1481,9 @@ def _should_poll_refresh(sql: str) -> Tuple[bool, str]:
return refresh_search is not None, name


def _get_table_view_pipeline_id(host: str, headers: dict, name: str) -> str:
def _get_table_view_pipeline_id(session: Session, host: str, name: str) -> str:
table_url = f"https://{host}/api/2.1/unity-catalog/tables/{name}"
resp1 = requests.get(table_url, headers=headers)
resp1 = session.get(table_url)
if resp1.status_code != 200:
raise dbt.exceptions.DbtRuntimeError(
f"Error getting info for materialized view/streaming table: {name}"
Expand All @@ -1493,10 +1498,10 @@ def _get_table_view_pipeline_id(host: str, headers: dict, name: str) -> str:
return pipeline_id


def _get_pipeline_state(host: str, headers: dict, pipeline_id: str) -> dict:
def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict:
pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}"

response = requests.get(pipeline_url, headers=headers)
response = session.get(pipeline_url)
if response.status_code != 200:
raise dbt.exceptions.DbtRuntimeError(f"Error getting pipeline info: {pipeline_id}")

Expand All @@ -1520,9 +1525,9 @@ def _find_update(pipeline: dict, id: str = "") -> Optional[Dict]:
return None


def _get_update_error_msg(host: str, headers: dict, pipeline_id: str, update_id: str) -> str:
def _get_update_error_msg(session: Session, host: str, pipeline_id: str, update_id: str) -> str:
events_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}/events"
response = requests.get(events_url, headers=headers)
response = session.get(events_url)
if response.status_code != 200:
raise dbt.exceptions.DbtRuntimeError(f"Error getting pipeline event info: {pipeline_id}")

Expand Down

0 comments on commit 319c427

Please sign in to comment.