From 67bf98d0fa7d24a2affc4d2311c06827d5fba8ae Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Fri, 17 Jan 2025 10:38:15 -0800 Subject: [PATCH] Adding missing 1.9 Snapshot behavior (#904) --- CHANGELOG.md | 4 + .../macros/materializations/snapshot.sql | 88 ++--- .../simple_snapshot/test_new_record_mode.py | 74 ++++ .../simple_snapshot/test_various_configs.py | 345 ++++++++++++++++++ 4 files changed, 459 insertions(+), 52 deletions(-) create mode 100644 tests/functional/adapter/simple_snapshot/test_new_record_mode.py create mode 100644 tests/functional/adapter/simple_snapshot/test_various_configs.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c01efeeb..54f17963 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## dbt-databricks 1.9.2 (TBD) +### Features + +- Update snapshot materialization to support new snapshot features ([904](https://github.com/databricks/dbt-databricks/pull/904)) + ### Under the Hood - Refactor global state reading ([888](https://github.com/databricks/dbt-databricks/pull/888)) diff --git a/dbt/include/databricks/macros/materializations/snapshot.sql b/dbt/include/databricks/macros/materializations/snapshot.sql index 3d1236a1..3a513a24 100644 --- a/dbt/include/databricks/macros/materializations/snapshot.sql +++ b/dbt/include/databricks/macros/materializations/snapshot.sql @@ -1,27 +1,4 @@ -{% macro databricks_build_snapshot_staging_table(strategy, sql, target_relation) %} - {% set tmp_identifier = target_relation.identifier ~ '__dbt_tmp' %} - - {%- set tmp_relation = api.Relation.create(identifier=tmp_identifier, - schema=target_relation.schema, - database=target_relation.database, - type='view') -%} - - {% set select = snapshot_staging_table(strategy, sql, target_relation) %} - - {# needs to be a non-temp view so that its columns can be ascertained via `describe` #} - {% call statement('build_snapshot_staging_relation') %} - create or replace view {{ tmp_relation }} - as - {{ select }} - {% endcall %} - - {% do return(tmp_relation) %} -{% endmacro %} - - {% materialization snapshot, adapter='databricks' %} - {%- set config = model['config'] -%} - {%- set target_table = model.get('alias', model.get('name')) -%} {%- set strategy_name = config.get('strategy') -%} @@ -62,47 +39,43 @@ {{ run_hooks(pre_hooks, inside_transaction=True) }} {% set strategy_macro = strategy_dispatch(strategy_name) %} - {% set strategy = strategy_macro(model, "snapshotted_data", "source_data", config, target_relation_exists) %} + {% set strategy = strategy_macro(model, "snapshotted_data", "source_data", model['config'], target_relation_exists) %} {% if not target_relation_exists %} {% set build_sql = build_snapshot_table(strategy, model['compiled_code']) %} + {% set build_or_select_sql = build_sql %} {% set final_sql = create_table_as(False, target_relation, build_sql) %} - {% call statement('main') %} - {{ final_sql }} - {% endcall %} - - {% do persist_docs(target_relation, model, for_relation=False) %} - {% else %} - {{ adapter.valid_snapshot_target(target_relation) }} + {% set columns = config.get("snapshot_table_column_names") or get_snapshot_table_column_names() %} - {% if target_relation.database is none %} - {% set staging_table = spark_build_snapshot_staging_table(strategy, sql, target_relation) %} - {% else %} - {% set staging_table = databricks_build_snapshot_staging_table(strategy, sql, target_relation) %} - {% endif %} + {{ adapter.assert_valid_snapshot_target_given_strategy(target_relation, columns, strategy) }} + + {% set build_or_select_sql = snapshot_staging_table(strategy, sql, target_relation) %} + {% set staging_table = build_snapshot_staging_table(strategy, sql, target_relation) %} -- this may no-op if the database does not require column expansion {% do adapter.expand_target_column_types(from_relation=staging_table, to_relation=target_relation) %} + {% set remove_columns = ['dbt_change_type', 'DBT_CHANGE_TYPE', 'dbt_unique_key', 'DBT_UNIQUE_KEY'] %} + {% if unique_key | is_list %} + {% for key in strategy.unique_key %} + {{ remove_columns.append('dbt_unique_key_' + loop.index|string) }} + {{ remove_columns.append('DBT_UNIQUE_KEY_' + loop.index|string) }} + {% endfor %} + {% endif %} + {% set missing_columns = adapter.get_missing_columns(staging_table, target_relation) - | rejectattr('name', 'equalto', 'dbt_change_type') - | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') - | rejectattr('name', 'equalto', 'dbt_unique_key') - | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | rejectattr('name', 'in', remove_columns) | list %} {% do create_columns(target_relation, missing_columns) %} {% set source_columns = adapter.get_columns_in_relation(staging_table) - | rejectattr('name', 'equalto', 'dbt_change_type') - | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') - | rejectattr('name', 'equalto', 'dbt_unique_key') - | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | rejectattr('name', 'in', remove_columns) | list %} {% set quoted_source_columns = [] %} @@ -117,23 +90,34 @@ ) %} - {% call statement_with_staging_table('main', staging_table) %} - {{ final_sql }} - {% endcall %} + {% endif %} - {% do persist_docs(target_relation, model, for_relation=True) %} - {% endif %} + {{ check_time_data_types(build_or_select_sql) }} - {% set should_revoke = should_revoke(target_relation_exists, full_refresh_mode) %} - {% do apply_grants(target_relation, grant_config, should_revoke) %} + {% call statement('main') %} + {{ final_sql }} + {% endcall %} - {% do persist_constraints(target_relation, model) %} + {% set should_revoke = should_revoke(target_relation_exists, full_refresh_mode=False) %} + {% do apply_grants(target_relation, grant_config, should_revoke=should_revoke) %} + + {% do persist_docs(target_relation, model) %} + + {% if not target_relation_exists %} + {% do create_indexes(target_relation) %} + {% endif %} {{ run_hooks(post_hooks, inside_transaction=True) }} {{ adapter.commit() }} + {% if staging_table is defined %} + {% do post_snapshot(staging_table) %} + {% endif %} + + {% do persist_constraints(target_relation, model) %} + {{ run_hooks(post_hooks, inside_transaction=False) }} {{ return({'relations': [target_relation]}) }} diff --git a/tests/functional/adapter/simple_snapshot/test_new_record_mode.py b/tests/functional/adapter/simple_snapshot/test_new_record_mode.py new file mode 100644 index 00000000..6b436a31 --- /dev/null +++ b/tests/functional/adapter/simple_snapshot/test_new_record_mode.py @@ -0,0 +1,74 @@ +import pytest + +from dbt.tests.adapter.simple_snapshot.new_record_mode import ( + _delete_sql, + _invalidate_sql, + _ref_snapshot_sql, + _seed_new_record_mode, + _snapshot_actual_sql, + _snapshots_yml, + _update_sql, +) +from dbt.tests.util import check_relations_equal, run_dbt + + +class TestDatabricksSnapshotNewRecordMode: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": _snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": _snapshots_yml, + "ref_snapshot.sql": _ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def seed_new_record_mode(self): + return _seed_new_record_mode + + @pytest.fixture(scope="class") + def invalidate_sql_1(self): + return _invalidate_sql.split(";", 1)[0].replace("BEGIN", "") + + @pytest.fixture(scope="class") + def invalidate_sql_2(self): + return _invalidate_sql.split(";", 1)[1].replace("END", "").replace(";", "") + + @pytest.fixture(scope="class") + def update_sql(self): + return _update_sql.replace("text", "string") + + @pytest.fixture(scope="class") + def delete_sql(self): + return _delete_sql + + def test_snapshot_new_record_mode( + self, project, seed_new_record_mode, invalidate_sql_1, invalidate_sql_2, update_sql + ): + for sql in ( + seed_new_record_mode.replace("text", "string") + .replace("TEXT", "STRING") + .replace("BEGIN", "") + .replace("END;", "") + .replace(" WITHOUT TIME ZONE", "") + .split(";") + ): + project.run_sql(sql) + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + project.run_sql(_delete_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 diff --git a/tests/functional/adapter/simple_snapshot/test_various_configs.py b/tests/functional/adapter/simple_snapshot/test_various_configs.py new file mode 100644 index 00000000..18b82de0 --- /dev/null +++ b/tests/functional/adapter/simple_snapshot/test_various_configs.py @@ -0,0 +1,345 @@ +import datetime + +import pytest +from agate import Table + +from dbt.tests.adapter.simple_snapshot.fixtures import ( + create_multi_key_seed_sql, + create_multi_key_snapshot_expected_sql, + create_seed_sql, + create_snapshot_expected_sql, + model_seed_sql, + populate_multi_key_snapshot_expected_sql, + populate_snapshot_expected_sql, + populate_snapshot_expected_valid_to_current_sql, + ref_snapshot_sql, + seed_insert_sql, + seed_multi_key_insert_sql, + snapshot_actual_sql, + snapshots_multi_key_yml, + snapshots_no_column_names_yml, + snapshots_valid_to_current_yml, + snapshots_yml, + update_multi_key_sql, + update_sql, + update_with_current_sql, +) +from dbt.tests.util import ( + check_relations_equal, + get_manifest, + run_dbt, + run_dbt_and_capture, + run_sql_with_adapter, + update_config_file, +) + + +def text_replace(input: str) -> str: + return input.replace("TEXT", "STRING").replace("text", "string") + + +create_snapshot_expected_sql = text_replace(create_snapshot_expected_sql) +populate_snapshot_expected_sql = text_replace(populate_snapshot_expected_sql) +populate_snapshot_expected_valid_to_current_sql = text_replace( + populate_snapshot_expected_valid_to_current_sql +) +update_with_current_sql = text_replace(update_with_current_sql) +create_multi_key_snapshot_expected_sql = text_replace(create_multi_key_snapshot_expected_sql) +populate_multi_key_snapshot_expected_sql = text_replace(populate_multi_key_snapshot_expected_sql) +update_sql = text_replace(update_sql) +update_multi_key_sql = text_replace(update_multi_key_sql) + +invalidate_sql_1 = """ +-- update records 11 - 21. Change email and updated_at field +update {schema}.seed set + updated_at = updated_at + interval '1 hour', + email = case when id = 20 then 'pfoxj@creativecommons.org' else 'new_' || email end +where id >= 10 and id <= 20 +""" + +invalidate_sql_2 = """ +-- invalidate records 11 - 21 +update {schema}.snapshot_expected set + test_valid_to = updated_at + interval '1 hour' +where id >= 10 and id <= 20; +""" + +invalidate_multi_key_sql_1 = """ +-- update records 11 - 21. Change email and updated_at field +update {schema}.seed set + updated_at = updated_at + interval '1 hour', + email = case when id1 = 20 then 'pfoxj@creativecommons.org' else 'new_' || email end +where id1 >= 10 and id1 <= 20; +""" + +invalidate_multi_key_sql_2 = """ +-- invalidate records 11 - 21 +update {schema}.snapshot_expected set + test_valid_to = updated_at + interval '1 hour' +where id1 >= 10 and id1 <= 20; +""" + + +class BaseSnapshotColumnNames: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_snapshot_column_names(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class BaseSnapshotColumnNamesFromDbtProject: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_no_column_names_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + } + } + } + + def test_snapshot_column_names_from_project(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class BaseSnapshotInvalidColumnNames: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_no_column_names_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + } + } + } + + def test_snapshot_invalid_column_names(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + manifest = get_manifest(project.project_root) + snapshot_node = manifest.nodes["snapshot.test.snapshot_actual"] + snapshot_node.config.snapshot_meta_column_names == { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_sql) + + # Change snapshot_meta_columns and look for an error + different_columns = { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_updated_at": "test_updated_at", + } + } + } + } + update_config_file(different_columns, "dbt_project.yml") + + results, log_output = run_dbt_and_capture(["snapshot"], expect_pass=False) + assert len(results) == 1 + assert "Compilation Error in snapshot snapshot_actual" in log_output + assert "Snapshot target is missing configured columns" in log_output + + +class BaseSnapshotDbtValidToCurrent: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_valid_to_current_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_valid_to_current(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_valid_to_current_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + original_snapshot: Table = run_sql_with_adapter( + project.adapter, + "select id, test_scd_id, test_valid_to from {schema}.snapshot_actual", + "all", + ) + assert original_snapshot[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + original_row = list( + filter(lambda x: x[1] == "61ecd07d17b8a4acb57d115eebb0e2c9", original_snapshot) + ) + assert original_row[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + + project.run_sql(invalidate_sql_1) + project.run_sql(invalidate_sql_2) + project.run_sql(update_with_current_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + updated_snapshot: Table = run_sql_with_adapter( + project.adapter, + "select id, test_scd_id, test_valid_to from {schema}.snapshot_actual", + "all", + ) + print(updated_snapshot) + assert updated_snapshot[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + # Original row that was updated now has a non-current (2099/12/31) date + original_row = list( + filter(lambda x: x[1] == "61ecd07d17b8a4acb57d115eebb0e2c9", updated_snapshot) + ) + assert original_row[0][2] == datetime.datetime( + 2016, 8, 20, 16, 44, 49, tzinfo=datetime.timezone.utc + ) + updated_row = list( + filter(lambda x: x[1] == "af1f803f2179869aeacb1bfe2b23c1df", updated_snapshot) + ) + + # Updated row has a current date + assert updated_row[0][2] == datetime.datetime( + 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc + ) + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +# This uses snapshot_meta_column_names, yaml-only snapshot def, +# and multiple keys +class BaseSnapshotMultiUniqueKey: + @pytest.fixture(scope="class") + def models(self): + return { + "seed.sql": model_seed_sql, + "snapshots.yml": snapshots_multi_key_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_multi_column_unique_key(self, project): + project.run_sql(create_multi_key_seed_sql) + project.run_sql(create_multi_key_snapshot_expected_sql) + project.run_sql(seed_multi_key_insert_sql) + project.run_sql(populate_multi_key_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_multi_key_sql_1) + project.run_sql(invalidate_multi_key_sql_2) + project.run_sql(update_multi_key_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class TestDatabricksSnapshotColumnNames(BaseSnapshotColumnNames): + pass + + +class TestDatabricksSnapshotColumnNamesFromDbtProject(BaseSnapshotColumnNamesFromDbtProject): + pass + + +class TestDatabricksSnapshotInvalidColumnNames(BaseSnapshotInvalidColumnNames): + pass + + +class TestDatabricksSnapshotDbtValidToCurrent(BaseSnapshotDbtValidToCurrent): + pass + + +class TestDatabricksSnapshotMultiUniqueKey(BaseSnapshotMultiUniqueKey): + pass