Skip to content

Commit

Permalink
fix linting and hopefully tests
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Dec 5, 2023
1 parent 3ca0b7a commit 6725490
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 15 deletions.
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.11
88 changes: 81 additions & 7 deletions tests/unit/macros/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from typing import Any, Dict
from mock import Mock
import pytest
from jinja2 import Environment, FileSystemLoader, PackageLoader, Template
Expand All @@ -15,18 +16,27 @@ def __init__(self, template, context, relation):
class MacroTestBase:
@pytest.fixture(autouse=True)
def config(self, context) -> dict:
local_config = {}
"""
Anything you put in this dict will be returned by config in the rendered template
"""
local_config: Dict[str, Any] = {}
context["config"].get = lambda key, default=None, **kwargs: local_config.get(key, default)
return local_config

@pytest.fixture(autouse=True)
def var(self, context) -> dict:
local_var = {}
"""
Anything you put in this dict will be returned by config in the rendered template
"""
local_var: Dict[str, Any] = {}
context["var"] = lambda key, default=None, **kwargs: local_var.get(key, default)
return local_var

@pytest.fixture(scope="class")
def default_context(self) -> dict:
"""
This is the default context used in all tests.
"""
context = {
"validation": Mock(),
"model": Mock(),
Expand All @@ -36,31 +46,50 @@ def default_context(self) -> dict:
"adapter": Mock(),
"var": Mock(),
"return": lambda r: r,
"is_incremental": Mock(return_value=False),
}

return context

@pytest.fixture(scope="class")
def spark_env(self) -> Environment:
"""
The environment used for rendering dbt-spark macros
"""
return Environment(
loader=PackageLoader("dbt.include.spark", "macros"),
extensions=["jinja2.ext.do"],
)

@pytest.fixture(scope="class")
def spark_template_names(self) -> list:
"""
The list of Spark templates to load for the tests.
Use this if your macro relies on macros defined in templates we inherit from dbt-spark.
"""
return ["adapters.sql"]

@pytest.fixture(scope="class")
def spark_context(self, default_context, spark_env, spark_template_names) -> dict:
"""
Adds all the requested Spark macros to the context
"""
return self.build_up_context(default_context, spark_env, spark_template_names)

@pytest.fixture(scope="class")
def macro_folders_to_load(self) -> list:
"""
This is a list of folders from which we look to load Databricks macro templates.
All folders are relative to the dbt/include/databricks folder.
Folders will be searched for in the order they are listed here, in case of name collisions.
"""
return ["macros"]

