Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement date_partition_column for SparkSource #4844

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def pull_latest_from_table_or_query(
fields_as_string = ", ".join(fields_with_aliases)
aliases_as_string = ", ".join(aliases)

date_partition_column = data_source.date_partition_column

start_date_str = _format_datetime(start_date)
end_date_str = _format_datetime(end_date)
query = f"""
Expand All @@ -109,7 +111,7 @@ def pull_latest_from_table_or_query(
SELECT {fields_as_string},
ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS feast_row_
FROM {from_expression} t1
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}')
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){" AND "+date_partition_column+" >= '"+start_date.strftime('%Y-%m-%d')+"' AND "+date_partition_column+" <= '"+end_date.strftime('%Y-%m-%d')+"' " if date_partition_column != "" and date_partition_column is not None else ''}
) t2
WHERE feast_row_ = 1
"""
Expand Down Expand Up @@ -641,8 +643,15 @@ def _cast_data_frame(
{% endfor %}
FROM {{ featureview.table_subquery }}
WHERE {{ featureview.timestamp_field }} <= '{{ featureview.max_event_timestamp }}'
{% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %}
AND {{ featureview.date_partition_column }} <= '{{ featureview.max_event_timestamp[:10] }}'
{% endif %}

{% if featureview.ttl == 0 %}{% else %}
AND {{ featureview.timestamp_field }} >= '{{ featureview.min_event_timestamp }}'
{% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %}
AND {{ featureview.date_partition_column }} >= '{{ featureview.min_event_timestamp[:10] }}'
{% endif %}
{% endif %}
),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
tags: Optional[Dict[str, str]] = None,
owner: Optional[str] = "",
timestamp_field: Optional[str] = None,
date_partition_column: Optional[str] = None,
):
"""Creates a SparkSource object.

Expand All @@ -64,6 +65,8 @@ def __init__(
maintainer.
timestamp_field: Event timestamp field used for point-in-time joins of
feature values.
date_partition_column: The column to partition the data on for faster
retrieval. This is useful for large tables and will limit the number ofi
"""
# If no name, use the table as the default name.
if name is None and table is None:
Expand All @@ -77,6 +80,7 @@ def __init__(
created_timestamp_column=created_timestamp_column,
field_mapping=field_mapping,
description=description,
date_partition_column=date_partition_column,
tags=tags,
owner=owner,
)
Expand Down Expand Up @@ -135,6 +139,7 @@ def from_proto(data_source: DataSourceProto) -> Any:
query=spark_options.query,
path=spark_options.path,
file_format=spark_options.file_format,
date_partition_column=data_source.date_partition_column,
timestamp_field=data_source.timestamp_field,
created_timestamp_column=data_source.created_timestamp_column,
description=data_source.description,
Expand All @@ -148,6 +153,7 @@ def to_proto(self) -> DataSourceProto:
type=DataSourceProto.BATCH_SPARK,
data_source_class_type="feast.infra.offline_stores.contrib.spark_offline_store.spark_source.SparkSource",
field_mapping=self.field_mapping,
date_partition_column=self.date_partition_column,
spark_options=self.spark_options.to_proto(),
description=self.description,
tags=self.tags,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,68 @@ def test_pull_latest_from_table_with_nested_timestamp_or_query(mock_get_spark_se
assert retrieval_job.query.strip() == expected_query.strip()


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_pull_latest_from_table_with_nested_timestamp_or_query_and_date_partition_column_set(
mock_get_spark_session,
):
mock_spark_session = MagicMock()
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source = SparkSource(
name="test_nested_batch_source",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="nested_timestamp",
field_mapping={
"event_header.event_published_datetime_utc": "nested_timestamp",
},
date_partition_column="effective_date",
)

# Define the parameters for the method
join_key_columns = ["key1", "key2"]
feature_name_columns = ["feature1", "feature2"]
timestamp_field = "event_header.event_published_datetime_utc"
created_timestamp_column = "created_timestamp"
start_date = datetime(2021, 1, 1)
end_date = datetime(2021, 1, 2)

# Call the method
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
config=test_repo_config,
data_source=test_data_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)

expected_query = """SELECT
key1, key2, feature1, feature2, nested_timestamp, created_timestamp

FROM (
SELECT key1, key2, feature1, feature2, event_header.event_published_datetime_utc AS nested_timestamp, created_timestamp,
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_header.event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
FROM `offline_store_database_name`.`offline_store_table_name` t1
WHERE event_header.event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000') AND effective_date >= '2021-01-01' AND effective_date <= '2021-01-02'
) t2
WHERE feast_row_ = 1""" # noqa: W293, W291

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
Expand Down Expand Up @@ -127,3 +189,62 @@ def test_pull_latest_from_table_without_nested_timestamp_or_query(

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_pull_latest_from_table_without_nested_timestamp_or_query_and_date_partition_column_set(
mock_get_spark_session,
):
mock_spark_session = MagicMock()
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source = SparkSource(
name="test_batch_source",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="event_published_datetime_utc",
date_partition_column="effective_date",
)

# Define the parameters for the method
join_key_columns = ["key1", "key2"]
feature_name_columns = ["feature1", "feature2"]
timestamp_field = "event_published_datetime_utc"
created_timestamp_column = "created_timestamp"
start_date = datetime(2021, 1, 1)
end_date = datetime(2021, 1, 2)

# Call the method
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
config=test_repo_config,
data_source=test_data_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)

expected_query = """SELECT
key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp

FROM (
SELECT key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp,
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
FROM `offline_store_database_name`.`offline_store_table_name` t1
WHERE event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000') AND effective_date >= '2021-01-01' AND effective_date <= '2021-01-02'
) t2
WHERE feast_row_ = 1""" # noqa: W293, W291

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()
Loading