Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine --sample and --sample-window into one CLI param #11303

Merged
merged 2 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20250212-155658.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Combine `--sample` and `--sample-window` CLI params
time: 2025-02-12T15:56:58.546879-06:00
custom:
Author: QMalcolm
Issue: "11299"
1 change: 0 additions & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ def parse(ctx, **kwargs):
@p.event_time_start
@p.event_time_end
@p.sample
@p.sample_window
@p.select
@p.selector
@p.target_path
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/cli/option_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def convert(self, value, param, ctx):
return value


class SampleWindowType(ParamType):
name = "SAMPLE_WINDOW"
class SampleType(ParamType):
name = "SAMPLE"

def convert(
self, value, param: Optional[Parameter], ctx: Optional[Context]
Expand Down
17 changes: 4 additions & 13 deletions core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
YAML,
ChoiceTuple,
Package,
SampleWindowType,
SampleType,
WarnErrorOptionsType,
)
from dbt.cli.options import MultiOption
Expand Down Expand Up @@ -525,20 +525,11 @@
)

sample = click.option(
"--sample/--no-sample",
"--sample",
envvar="DBT_SAMPLE",
help="Run in sample mode, creating only samples of models where possible",
default=False,
is_flag=True,
hidden=True, # TODO: Unhide
)

sample_window = click.option(
"--sample-window",
envvar="DBT_SAMPLE_WINDOW",
help="The time window to use with sample mode. Example: '3 days'.",
help="Run in sample mode with given SAMPLE_WINDOW spec, such that ref/source calls are sampled by the sample window.",
default=None,
type=SampleWindowType(),
type=SampleType(),
hidden=True, # TODO: Unhide
)

