From 9678427904567ccefb71539e17fc08a4032177d3 Mon Sep 17 00:00:00 2001 From: Quigley Malcolm Date: Sun, 2 Feb 2025 13:50:58 -0600 Subject: [PATCH] Add sample mode tests for incremental models --- .../sample_mode/test_sample_mode.py | 157 +++++++++++++++++- 1 file changed, 156 insertions(+), 1 deletion(-) diff --git a/tests/functional/sample_mode/test_sample_mode.py b/tests/functional/sample_mode/test_sample_mode.py index b25a6710add..84c535c4ded 100644 --- a/tests/functional/sample_mode/test_sample_mode.py +++ b/tests/functional/sample_mode/test_sample_mode.py @@ -1,5 +1,6 @@ import os from datetime import datetime +from typing import Optional import freezegun import pytest @@ -10,7 +11,7 @@ from dbt.event_time.sample_window import SampleWindow from dbt.events.types import JinjaLogInfo from dbt.materializations.incremental.microbatch import MicrobatchBuilder -from dbt.tests.util import read_file, relation_from_name, run_dbt +from dbt.tests.util import read_file, relation_from_name, run_dbt, write_file from tests.utils import EventCatcher input_model_sql = """ @@ -22,6 +23,21 @@ select 3 as id, TIMESTAMP '2025-01-02 12:32:00-0' as event_time """ +later_input_model_sql = """ +{{ config(materialized='table', event_time='event_time') }} +select 1 as id, TIMESTAMP '2020-01-01 01:25:00-0' as event_time +UNION ALL +select 2 as id, TIMESTAMP '2025-01-02 13:47:00-0' as event_time +UNION ALL +select 3 as id, TIMESTAMP '2025-01-03 12:32:00-0' as event_time +UNION ALL +select 4 as id, TIMESTAMP '2025-01-04 14:32:00-0' as event_time +UNION ALL +select 5 as id, TIMESTAMP '2025-01-05 20:32:00-0' as event_time +UNION ALL +select 6 as id, TIMESTAMP '2025-01-06 12:32:00-0' as event_time +""" + sample_mode_model_sql = """ {{ config(materialized='table', event_time='event_time') }} @@ -44,6 +60,21 @@ SELECT * FROM {{ ref("input_model") }} """ +sample_incremental_merge_sql = """ +{{ config(materialized='incremental', incremental_strategy='merge', unique_key='id')}} + +{% if execute %} + {{ log("is_incremental: " ~ is_incremental(), info=true) }} + {{ log("sample window: " ~ invocation_args_dict.get("sample_window"), info=true) }} +{% endif %} + +SELECT * FROM {{ ref("input_model") }} + +{% if is_incremental() %} + WHERE event_time >= (SELECT max(event_time) FROM {{ this }}) +{% endif %} +""" + class BaseSampleMode: # TODO This is now used in 3 test files, it might be worth turning into a full test utility method @@ -190,3 +221,127 @@ def test_sample_mode( relation_name="sample_microbatch_model", expected_row_count=2, ) + + +class TestIncrementalModelSampleModeRelative(BaseSampleMode): + @pytest.fixture(scope="class") + def models(self): + return { + "input_model.sql": input_model_sql, + "sample_incremental_merge.sql": sample_incremental_merge_sql, + } + + @pytest.fixture + def event_catcher(self) -> EventCatcher: + return EventCatcher(event_to_catch=JinjaLogInfo, predicate=lambda event: "is_incremental: True" in event.info.msg) # type: ignore + + @pytest.mark.parametrize( + "sample_mode_available,run_sample_mode,sample_window,expected_rows", + [ + (True, False, None, 6), + (True, True, "3 days", 6), + (True, True, "2 days", 5), + ], + ) + @freezegun.freeze_time("2025-01-06T18:03:0Z") + def test_incremental_model_sample( + self, + project, + mocker: MockerFixture, + event_catcher: EventCatcher, + sample_mode_available: bool, + run_sample_mode: bool, + sample_window: Optional[str], + expected_rows: int, + ): + write_file(input_model_sql, "models", "input_model.sql") + if sample_mode_available: + mocker.patch.dict(os.environ, {"DBT_EXPERIMENTAL_SAMPLE_MODE": "True"}) + + _ = run_dbt(["run", "--full-refresh"], callbacks=[event_catcher.catch]) + + assert len(event_catcher.caught_events) == 0 + self.assert_row_count( + project=project, + relation_name="sample_incremental_merge", + expected_row_count=3, + ) + + write_file(later_input_model_sql, "models", "input_model.sql") + + run_args = ["run"] + if run_sample_mode: + run_args.extend(["--sample", f"--sample-window={sample_window}"]) + + _ = run_dbt(run_args, callbacks=[event_catcher.catch]) + + assert len(event_catcher.caught_events) == 1 + self.assert_row_count( + project=project, + relation_name="sample_incremental_merge", + expected_row_count=expected_rows, + ) + + +class TestIncrementalModelSampleModeSpecific(BaseSampleMode): + # This had to be split out from the "relative" tests because `freezegun.freezetime` + # breaks how timestamps get created. + + @pytest.fixture(scope="class") + def models(self): + return { + "input_model.sql": input_model_sql, + "sample_incremental_merge.sql": sample_incremental_merge_sql, + } + + @pytest.fixture + def event_catcher(self) -> EventCatcher: + return EventCatcher(event_to_catch=JinjaLogInfo, predicate=lambda event: "is_incremental: True" in event.info.msg) # type: ignore + + @pytest.mark.parametrize( + "sample_mode_available,run_sample_mode,sample_window,expected_rows", + [ + (True, False, None, 6), + (True, True, "{'start': '2025-01-03', 'end': '2025-01-07'}", 6), + (True, True, "{'start': '2025-01-04', 'end': '2025-01-06'}", 5), + (True, True, "{'start': '2025-01-05', 'end': '2025-01-07'}", 5), + (True, True, "{'start': '2024-12-31', 'end': '2025-01-03'}", 3), + ], + ) + def test_incremental_model_sample( + self, + project, + mocker: MockerFixture, + event_catcher: EventCatcher, + sample_mode_available: bool, + run_sample_mode: bool, + sample_window: Optional[str], + expected_rows: int, + ): + write_file(input_model_sql, "models", "input_model.sql") + if sample_mode_available: + mocker.patch.dict(os.environ, {"DBT_EXPERIMENTAL_SAMPLE_MODE": "True"}) + + _ = run_dbt(["run", "--full-refresh"], callbacks=[event_catcher.catch]) + + assert len(event_catcher.caught_events) == 0 + self.assert_row_count( + project=project, + relation_name="sample_incremental_merge", + expected_row_count=3, + ) + + write_file(later_input_model_sql, "models", "input_model.sql") + + run_args = ["run"] + if run_sample_mode: + run_args.extend(["--sample", f"--sample-window={sample_window}"]) + + _ = run_dbt(run_args, callbacks=[event_catcher.catch]) + + assert len(event_catcher.caught_events) == 1 + self.assert_row_count( + project=project, + relation_name="sample_incremental_merge", + expected_row_count=expected_rows, + )