Skip to content

Commit

Permalink
Support more config options in python (#379)
Browse files Browse the repository at this point in the history
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
  • Loading branch information
benc-db authored Jul 25, 2023
1 parent 84a7e97 commit 279492c
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 31 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
name: Integration Tests
on:
push
on: push
jobs:
run-tox-tests-uc:
runs-on: ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
### Fixes

- Fix issue where the show tables extended command is limited to 2048 characters. ([#326](https://github.com/databricks/dbt-databricks/pull/326))
- Extend python model support to cover the same config options as SQL ([#379](https://github.com/databricks/dbt-databricks/pull/379))

## dbt-databricks 1.5.5 (July 7, 2023)

### Fixes

- Fixed issue where starting a terminated cluster in the python path would never return

### Features
Expand Down
2 changes: 1 addition & 1 deletion dbt/include/databricks/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
TODO: Deep dive into spark sessions to see if we can reuse a single session for an entire
dbt invocation.
--#}
{{ py_write_table(compiled_code=compiled_code, target_relation=relation) }}
{{ databricks__py_write_table(compiled_code=compiled_code, target_relation=relation) }}
{%- endif -%}
{%- endmacro -%}

Expand Down
78 changes: 78 additions & 0 deletions dbt/include/databricks/macros/python.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
{% macro databricks__py_write_table(compiled_code, target_relation) %}
{{ compiled_code }}
# --- Autogenerated dbt materialization code. --- #
dbt = dbtObj(spark.table)
df = model(dbt, spark)

import pyspark

{{ py_try_import('pandas', 'pandas_available') }}
{{ py_try_import('pyspark.pandas', 'pyspark_pandas_api_available') }}
{{ py_try_import('databricks.koalas', 'koalas_available') }}

# preferentially convert pandas DataFrames to pandas-on-Spark or Koalas DataFrames first
# since they know how to convert pandas DataFrames better than `spark.createDataFrame(df)`
# and converting from pandas-on-Spark to Spark DataFrame has no overhead

if pandas_available and isinstance(df, pandas.core.frame.DataFrame):
if pyspark_pandas_api_available:
df = pyspark.pandas.frame.DataFrame(df)
elif koalas_available:
df = databricks.koalas.frame.DataFrame(df)

# convert to pyspark.sql.dataframe.DataFrame
if isinstance(df, pyspark.sql.dataframe.DataFrame):
pass # since it is already a Spark DataFrame
elif pyspark_pandas_api_available and isinstance(df, pyspark.pandas.frame.DataFrame):
df = df.to_spark()
elif koalas_available and isinstance(df, databricks.koalas.frame.DataFrame):
df = df.to_spark()
elif pandas_available and isinstance(df, pandas.core.frame.DataFrame):
df = spark.createDataFrame(df)
else:
msg = f"{type(df)} is not a supported type for dbt Python materialization"
raise Exception(msg)

writer = (
df.write
.mode("overwrite")
.option("overwriteSchema", "true")
{{ py_get_writer_options()|indent(8, True) }}
)

writer.saveAsTable("{{ target_relation }}")
{% endmacro %}

{%- macro py_get_writer_options() -%}
{%- set location_root = config.get('location_root', validator=validation.any[basestring]) -%}
{%- set file_format = config.get('file_format', validator=validation.any[basestring])|default('delta', true) -%}
{%- set partition_by = config.get('partition_by', validator=validation.any[list, basestring]) -%}
{%- set clustered_by = config.get('clustered_by', validator=validation.any[list, basestring]) -%}
{%- set buckets = config.get('buckets', validator=validation.any[int]) -%}
.format("{{ file_format }}")
{%- if location_root is not none %}
{%- set identifier = model['alias'] %}
.option("path", "{{ location_root }}/{{ identifier }}")
{%- endif -%}
{%- if partition_by is not none -%}
{%- if partition_by is string -%}
{%- set partition_by = [partition_by] -%}
{%- endif %}
.partitionBy({{ partition_by }})
{%- endif -%}
{%- if (clustered_by is not none) and (buckets is not none) -%}
{%- if clustered_by is string -%}
{%- set clustered_by = [clustered_by] -%}
{%- endif %}
.bucketBy({{ buckets }}, {{ clustered_by }})
{%- endif -%}
{% endmacro -%}

{% macro py_try_import(library, var_name) -%}
# make sure {{ library }} exists before using it
try:
import {{ library }}
{{ var_name }} = True
except ImportError:
{{ var_name }} = False
{% endmacro %}
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pytest>=6.0.2
pytz
tox>=3.2.0
types-requests
types-mock

dbt-spark~=1.5.0
dbt-core~=1.5.0
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/python/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


def model(dbt, spark):
dbt.config(partition_by="id")
dbt.config(unique_key="name")
data = [[1, "Elia"], [2, "Teo"], [3, "Fang"]]

pdf = pd.DataFrame(data, columns=["id", "name"])
Expand Down
11 changes: 9 additions & 2 deletions tests/integration/python/models/models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,12 @@ models:
- name: basic
config:
materialized: table
tags: [ 'python' ]
http_path: '{{ var("http_path") }}'
tags: ["python"]
http_path: '{{ var("http_path") }}'
columns:
- name: id
tests:
- not_null
- name: name
tests:
- unique
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@


@pytest.mark.skip(
reason="Run manually. Test must start with the Python compute\
resource in TERMINATED or TERMINATING state"
reason="Run manually. Test must start with the Python compute resource in TERMINATED or \
TERMINATING state, as the purpose of the test is to validate successful cold start."
)
class TestPython(DBTIntegrationTest):
class TestPythonAutostart(DBTIntegrationTest):
@property
def schema(self):
return "python"
Expand All @@ -20,7 +20,9 @@ def models(self):
def project_config(self):
return {
"config-version": 2,
"vars": {"http_path": os.getenv("DBT_DATABRICKS_CLUSTER_HTTP_PATH")},
"vars": {
"http_path": os.getenv("DBT_DATABRICKS_CLUSTER_HTTP_PATH"),
},
}

def python_exc(self):
Expand All @@ -32,4 +34,11 @@ def test_python_databricks_sql_endpoint(self):

@use_profile("databricks_uc_sql_endpoint")
def test_python_databricks_uc_sql_endpoint(self):
self.use_default_project(
{
"vars": {
"http_path": os.getenv("DBT_DATABRICKS_UC_CLUSTER_HTTP_PATH"),
}
}
)
self.python_exc()
7 changes: 5 additions & 2 deletions tests/unit/macros/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_template(self, template_filename, parent_context=None):

return self.jinja_env.get_template(template_filename, globals=self.default_context)

def _run_macro(self, name, *args):
def _run_macro_raw(self, name, *args):
def dispatch(macro_name, macro_namespace=None, packages=None):
if hasattr(self.template.module, f"databricks__{macro_name}"):
return getattr(self.template.module, f"databricks__{macro_name}")
Expand All @@ -49,5 +49,8 @@ def dispatch(macro_name, macro_namespace=None, packages=None):

self.default_context["adapter"].dispatch = dispatch

value = getattr(self.template.module, name)(*args)
return getattr(self.template.module, name)(*args)

def _run_macro(self, name, *args):
value = self._run_macro_raw(name, *args)
return re.sub(r"\s\s+", " ", value).strip()
16 changes: 0 additions & 16 deletions tests/unit/macros/test_adapters_macros.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import re
from dbt.adapters.databricks.relation import DatabricksRelation

from tests.unit.macros.base import TestMacros
Expand Down Expand Up @@ -213,21 +212,6 @@ def setUp(self):

self.relation = DatabricksRelation.from_dict(data)

def __run_macro2(self, template, name, relation, *args):
self.default_context["model"].alias = relation

def dispatch(macro_name, macro_namespace=None, packages=None):
if hasattr(template.module, f"databricks__{macro_name}"):
return getattr(template.module, f"databricks__{macro_name}")
else:
return self.default_context[f"spark__{macro_name}"]

self.default_context["adapter"].dispatch = dispatch

value = getattr(template.module, name)(*args)
value = re.sub(r"\s\s+", " ", value)
return value

def test_macros_create_table_as(self):
sql = self._render_create_table_as(self.relation)

Expand Down
75 changes: 75 additions & 0 deletions tests/unit/macros/test_python_macros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from mock import MagicMock
from tests.unit.macros.base import TestMacros


class TestPythonMacros(TestMacros):
def setUp(self):
TestMacros.setUp(self)
self.default_context["model"] = MagicMock()
self.template = self._get_template("python.sql", "adapters.sql")

def test_py_get_writer__default_file_format(self):
result = self._run_macro_raw("py_get_writer_options")

self.assertEqual(result, '.format("delta")')

def test_py_get_writer__specified_file_format(self):
self.config["file_format"] = "parquet"
result = self._run_macro_raw("py_get_writer_options")

self.assertEqual(result, '.format("parquet")')

def test_py_get_writer__specified_location_root(self):
self.config["location_root"] = "s3://fake_location"
d = {"alias": "schema"}
self.default_context["model"].__getitem__.side_effect = d.__getitem__
result = self._run_macro_raw("py_get_writer_options")

expected = '.format("delta")\n.option("path", "s3://fake_location/schema")'
self.assertEqual(result, expected)

def test_py_get_writer__partition_by_single_column(self):
self.config["partition_by"] = "name"
result = self._run_macro_raw("py_get_writer_options")

expected = ".format(\"delta\")\n.partitionBy(['name'])"
self.assertEqual(result, expected)

def test_py_get_writer__partition_by_array(self):
self.config["partition_by"] = ["name", "date"]
result = self._run_macro_raw("py_get_writer_options")

self.assertEqual(result, (".format(\"delta\")\n.partitionBy(['name', 'date'])"))

def test_py_get_writer__clustered_by_single_column(self):
self.config["clustered_by"] = "name"
self.config["buckets"] = 2
result = self._run_macro_raw("py_get_writer_options")

self.assertEqual(result, (".format(\"delta\")\n.bucketBy(2, ['name'])"))

def test_py_get_writer__clustered_by_array(self):
self.config["clustered_by"] = ["name", "date"]
self.config["buckets"] = 2
result = self._run_macro_raw("py_get_writer_options")

self.assertEqual(result, (".format(\"delta\")\n.bucketBy(2, ['name', 'date'])"))

def test_py_get_writer__clustered_by_without_buckets(self):
self.config["clustered_by"] = ["name", "date"]
result = self._run_macro_raw("py_get_writer_options")

self.assertEqual(result, ('.format("delta")'))

def test_py_try_import__golden_path(self):
result = self._run_macro_raw("py_try_import", "pandas", "pandas_available")

expected = (
"# make sure pandas exists before using it\n"
"try:\n"
" import pandas\n"
" pandas_available = True\n"
"except ImportError:\n"
" pandas_available = False\n"
)
self.assertEqual(result, expected)
18 changes: 14 additions & 4 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,8 @@ def _test_databricks_sql_connector_connection(self, connect):
self.assertEqual(connection.credentials.schema, "analytics")
self.assertEqual(len(connection.credentials.session_properties), 1)
self.assertEqual(
connection.credentials.session_properties["spark.sql.ansi.enabled"], "true"
connection.credentials.session_properties["spark.sql.ansi.enabled"],
"true",
)
self.assertIsNone(connection.credentials.database)

Expand Down Expand Up @@ -455,7 +456,10 @@ def test_simple_catalog_relation(self):
rel_type = DatabricksRelation.get_relation_type.Table

relation = DatabricksRelation.create(
database="test_catalog", schema="default_schema", identifier="mytable", type=rel_type
database="test_catalog",
schema="default_schema",
identifier="mytable",
type=rel_type,
)
assert relation.database == "test_catalog"

Expand Down Expand Up @@ -491,7 +495,10 @@ def test_parse_relation(self):
("Location", "/mnt/vo"),
("Serde Library", "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"),
("InputFormat", "org.apache.hadoop.mapred.SequenceFileInputFormat"),
("OutputFormat", "org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"),
(
"OutputFormat",
"org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat",
),
("Partition Provider", "Catalog"),
]

Expand Down Expand Up @@ -638,7 +645,10 @@ def test_parse_relation_with_statistics(self):
("Location", "/mnt/vo"),
("Serde Library", "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"),
("InputFormat", "org.apache.hadoop.mapred.SequenceFileInputFormat"),
("OutputFormat", "org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"),
(
"OutputFormat",
"org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat",
),
("Partition Provider", "Catalog"),
]

Expand Down

0 comments on commit 279492c

Please sign in to comment.