diff --git a/.changes/unreleased/Features-20250212-155658.yaml b/.changes/unreleased/Features-20250212-155658.yaml new file mode 100644 index 00000000000..d56ba76cb63 --- /dev/null +++ b/.changes/unreleased/Features-20250212-155658.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Combine `--sample` and `--sample-window` CLI params +time: 2025-02-12T15:56:58.546879-06:00 +custom: + Author: QMalcolm + Issue: "11299" diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index 91e51242d44..a324049c90b 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -556,7 +556,6 @@ def parse(ctx, **kwargs): @p.event_time_start @p.event_time_end @p.sample -@p.sample_window @p.select @p.selector @p.target_path diff --git a/core/dbt/cli/option_types.py b/core/dbt/cli/option_types.py index 006b28c7c19..c9af65fa151 100644 --- a/core/dbt/cli/option_types.py +++ b/core/dbt/cli/option_types.py @@ -94,8 +94,8 @@ def convert(self, value, param, ctx): return value -class SampleWindowType(ParamType): - name = "SAMPLE_WINDOW" +class SampleType(ParamType): + name = "SAMPLE" def convert( self, value, param: Optional[Parameter], ctx: Optional[Context] diff --git a/core/dbt/cli/params.py b/core/dbt/cli/params.py index c9bebef59bd..7153c7b0e08 100644 --- a/core/dbt/cli/params.py +++ b/core/dbt/cli/params.py @@ -6,7 +6,7 @@ YAML, ChoiceTuple, Package, - SampleWindowType, + SampleType, WarnErrorOptionsType, ) from dbt.cli.options import MultiOption @@ -525,20 +525,11 @@ ) sample = click.option( - "--sample/--no-sample", + "--sample", envvar="DBT_SAMPLE", - help="Run in sample mode, creating only samples of models where possible", - default=False, - is_flag=True, - hidden=True, # TODO: Unhide -) - -sample_window = click.option( - "--sample-window", - envvar="DBT_SAMPLE_WINDOW", - help="The time window to use with sample mode. Example: '3 days'.", + help="Run in sample mode with given SAMPLE_WINDOW spec, such that ref/source calls are sampled by the sample window.", default=None, - type=SampleWindowType(), + type=SampleType(), hidden=True, # TODO: Unhide ) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 41b674f0d5e..46b57b140a8 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -239,8 +239,7 @@ def resolve_event_time_filter(self, target: ManifestNode) -> Optional[EventTimeF event_time_filter = None sample_mode = bool( os.environ.get("DBT_EXPERIMENTAL_SAMPLE_MODE") - and getattr(self.config.args, "sample", False) - and getattr(self.config.args, "sample_window", None) + and getattr(self.config.args, "sample", None) ) # TODO The number of branches here is getting rough. We should consider ways to simplify @@ -263,13 +262,13 @@ def resolve_event_time_filter(self, target: ManifestNode) -> Optional[EventTimeF # Sample mode microbatch models if sample_mode: start = ( - self.config.args.sample_window.start - if self.config.args.sample_window.start > self.model.batch.event_time_start + self.config.args.sample.start + if self.config.args.sample.start > self.model.batch.event_time_start else self.model.batch.event_time_start ) end = ( - self.config.args.sample_window.end - if self.config.args.sample_window.end < self.model.batch.event_time_end + self.config.args.sample.end + if self.config.args.sample.end < self.model.batch.event_time_end else self.model.batch.event_time_end ) event_time_filter = EventTimeFilter( @@ -290,8 +289,8 @@ def resolve_event_time_filter(self, target: ManifestNode) -> Optional[EventTimeF elif sample_mode: event_time_filter = EventTimeFilter( field_name=target.config.event_time, - start=self.config.args.sample_window.start, - end=self.config.args.sample_window.end, + start=self.config.args.sample.start, + end=self.config.args.sample.end, ) return event_time_filter diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index ac9fe761df6..f4d0fc584d4 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -561,13 +561,11 @@ def _execute_microbatch_materialization( event_time_start = getattr(self.config.args, "EVENT_TIME_START", None) event_time_end = getattr(self.config.args, "EVENT_TIME_END", None) - if ( - os.environ.get("DBT_EXPERIMENTAL_SAMPLE_MODE") - and getattr(self.config.args, "SAMPLE", None) - and getattr(self.config.args, "SAMPLE_WINDOW", None) + if os.environ.get("DBT_EXPERIMENTAL_SAMPLE_MODE") and getattr( + self.config.args, "SAMPLE", None ): - event_time_start = self.config.args.sample_window.start - event_time_end = self.config.args.sample_window.end + event_time_start = self.config.args.sample.start + event_time_end = self.config.args.sample.end microbatch_builder = MicrobatchBuilder( model=model, diff --git a/tests/functional/sample_mode/test_sample_mode.py b/tests/functional/sample_mode/test_sample_mode.py index 54737da2877..e4883e55287 100644 --- a/tests/functional/sample_mode/test_sample_mode.py +++ b/tests/functional/sample_mode/test_sample_mode.py @@ -42,8 +42,7 @@ {{ config(materialized='table', event_time='event_time') }} {% if execute %} - {{ log("Sample mode: " ~ invocation_args_dict.get("sample"), info=true) }} - {{ log("Sample window: " ~ invocation_args_dict.get("sample_window"), info=true) }} + {{ log("Sample: " ~ invocation_args_dict.get("sample"), info=true) }} {% endif %} SELECT * FROM {{ ref("input_model") }} @@ -65,7 +64,7 @@ {% if execute %} {{ log("is_incremental: " ~ is_incremental(), info=true) }} - {{ log("sample window: " ~ invocation_args_dict.get("sample_window"), info=true) }} + {{ log("sample: " ~ invocation_args_dict.get("sample"), info=true) }} {% endif %} SELECT * FROM {{ ref("input_model") }} @@ -102,12 +101,12 @@ def event_catcher(self) -> EventCatcher: return EventCatcher(event_to_catch=JinjaLogInfo) # type: ignore @pytest.mark.parametrize( - "sample_mode_available,run_sample_mode,expected_row_count,arg_value_in_jinja", + "sample_mode_available,run_sample_mode,expected_row_count", [ - (True, True, 2, True), - (True, False, 3, False), - (False, True, 3, True), - (False, False, 3, False), + (True, True, 2), + (True, False, 3), + (False, True, 3), + (False, False, 3), ], ) @freezegun.freeze_time("2025-01-03T02:03:0Z") @@ -119,13 +118,12 @@ def test_sample_mode( sample_mode_available: bool, run_sample_mode: bool, expected_row_count: int, - arg_value_in_jinja: bool, ): run_args = ["run"] - expected_sample_window = None + expected_sample = None if run_sample_mode: - run_args.extend(["--sample", "--sample-window=1 day"]) - expected_sample_window = SampleWindow( + run_args.append("--sample=1 day") + expected_sample = SampleWindow( start=datetime(2025, 1, 2, 2, 3, 0, 0, tzinfo=pytz.UTC), end=datetime(2025, 1, 3, 2, 3, 0, 0, tzinfo=pytz.UTC), ) @@ -134,9 +132,8 @@ def test_sample_mode( mocker.patch.dict(os.environ, {"DBT_EXPERIMENTAL_SAMPLE_MODE": "1"}) _ = run_dbt(run_args, callbacks=[event_catcher.catch]) - assert len(event_catcher.caught_events) == 2 - assert event_catcher.caught_events[0].info.msg == f"Sample mode: {arg_value_in_jinja}" # type: ignore - assert event_catcher.caught_events[1].info.msg == f"Sample window: {expected_sample_window}" # type: ignore + assert len(event_catcher.caught_events) == 1 + assert event_catcher.caught_events[0].info.msg == f"Sample: {expected_sample}" # type: ignore self.assert_row_count( project=project, relation_name="sample_mode_model", @@ -207,7 +204,7 @@ def test_sample_mode( if sample_mode_available: mocker.patch.dict(os.environ, {"DBT_EXPERIMENTAL_SAMPLE_MODE": "True"}) _ = run_dbt( - ["run", "--sample", "--sample-window=2 day"], + ["run", "--sample=2 day"], callbacks=[event_time_end_catcher.catch, event_time_start_catcher.catch], ) assert len(event_time_start_catcher.caught_events) == len(expected_batches) @@ -254,12 +251,12 @@ def event_catcher(self) -> EventCatcher: return EventCatcher(event_to_catch=JinjaLogInfo, predicate=lambda event: "is_incremental: True" in event.info.msg) # type: ignore @pytest.mark.parametrize( - "sample_mode_available,run_sample_mode,sample_window,expected_rows", + "sample_mode_available,sample,expected_rows", [ - (True, False, None, 6), - (True, True, "3 days", 6), - (True, True, "2 days", 5), - (False, True, "2 days", 6), + (True, None, 6), + (True, "3 days", 6), + (True, "2 days", 5), + (False, "2 days", 6), ], ) @freezegun.freeze_time("2025-01-06T18:03:0Z") @@ -269,8 +266,7 @@ def test_incremental_model_sample( mocker: MockerFixture, event_catcher: EventCatcher, sample_mode_available: bool, - run_sample_mode: bool, - sample_window: Optional[str], + sample: Optional[str], expected_rows: int, ): # writing the input_model is necessary because we've parametrized the test @@ -293,8 +289,8 @@ def test_incremental_model_sample( write_file(later_input_model_sql, "models", "input_model.sql") run_args = ["run"] - if run_sample_mode: - run_args.extend(["--sample", f"--sample-window={sample_window}"]) + if sample is not None: + run_args.extend([f"--sample={sample}"]) _ = run_dbt(run_args, callbacks=[event_catcher.catch]) @@ -322,14 +318,14 @@ def event_catcher(self) -> EventCatcher: return EventCatcher(event_to_catch=JinjaLogInfo, predicate=lambda event: "is_incremental: True" in event.info.msg) # type: ignore @pytest.mark.parametrize( - "sample_mode_available,run_sample_mode,sample_window,expected_rows", + "sample_mode_available,sample,expected_rows", [ - (True, False, None, 6), - (True, True, "{'start': '2025-01-03', 'end': '2025-01-07'}", 6), - (True, True, "{'start': '2025-01-04', 'end': '2025-01-06'}", 5), - (True, True, "{'start': '2025-01-05', 'end': '2025-01-07'}", 5), - (True, True, "{'start': '2024-12-31', 'end': '2025-01-03'}", 3), - (False, True, "{'start': '2024-12-31', 'end': '2025-01-03'}", 6), + (True, None, 6), + (True, "{'start': '2025-01-03', 'end': '2025-01-07'}", 6), + (True, "{'start': '2025-01-04', 'end': '2025-01-06'}", 5), + (True, "{'start': '2025-01-05', 'end': '2025-01-07'}", 5), + (True, "{'start': '2024-12-31', 'end': '2025-01-03'}", 3), + (False, "{'start': '2024-12-31', 'end': '2025-01-03'}", 6), ], ) def test_incremental_model_sample( @@ -338,8 +334,7 @@ def test_incremental_model_sample( mocker: MockerFixture, event_catcher: EventCatcher, sample_mode_available: bool, - run_sample_mode: bool, - sample_window: Optional[str], + sample: Optional[str], expected_rows: int, ): # writing the input_model is necessary because we've parametrized the test @@ -362,8 +357,8 @@ def test_incremental_model_sample( write_file(later_input_model_sql, "models", "input_model.sql") run_args = ["run"] - if run_sample_mode: - run_args.extend(["--sample", f"--sample-window={sample_window}"]) + if sample is not None: + run_args.extend([f"--sample={sample}"]) _ = run_dbt(run_args, callbacks=[event_catcher.catch]) diff --git a/tests/unit/cli/test_option_types.py b/tests/unit/cli/test_option_types.py index bf43df8d13a..efc878e23c4 100644 --- a/tests/unit/cli/test_option_types.py +++ b/tests/unit/cli/test_option_types.py @@ -6,7 +6,7 @@ import pytz from click import BadParameter, Option -from dbt.cli.option_types import YAML, SampleWindowType +from dbt.cli.option_types import YAML, SampleType from dbt.event_time.sample_window import SampleWindow @@ -32,7 +32,7 @@ def test_yaml_init_invalid_yaml_str(self, invalid_yaml_str): assert "--vars" in e.value.format_message() -class TestSampleWindowType: +class TestSampleType: @pytest.mark.parametrize( "input,expected_result", [ @@ -61,7 +61,7 @@ class TestSampleWindowType: ) def test_convert(self, input: str, expected_result: Union[SampleWindow, Exception]): try: - result = SampleWindowType().convert(input, Option(["--sample-window"]), None) + result = SampleType().convert(input, Option(["--sample"]), None) assert result == expected_result except Exception as e: assert str(e) == str(expected_result) @@ -76,5 +76,5 @@ def test_convert_relative(self): start=datetime(2025, 1, 25, 2, 3, 0, 0, pytz.UTC), end=datetime(2025, 1, 28, 2, 3, 0, 0, pytz.UTC), ) - result = SampleWindowType().convert(input, Option(["--sample-window"]), None) + result = SampleType().convert(input, Option(["--sample"]), None) assert result == expected_result diff --git a/tests/unit/context/test_providers.py b/tests/unit/context/test_providers.py index 415aee86dde..f4ce723cdcf 100644 --- a/tests/unit/context/test_providers.py +++ b/tests/unit/context/test_providers.py @@ -46,7 +46,7 @@ def test_resolve_limit(self, resolver, empty, expected_resolve_limit): assert resolver.resolve_limit == expected_resolve_limit @pytest.mark.parametrize( - "use_microbatch_batches,materialized,incremental_strategy,sample_mode_available,run_sample_mode,sample_window,resolver_model_node,expect_filter", + "use_microbatch_batches,materialized,incremental_strategy,sample_mode_available,sample,resolver_model_node,expect_filter", [ # Microbatch model without sample ( @@ -54,7 +54,6 @@ def test_resolve_limit(self, resolver, empty, expected_resolve_limit): "incremental", "microbatch", True, - False, None, True, True, @@ -65,7 +64,6 @@ def test_resolve_limit(self, resolver, empty, expected_resolve_limit): "incremental", "microbatch", True, - True, SampleWindow( start=datetime(2024, 1, 1, tzinfo=pytz.UTC), end=datetime(2025, 1, 1, tzinfo=pytz.UTC), @@ -79,7 +77,6 @@ def test_resolve_limit(self, resolver, empty, expected_resolve_limit): "table", None, True, - True, SampleWindow( start=datetime(2024, 1, 1, tzinfo=pytz.UTC), end=datetime(2025, 1, 1, tzinfo=pytz.UTC), @@ -93,7 +90,6 @@ def test_resolve_limit(self, resolver, empty, expected_resolve_limit): "incremental", "merge", True, - True, SampleWindow( start=datetime(2024, 1, 1, tzinfo=pytz.UTC), end=datetime(2025, 1, 1, tzinfo=pytz.UTC), @@ -107,7 +103,6 @@ def test_resolve_limit(self, resolver, empty, expected_resolve_limit): "table", None, False, - True, SampleWindow( start=datetime(2024, 1, 1, tzinfo=pytz.UTC), end=datetime(2025, 1, 1, tzinfo=pytz.UTC), @@ -121,7 +116,6 @@ def test_resolve_limit(self, resolver, empty, expected_resolve_limit): "table", None, True, - True, SampleWindow( start=datetime(2024, 1, 1, tzinfo=pytz.UTC), end=datetime(2025, 1, 1, tzinfo=pytz.UTC), @@ -135,7 +129,6 @@ def test_resolve_limit(self, resolver, empty, expected_resolve_limit): "incremental", "microbatch", False, - False, None, False, False, @@ -146,7 +139,6 @@ def test_resolve_limit(self, resolver, empty, expected_resolve_limit): "incremental", "microbatch", False, - False, None, True, False, @@ -157,13 +149,12 @@ def test_resolve_limit(self, resolver, empty, expected_resolve_limit): "table", "microbatch", False, - False, None, True, False, ), # Incremental merge - (True, "incremental", "merge", False, False, None, True, False), + (True, "incremental", "merge", False, None, True, False), ], ) def test_resolve_event_time_filter( @@ -174,8 +165,7 @@ def test_resolve_event_time_filter( materialized: str, incremental_strategy: Optional[str], sample_mode_available: bool, - run_sample_mode: bool, - sample_window: Optional[SampleWindow], + sample: Optional[SampleWindow], resolver_model_node: bool, expect_filter: bool, ) -> None: @@ -191,8 +181,7 @@ def test_resolve_event_time_filter( # Resolver mocking resolver.config.args.EVENT_TIME_END = None resolver.config.args.EVENT_TIME_START = None - resolver.config.args.sample = run_sample_mode - resolver.config.args.sample_window = sample_window + resolver.config.args.sample = sample if resolver_model_node: resolver.model = mock.MagicMock(spec=ModelNode) resolver.model.batch = BatchContext(