-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support more config options in python (#379)
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
- Loading branch information
Showing
12 changed files
with
202 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
name: Integration Tests | ||
on: | ||
push | ||
on: push | ||
jobs: | ||
run-tox-tests-uc: | ||
runs-on: ubuntu-latest | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
{% macro databricks__py_write_table(compiled_code, target_relation) %} | ||
{{ compiled_code }} | ||
# --- Autogenerated dbt materialization code. --- # | ||
dbt = dbtObj(spark.table) | ||
df = model(dbt, spark) | ||
|
||
import pyspark | ||
|
||
{{ py_try_import('pandas', 'pandas_available') }} | ||
{{ py_try_import('pyspark.pandas', 'pyspark_pandas_api_available') }} | ||
{{ py_try_import('databricks.koalas', 'koalas_available') }} | ||
|
||
# preferentially convert pandas DataFrames to pandas-on-Spark or Koalas DataFrames first | ||
# since they know how to convert pandas DataFrames better than `spark.createDataFrame(df)` | ||
# and converting from pandas-on-Spark to Spark DataFrame has no overhead | ||
|
||
if pandas_available and isinstance(df, pandas.core.frame.DataFrame): | ||
if pyspark_pandas_api_available: | ||
df = pyspark.pandas.frame.DataFrame(df) | ||
elif koalas_available: | ||
df = databricks.koalas.frame.DataFrame(df) | ||
|
||
# convert to pyspark.sql.dataframe.DataFrame | ||
if isinstance(df, pyspark.sql.dataframe.DataFrame): | ||
pass # since it is already a Spark DataFrame | ||
elif pyspark_pandas_api_available and isinstance(df, pyspark.pandas.frame.DataFrame): | ||
df = df.to_spark() | ||
elif koalas_available and isinstance(df, databricks.koalas.frame.DataFrame): | ||
df = df.to_spark() | ||
elif pandas_available and isinstance(df, pandas.core.frame.DataFrame): | ||
df = spark.createDataFrame(df) | ||
else: | ||
msg = f"{type(df)} is not a supported type for dbt Python materialization" | ||
raise Exception(msg) | ||
|
||
writer = ( | ||
df.write | ||
.mode("overwrite") | ||
.option("overwriteSchema", "true") | ||
{{ py_get_writer_options()|indent(8, True) }} | ||
) | ||
|
||
writer.saveAsTable("{{ target_relation }}") | ||
{% endmacro %} | ||
|
||
{%- macro py_get_writer_options() -%} | ||
{%- set location_root = config.get('location_root', validator=validation.any[basestring]) -%} | ||
{%- set file_format = config.get('file_format', validator=validation.any[basestring])|default('delta', true) -%} | ||
{%- set partition_by = config.get('partition_by', validator=validation.any[list, basestring]) -%} | ||
{%- set clustered_by = config.get('clustered_by', validator=validation.any[list, basestring]) -%} | ||
{%- set buckets = config.get('buckets', validator=validation.any[int]) -%} | ||
.format("{{ file_format }}") | ||
{%- if location_root is not none %} | ||
{%- set identifier = model['alias'] %} | ||
.option("path", "{{ location_root }}/{{ identifier }}") | ||
{%- endif -%} | ||
{%- if partition_by is not none -%} | ||
{%- if partition_by is string -%} | ||
{%- set partition_by = [partition_by] -%} | ||
{%- endif %} | ||
.partitionBy({{ partition_by }}) | ||
{%- endif -%} | ||
{%- if (clustered_by is not none) and (buckets is not none) -%} | ||
{%- if clustered_by is string -%} | ||
{%- set clustered_by = [clustered_by] -%} | ||
{%- endif %} | ||
.bucketBy({{ buckets }}, {{ clustered_by }}) | ||
{%- endif -%} | ||
{% endmacro -%} | ||
|
||
{% macro py_try_import(library, var_name) -%} | ||
# make sure {{ library }} exists before using it | ||
try: | ||
import {{ library }} | ||
{{ var_name }} = True | ||
except ImportError: | ||
{{ var_name }} = False | ||
{% endmacro %} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ pytest>=6.0.2 | |
pytz | ||
tox>=3.2.0 | ||
types-requests | ||
types-mock | ||
|
||
dbt-spark~=1.5.0 | ||
dbt-core~=1.5.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from mock import MagicMock | ||
from tests.unit.macros.base import TestMacros | ||
|
||
|
||
class TestPythonMacros(TestMacros): | ||
def setUp(self): | ||
TestMacros.setUp(self) | ||
self.default_context["model"] = MagicMock() | ||
self.template = self._get_template("python.sql", "adapters.sql") | ||
|
||
def test_py_get_writer__default_file_format(self): | ||
result = self._run_macro_raw("py_get_writer_options") | ||
|
||
self.assertEqual(result, '.format("delta")') | ||
|
||
def test_py_get_writer__specified_file_format(self): | ||
self.config["file_format"] = "parquet" | ||
result = self._run_macro_raw("py_get_writer_options") | ||
|
||
self.assertEqual(result, '.format("parquet")') | ||
|
||
def test_py_get_writer__specified_location_root(self): | ||
self.config["location_root"] = "s3://fake_location" | ||
d = {"alias": "schema"} | ||
self.default_context["model"].__getitem__.side_effect = d.__getitem__ | ||
result = self._run_macro_raw("py_get_writer_options") | ||
|
||
expected = '.format("delta")\n.option("path", "s3://fake_location/schema")' | ||
self.assertEqual(result, expected) | ||
|
||
def test_py_get_writer__partition_by_single_column(self): | ||
self.config["partition_by"] = "name" | ||
result = self._run_macro_raw("py_get_writer_options") | ||
|
||
expected = ".format(\"delta\")\n.partitionBy(['name'])" | ||
self.assertEqual(result, expected) | ||
|
||
def test_py_get_writer__partition_by_array(self): | ||
self.config["partition_by"] = ["name", "date"] | ||
result = self._run_macro_raw("py_get_writer_options") | ||
|
||
self.assertEqual(result, (".format(\"delta\")\n.partitionBy(['name', 'date'])")) | ||
|
||
def test_py_get_writer__clustered_by_single_column(self): | ||
self.config["clustered_by"] = "name" | ||
self.config["buckets"] = 2 | ||
result = self._run_macro_raw("py_get_writer_options") | ||
|
||
self.assertEqual(result, (".format(\"delta\")\n.bucketBy(2, ['name'])")) | ||
|
||
def test_py_get_writer__clustered_by_array(self): | ||
self.config["clustered_by"] = ["name", "date"] | ||
self.config["buckets"] = 2 | ||
result = self._run_macro_raw("py_get_writer_options") | ||
|
||
self.assertEqual(result, (".format(\"delta\")\n.bucketBy(2, ['name', 'date'])")) | ||
|
||
def test_py_get_writer__clustered_by_without_buckets(self): | ||
self.config["clustered_by"] = ["name", "date"] | ||
result = self._run_macro_raw("py_get_writer_options") | ||
|
||
self.assertEqual(result, ('.format("delta")')) | ||
|
||
def test_py_try_import__golden_path(self): | ||
result = self._run_macro_raw("py_try_import", "pandas", "pandas_available") | ||
|
||
expected = ( | ||
"# make sure pandas exists before using it\n" | ||
"try:\n" | ||
" import pandas\n" | ||
" pandas_available = True\n" | ||
"except ImportError:\n" | ||
" pandas_available = False\n" | ||
) | ||
self.assertEqual(result, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters