Skip to content

Commit

Permalink
try some clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Dec 6, 2024
1 parent 69aa772 commit f5eb3f8
Show file tree
Hide file tree
Showing 14 changed files with 21 additions and 23 deletions.
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _poll_api(


@dataclass(frozen=True, eq=True, unsafe_hash=True)
class CommandExecution(object):
class CommandExecution:
command_id: str
context_id: str
cluster_id: str
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def data_type(self) -> str:
return self.translate_type(self.dtype)

def __repr__(self) -> str:
return "<DatabricksColumn {} ({})>".format(self.name, self.data_type)
return f"<DatabricksColumn {self.name} ({self.data_type})>"

@staticmethod
def get_name(column: dict[str, Any]) -> str:
Expand Down
6 changes: 3 additions & 3 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import time
import uuid
import warnings
from collections.abc import Callable, Iterator, Sequence
from collections.abc import Callable, Hashable, Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from multiprocessing.context import SpawnContext
from numbers import Number
from threading import get_ident
from typing import TYPE_CHECKING, Any, Hashable, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast

from dbt_common.events.contextvars import get_node_info
from dbt_common.events.functions import fire_event
Expand Down Expand Up @@ -488,7 +488,7 @@ def add_query(
try:
log_sql = redact_credentials(sql)
if abridge_sql_log:
log_sql = "{}...".format(log_sql[:512])
log_sql = f"{log_sql[:512]}..."

fire_event(
SQLQuery(
Expand Down
8 changes: 4 additions & 4 deletions dbt/adapters/databricks/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,19 @@ def validate_creds(self) -> None:
for key in ["host", "http_path"]:
if not getattr(self, key):
raise DbtConfigError(
"The config '{}' is required to connect to Databricks".format(key)
f"The config '{key}' is required to connect to Databricks"
)
if not self.token and self.auth_type != "oauth":
raise DbtConfigError(
("The config `auth_type: oauth` is required when not using access token")
"The config `auth_type: oauth` is required when not using access token"
)

if not self.client_id and self.client_secret:
raise DbtConfigError(
(

"The config 'client_id' is required to connect "
"to Databricks when 'client_secret' is present"
)

)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def _get_catalog_for_relation_map(
used_schemas: frozenset[tuple[str, str]],
) -> tuple["Table", list[Exception]]:
with executor(self.config) as tpe:
futures: list[Future["Table"]] = []
futures: list[Future[Table]] = []
for schema, relations in relation_map.items():
if schema in used_schemas:
identifier = get_identifier_list_string(relations)
Expand Down Expand Up @@ -804,7 +804,7 @@ def get_from_relation(
) -> DatabricksRelationConfig:
"""Get the relation config from the relation."""

relation_config = super(DeltaLiveTableAPIBase, cls).get_from_relation(adapter, relation)
relation_config = super().get_from_relation(adapter, relation)

# Ensure any current refreshes are completed before returning the relation config
tblproperties = cast(TblPropertiesConfig, relation_config.config["tblproperties"])
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/python_models/run_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dbt.adapters.databricks.logging import logger


class PythonRunTracker(object):
class PythonRunTracker:
_run_ids: set[str] = set()
_commands: set[CommandExecution] = set()
_lock = threading.Lock()
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/databricks/relation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Optional, Type
from typing import Any, Optional

from dbt_common.dataclass_schema import StrEnum
from dbt_common.exceptions import DbtRuntimeError
Expand Down Expand Up @@ -131,7 +131,7 @@ def matches(
return match

@classproperty
def get_relation_type(cls) -> Type[DatabricksRelationType]:
def get_relation_type(cls) -> type[DatabricksRelationType]: # type: ignore
return DatabricksRelationType

def information_schema(self, view_name: Optional[str] = None) -> InformationSchema:
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ dependencies = [
"dbt-core>=1.8.7, <2.0",
"dbt-spark>=1.8.0, <2.0",
"keyring>=23.13.0",
"pandas<2.2.0",
"pydantic>=1.10.0, <2",
]

[project.urls]
Expand Down Expand Up @@ -102,7 +100,7 @@ line-length = 100
target-version = 'py39'

[tool.ruff.lint]
select = ["E", "W", "F", "I"]
select = ["E", "W", "F", "I", "UP"]
ignore = ["E203"]

[tool.pytest.ini_options]
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/ephemeral/test_ephemeral.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_ephemeral_nested(self, project):
results = util.run_dbt(["run"])
assert len(results) == 2
assert os.path.exists("./target/run/test/models/root_view.sql")
with open("./target/run/test/models/root_view.sql", "r") as fp:
with open("./target/run/test/models/root_view.sql") as fp:
sql_file = fp.read()

sql_file = re.sub(r"\d+", "", sql_file)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/hooks/test_model_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_ctx_vars(self, state, count, project):
"invocation_id",
"thread_id",
]
field_list = ", ".join(["{}".format(f) for f in fields])
field_list = ", ".join([f"{f}" for f in fields])
query = (
f"select {field_list} from {project.test_schema}.on_model_hook"
f" where test_state = '{state}'"
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/hooks/test_run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_ctx_vars(self, state, project):
"invocation_id",
"thread_id",
]
field_list = ", ".join(["{}".format(f) for f in fields])
field_list = ", ".join([f"{f}" for f in fields])
query = (
f"select {field_list} from {project.test_schema}.on_run_hook where test_state = "
f"'{state}'"
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/python_model/test_python_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_changing_schema_with_log_validation(self, project, logs_dir):
)
util.run_dbt(["run"])
log_file = os.path.join(logs_dir, "dbt.log")
with open(log_file, "r") as f:
with open(log_file) as f:
log = f.read()
# validate #5510 log_code_execution works
assert "On model.test.simple_python_model:" in log
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/api_client/test_workspace_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_upload_notebook__non_200(self, api, session):

def test_upload_notebook__200(self, api, session, host):
session.post.return_value.status_code = 200
encoded = base64.b64encode("code".encode()).decode()
encoded = base64.b64encode(b"code").decode()
api.upload_notebook("path", "code")
session.post.assert_called_once_with(
f"https://{host}/api/2.0/workspace/import",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_password(self, servicename, username):
if not os.path.exists(file_path):
return None

with open(file_path, "r") as file:
with open(file_path) as file:
password = file.read()

return password
Expand Down

0 comments on commit f5eb3f8

Please sign in to comment.