@pytest.fixture(scope="class")
def databricks_env(self, macro_folders_to_load) -> Environment:
"""
The environment used for rendering Databricks macros
"""
return Environment(
loader=FileSystemLoader(
[f"dbt/include/databricks/{folder}" for folder in macro_folders_to_load]
Expand All @@ -70,32 +99,51 @@ def databricks_env(self, macro_folders_to_load) -> Environment:

@pytest.fixture(scope="class")
def databricks_template_names(self) -> list:
"""
The list of databricks templates to load for referencing imported macros in the
tests. Do not include the template you specify in template_name. Use this when you need a
macro defined in a template other than the one you render for the test.
Ex: If you are testing the python.sql template, you will also need to load ["adapters.sql"]
"""
return []

@pytest.fixture(scope="class")
def databricks_context(self, spark_context, databricks_env, databricks_template_names) -> dict:
"""
Adds all the requested Databricks macros to the context
"""
if not databricks_template_names:
return spark_context
return self.build_up_context(spark_context, databricks_env, databricks_template_names)

def build_up_context(self, context, env, template_names):
"""
Adds macros from the supplied env and template names to the context.
"""
new_context = context.copy()
for template_name in template_names:
template = env.get_template(template_name, globals=context)
new_context.update(template.module.__dict__)

return new_context

@pytest.fixture
def context(self, databricks_context) -> dict:
return databricks_context.copy()

@pytest.fixture(scope="class")
def template_name(self) -> str:
"""
The name of the Databricks template you want to test, not including the path.
Example: "adapters.sql"
"""
raise NotImplementedError("Must be implemented by subclasses")

@pytest.fixture
def template(self, template_name, context, databricks_env) -> Template:
def template(self, template_name, databricks_context, databricks_env) -> Template:
"""
This creates the template you will test against.
You generally don't want to override this.
"""
context = databricks_context.copy()
current_template = databricks_env.get_template(template_name, globals=context)

def dispatch(macro_name, macro_namespace=None, packages=None):
Expand All @@ -110,8 +158,21 @@ def dispatch(macro_name, macro_namespace=None, packages=None):

return current_template

@pytest.fixture
def context(self, template) -> dict:
"""
Access to the context used to render the template.
Modification of the context will work for mocking adapter calls, but may not work for
mocking macros.
If you need to mock a macro, see the use of is_incremental in default_context.
"""
return template.globals

@pytest.fixture(scope="class")
def relation(self):
"""
Dummy relation to use in tests.
"""
data = {
"path": {
"database": "some_database",
Expand All @@ -125,15 +186,28 @@ def relation(self):

@pytest.fixture
def template_bundle(self, template, context, relation):
"""
Bundles up the compiled template, its context, and a dummy relation.
"""
context["model"].alias = relation.identifier
return TemplateBundle(template, context, relation)

def run_macro_raw(self, template, name, *args):
"""
Run the named macro from a template, and return the rendered value.
"""
return getattr(template.module, name)(*args)

def run_macro(self, template, name, *args):
"""
Run the named macro from a template, and return the rendered value.
This version strips off extra whitespace and newlines.
"""
value = self.run_macro_raw(template, name, *args)
return re.sub(r"\s\s+", " ", value).strip()

def render_bundle(self, template_bundle, name, *args):
"""
Convenience method for macros that take a relation as a first argument.
"""
return self.run_macro(template_bundle.template, name, template_bundle.relation, *args)
3 changes: 0 additions & 3 deletions tests/unit/macros/relations/test_table_macros.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from mock import Mock
from jinja2 import Environment, FileSystemLoader, PackageLoader
import re
import pytest

from tests.unit.macros.base import MacroTestBase
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/macros/test_adapters_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def render_constraint_sql(self, template_bundle, constraint, *args):
"get_constraint_sql",
template_bundle.relation,
constraint,
*args
*args,
)

@pytest.fixture(scope="class")
Expand Down
10 changes: 6 additions & 4 deletions tests/unit/macros/test_python_macros.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from jinja2 import Template
from mock import MagicMock
from tests.unit.macros.base import MacroTestBase

Expand All @@ -6,7 +7,7 @@

class TestPythonMacros(MacroTestBase):
@pytest.fixture(scope="class", autouse=True)
def modify_context(self, default_context) -> dict:
def modify_context(self, default_context) -> None:
default_context["model"] = MagicMock()
d = {"alias": "schema"}
default_context["model"].__getitem__.side_effect = d.__getitem__
Expand All @@ -32,15 +33,16 @@ def test_py_get_writer__specified_file_format(self, config, template):

def test_py_get_writer__specified_location_root(self, config, template, context):
config["location_root"] = "s3://fake_location"
context["is_incremental"] = MagicMock(return_value=False)
result = self.run_macro_raw(template, "py_get_writer_options")

expected = '.format("delta")\n.option("path", "s3://fake_location/schema")'
assert result == expected

def test_py_get_writer__specified_location_root_on_incremental(self, config, template, context):
def test_py_get_writer__specified_location_root_on_incremental(
self, config, template: Template, context
):
config["location_root"] = "s3://fake_location"
context["is_incremental"] = MagicMock(return_value=True)
context["is_incremental"].return_value = True
result = self.run_macro_raw(template, "py_get_writer_options")

expected = '.format("delta")\n.option("path", "s3://fake_location/schema__dbt_tmp")'
Expand Down

0 comments on commit 6725490

Please sign in to comment.