Expand Down
15 changes: 7 additions & 8 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ def resolve_event_time_filter(self, target: ManifestNode) -> Optional[EventTimeF
event_time_filter = None
sample_mode = bool(
os.environ.get("DBT_EXPERIMENTAL_SAMPLE_MODE")
and getattr(self.config.args, "sample", False)
and getattr(self.config.args, "sample_window", None)
and getattr(self.config.args, "sample", None)
)

# TODO The number of branches here is getting rough. We should consider ways to simplify
Expand All @@ -263,13 +262,13 @@ def resolve_event_time_filter(self, target: ManifestNode) -> Optional[EventTimeF
# Sample mode microbatch models
if sample_mode:
start = (
self.config.args.sample_window.start
if self.config.args.sample_window.start > self.model.batch.event_time_start
self.config.args.sample.start
if self.config.args.sample.start > self.model.batch.event_time_start
else self.model.batch.event_time_start
)
end = (
self.config.args.sample_window.end
if self.config.args.sample_window.end < self.model.batch.event_time_end
self.config.args.sample.end
if self.config.args.sample.end < self.model.batch.event_time_end
else self.model.batch.event_time_end
)
event_time_filter = EventTimeFilter(
Expand All @@ -290,8 +289,8 @@ def resolve_event_time_filter(self, target: ManifestNode) -> Optional[EventTimeF
elif sample_mode:
event_time_filter = EventTimeFilter(
field_name=target.config.event_time,
start=self.config.args.sample_window.start,
end=self.config.args.sample_window.end,
start=self.config.args.sample.start,
end=self.config.args.sample.end,
)

return event_time_filter
Expand Down
10 changes: 4 additions & 6 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,13 +561,11 @@
event_time_start = getattr(self.config.args, "EVENT_TIME_START", None)
event_time_end = getattr(self.config.args, "EVENT_TIME_END", None)

if (
os.environ.get("DBT_EXPERIMENTAL_SAMPLE_MODE")
and getattr(self.config.args, "SAMPLE", None)
and getattr(self.config.args, "SAMPLE_WINDOW", None)
if os.environ.get("DBT_EXPERIMENTAL_SAMPLE_MODE") and getattr(

Check warning on line 564 in core/dbt/task/run.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/run.py#L564

Added line #L564 was not covered by tests
self.config.args, "SAMPLE", None
):
event_time_start = self.config.args.sample_window.start
event_time_end = self.config.args.sample_window.end
event_time_start = self.config.args.sample.start
event_time_end = self.config.args.sample.end

Check warning on line 568 in core/dbt/task/run.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/run.py#L567-L568

Added lines #L567 - L568 were not covered by tests

microbatch_builder = MicrobatchBuilder(
model=model,
Expand Down
67 changes: 31 additions & 36 deletions tests/functional/sample_mode/test_sample_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
{{ config(materialized='table', event_time='event_time') }}

{% if execute %}
{{ log("Sample mode: " ~ invocation_args_dict.get("sample"), info=true) }}
{{ log("Sample window: " ~ invocation_args_dict.get("sample_window"), info=true) }}
{{ log("Sample: " ~ invocation_args_dict.get("sample"), info=true) }}
{% endif %}

SELECT * FROM {{ ref("input_model") }}
Expand All @@ -65,7 +64,7 @@

{% if execute %}
{{ log("is_incremental: " ~ is_incremental(), info=true) }}
{{ log("sample window: " ~ invocation_args_dict.get("sample_window"), info=true) }}
{{ log("sample: " ~ invocation_args_dict.get("sample"), info=true) }}
{% endif %}

SELECT * FROM {{ ref("input_model") }}
Expand Down Expand Up @@ -102,12 +101,12 @@ def event_catcher(self) -> EventCatcher:
return EventCatcher(event_to_catch=JinjaLogInfo) # type: ignore

@pytest.mark.parametrize(
"sample_mode_available,run_sample_mode,expected_row_count,arg_value_in_jinja",
"sample_mode_available,run_sample_mode,expected_row_count",
[
(True, True, 2, True),
(True, False, 3, False),
(False, True, 3, True),
(False, False, 3, False),
(True, True, 2),
(True, False, 3),
(False, True, 3),
(False, False, 3),
],
)
@freezegun.freeze_time("2025-01-03T02:03:0Z")
Expand All @@ -119,13 +118,12 @@ def test_sample_mode(
sample_mode_available: bool,
run_sample_mode: bool,
expected_row_count: int,
arg_value_in_jinja: bool,
):
run_args = ["run"]
expected_sample_window = None
expected_sample = None
if run_sample_mode:
run_args.extend(["--sample", "--sample-window=1 day"])
expected_sample_window = SampleWindow(
run_args.append("--sample=1 day")
expected_sample = SampleWindow(
start=datetime(2025, 1, 2, 2, 3, 0, 0, tzinfo=pytz.UTC),
end=datetime(2025, 1, 3, 2, 3, 0, 0, tzinfo=pytz.UTC),
)
Expand All @@ -134,9 +132,8 @@ def test_sample_mode(
mocker.patch.dict(os.environ, {"DBT_EXPERIMENTAL_SAMPLE_MODE": "1"})

_ = run_dbt(run_args, callbacks=[event_catcher.catch])
assert len(event_catcher.caught_events) == 2
assert event_catcher.caught_events[0].info.msg == f"Sample mode: {arg_value_in_jinja}" # type: ignore
assert event_catcher.caught_events[1].info.msg == f"Sample window: {expected_sample_window}" # type: ignore
assert len(event_catcher.caught_events) == 1
assert event_catcher.caught_events[0].info.msg == f"Sample: {expected_sample}" # type: ignore
self.assert_row_count(
project=project,
relation_name="sample_mode_model",
Expand Down Expand Up @@ -207,7 +204,7 @@ def test_sample_mode(
if sample_mode_available:
mocker.patch.dict(os.environ, {"DBT_EXPERIMENTAL_SAMPLE_MODE": "True"})
_ = run_dbt(
["run", "--sample", "--sample-window=2 day"],
["run", "--sample=2 day"],
callbacks=[event_time_end_catcher.catch, event_time_start_catcher.catch],
)
assert len(event_time_start_catcher.caught_events) == len(expected_batches)
Expand Down Expand Up @@ -254,12 +251,12 @@ def event_catcher(self) -> EventCatcher:
return EventCatcher(event_to_catch=JinjaLogInfo, predicate=lambda event: "is_incremental: True" in event.info.msg) # type: ignore

@pytest.mark.parametrize(
"sample_mode_available,run_sample_mode,sample_window,expected_rows",
"sample_mode_available,sample,expected_rows",
[
(True, False, None, 6),
(True, True, "3 days", 6),
(True, True, "2 days", 5),
(False, True, "2 days", 6),
(True, None, 6),
(True, "3 days", 6),
(True, "2 days", 5),
(False, "2 days", 6),
],
)
@freezegun.freeze_time("2025-01-06T18:03:0Z")
Expand All @@ -269,8 +266,7 @@ def test_incremental_model_sample(
mocker: MockerFixture,
event_catcher: EventCatcher,
sample_mode_available: bool,
run_sample_mode: bool,
sample_window: Optional[str],
sample: Optional[str],
expected_rows: int,
):
# writing the input_model is necessary because we've parametrized the test
Expand All @@ -293,8 +289,8 @@ def test_incremental_model_sample(
write_file(later_input_model_sql, "models", "input_model.sql")

run_args = ["run"]
if run_sample_mode:
run_args.extend(["--sample", f"--sample-window={sample_window}"])
if sample is not None:
run_args.extend([f"--sample={sample}"])

_ = run_dbt(run_args, callbacks=[event_catcher.catch])

Expand Down Expand Up @@ -322,14 +318,14 @@ def event_catcher(self) -> EventCatcher:
return EventCatcher(event_to_catch=JinjaLogInfo, predicate=lambda event: "is_incremental: True" in event.info.msg) # type: ignore

@pytest.mark.parametrize(
"sample_mode_available,run_sample_mode,sample_window,expected_rows",
"sample_mode_available,sample,expected_rows",
[
(True, False, None, 6),
(True, True, "{'start': '2025-01-03', 'end': '2025-01-07'}", 6),
(True, True, "{'start': '2025-01-04', 'end': '2025-01-06'}", 5),
(True, True, "{'start': '2025-01-05', 'end': '2025-01-07'}", 5),
(True, True, "{'start': '2024-12-31', 'end': '2025-01-03'}", 3),
(False, True, "{'start': '2024-12-31', 'end': '2025-01-03'}", 6),
(True, None, 6),
(True, "{'start': '2025-01-03', 'end': '2025-01-07'}", 6),
(True, "{'start': '2025-01-04', 'end': '2025-01-06'}", 5),
(True, "{'start': '2025-01-05', 'end': '2025-01-07'}", 5),
(True, "{'start': '2024-12-31', 'end': '2025-01-03'}", 3),
(False, "{'start': '2024-12-31', 'end': '2025-01-03'}", 6),
],
)
def test_incremental_model_sample(
Expand All @@ -338,8 +334,7 @@ def test_incremental_model_sample(
mocker: MockerFixture,
event_catcher: EventCatcher,
sample_mode_available: bool,
run_sample_mode: bool,
sample_window: Optional[str],
sample: Optional[str],
expected_rows: int,
):
# writing the input_model is necessary because we've parametrized the test
Expand All @@ -362,8 +357,8 @@ def test_incremental_model_sample(
write_file(later_input_model_sql, "models", "input_model.sql")

run_args = ["run"]
if run_sample_mode:
run_args.extend(["--sample", f"--sample-window={sample_window}"])
if sample is not None:
run_args.extend([f"--sample={sample}"])

_ = run_dbt(run_args, callbacks=[event_catcher.catch])

Expand Down
8 changes: 4 additions & 4 deletions tests/unit/cli/test_option_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytz
from click import BadParameter, Option

from dbt.cli.option_types import YAML, SampleWindowType
from dbt.cli.option_types import YAML, SampleType
from dbt.event_time.sample_window import SampleWindow


Expand All @@ -32,7 +32,7 @@ def test_yaml_init_invalid_yaml_str(self, invalid_yaml_str):
assert "--vars" in e.value.format_message()


class TestSampleWindowType:
class TestSampleType:
@pytest.mark.parametrize(
"input,expected_result",
[
Expand Down Expand Up @@ -61,7 +61,7 @@ class TestSampleWindowType:
)
def test_convert(self, input: str, expected_result: Union[SampleWindow, Exception]):
try:
result = SampleWindowType().convert(input, Option(["--sample-window"]), None)
result = SampleType().convert(input, Option(["--sample"]), None)
assert result == expected_result
except Exception as e:
assert str(e) == str(expected_result)
Expand All @@ -76,5 +76,5 @@ def test_convert_relative(self):
start=datetime(2025, 1, 25, 2, 3, 0, 0, pytz.UTC),
end=datetime(2025, 1, 28, 2, 3, 0, 0, pytz.UTC),
)
result = SampleWindowType().convert(input, Option(["--sample-window"]), None)
result = SampleType().convert(input, Option(["--sample"]), None)
assert result == expected_result
Loading
Loading