Skip to content

Commit

Permalink
refactor: Rework get_online_features helper functions (#5060)
Browse files Browse the repository at this point in the history
* rework _populate_response_from_feature_data

Signed-off-by: Artem Petrov <[email protected]>

* rework _convert_rows_to_protobuf

Signed-off-by: Artem Petrov <[email protected]>

* fix typing

Signed-off-by: Artem Petrov <[email protected]>

---------

Signed-off-by: Artem Petrov <[email protected]>
  • Loading branch information
wckdman authored Feb 16, 2025
1 parent 0fffe21 commit 6bf7516
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 77 deletions.
3 changes: 2 additions & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2018,7 +2018,7 @@ def _retrieve_from_online_store_v2(
entity_key_dict[key] = []
entity_key_dict[key].append(python_value)

table_entity_values, idxs = utils._get_unique_entities_from_values(
table_entity_values, idxs, output_len = utils._get_unique_entities_from_values(
entity_key_dict,
)

Expand All @@ -2040,6 +2040,7 @@ def _retrieve_from_online_store_v2(
full_feature_names=False,
requested_features=features_to_request,
table=table,
output_len=output_len,
)

return OnlineResponse(online_features_response)
Expand Down
10 changes: 6 additions & 4 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def get_online_features(

for table, requested_features in grouped_refs:
# Get the correct set of entity values with the correct join keys.
table_entity_values, idxs = utils._get_unique_entities(
table_entity_values, idxs, output_len = utils._get_unique_entities(
table,
join_key_values,
entity_name_to_join_key_map,
Expand Down Expand Up @@ -215,6 +215,7 @@ def get_online_features(
full_feature_names,
requested_features,
table,
output_len,
)

if requested_on_demand_feature_views:
Expand Down Expand Up @@ -274,7 +275,7 @@ async def get_online_features_async(

async def query_table(table, requested_features):
# Get the correct set of entity values with the correct join keys.
table_entity_values, idxs = utils._get_unique_entities(
table_entity_values, idxs, output_len = utils._get_unique_entities(
table,
join_key_values,
entity_name_to_join_key_map,
Expand All @@ -290,7 +291,7 @@ async def query_table(table, requested_features):
requested_features=requested_features,
)

return idxs, read_rows
return idxs, read_rows, output_len

all_responses = await asyncio.gather(
*[
Expand All @@ -299,7 +300,7 @@ async def query_table(table, requested_features):
]
)

for (idxs, read_rows), (table, requested_features) in zip(
for (idxs, read_rows, output_len), (table, requested_features) in zip(
all_responses, grouped_refs
):
feature_data = utils._convert_rows_to_protobuf(
Expand All @@ -314,6 +315,7 @@ async def query_table(table, requested_features):
full_feature_names,
requested_features,
table,
output_len,
)

if requested_on_demand_feature_views:
Expand Down
137 changes: 67 additions & 70 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,16 +490,28 @@ def _group_feature_refs(
return fvs_result, odfvs_result


def apply_list_mapping(
lst: Iterable[Any], mapping_indexes: Iterable[List[int]]
) -> Iterable[Any]:
output_len = sum(len(item) for item in mapping_indexes)
output = [None] * output_len
for elem, destinations in zip(lst, mapping_indexes):
def construct_response_feature_vector(
values_vector: Iterable[Any],
statuses_vector: Iterable[Any],
timestamp_vector: Iterable[Any],
mapping_indexes: Iterable[List[int]],
output_len: int,
) -> GetOnlineFeaturesResponse.FeatureVector:
values_output: Iterable[Any] = [None] * output_len
statuses_output: Iterable[Any] = [None] * output_len
timestamp_output: Iterable[Any] = [None] * output_len

for i, destinations in enumerate(mapping_indexes):
for idx in destinations:
output[idx] = elem

return output
values_output[idx] = values_vector[i] # type: ignore[index]
statuses_output[idx] = statuses_vector[i] # type: ignore[index]
timestamp_output[idx] = timestamp_vector[i] # type: ignore[index]

return GetOnlineFeaturesResponse.FeatureVector(
values=values_output,
statuses=statuses_output,
event_timestamps=timestamp_output,
)


def _augment_response_with_on_demand_transforms(
Expand Down Expand Up @@ -674,7 +686,7 @@ def _get_unique_entities(
table: "FeatureView",
join_key_values: Dict[str, List[ValueProto]],
entity_name_to_join_key_map: Dict[str, str],
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...]]:
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...], int]:
"""Return the set of unique composite Entities for a Feature View and the indexes at which they appear.
This method allows us to query the OnlineStore for data we need only once
Expand Down Expand Up @@ -712,7 +724,7 @@ def _get_unique_entities(

# If there are no rows, return empty tuples.
if not rowise:
return (), ()
return (), (), 0

# Sort rowise so that rows with the same join key values are adjacent.
rowise.sort(key=lambda row: tuple(getattr(x, x.WhichOneof("val")) for x in row[1]))
Expand All @@ -725,16 +737,16 @@ def _get_unique_entities(

# If no groups were formed (should not happen for valid input), return empty tuples.
if not groups:
return (), ()
return (), (), 0

# Unpack the unique entities and their original row indexes.
unique_entities, indexes = tuple(zip(*groups))
return unique_entities, indexes
return unique_entities, indexes, len(rowise)


def _get_unique_entities_from_values(
table_entity_values: Dict[str, List[ValueProto]],
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...]]:
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...], int]:
"""Return the set of unique composite Entities for a Feature View and the indexes at which they appear.
This method allows us to query the OnlineStore for data we need only once
Expand All @@ -758,7 +770,7 @@ def _get_unique_entities_from_values(
]
)
)
return unique_entities, indexes
return unique_entities, indexes, len(rowise)


def _drop_unneeded_columns(
Expand Down Expand Up @@ -835,6 +847,7 @@ def _populate_response_from_feature_data(
full_feature_names: bool,
requested_features: Iterable[str],
table: "FeatureView",
output_len: int,
):
"""Populate the GetOnlineFeaturesResponse with feature data.
Expand All @@ -853,33 +866,22 @@ def _populate_response_from_feature_data(
requested_features: The names of the features in `feature_data`. This should be ordered in the same way as the
data in `feature_data`.
table: The FeatureView that `feature_data` was retrieved from.
output_len: The number of result rows in `online_features_response`.
"""
# Add the feature names to the response.
table_name = table.projection.name_to_use()
requested_feature_refs = [
(
f"{table.projection.name_to_use()}__{feature_name}"
if full_feature_names
else feature_name
)
f"{table_name}__{feature_name}" if full_feature_names else feature_name
for feature_name in requested_features
]
online_features_response.metadata.feature_names.val.extend(requested_feature_refs)

timestamps, statuses, values = zip(*feature_data)

# Populate the result with data fetched from the OnlineStore
# which is guaranteed to be aligned with `requested_features`.
for (
feature_idx,
(timestamp_vector, statuses_vector, values_vector),
) in enumerate(zip(zip(*timestamps), zip(*statuses), zip(*values))):
online_features_response.results.append(
GetOnlineFeaturesResponse.FeatureVector(
values=apply_list_mapping(values_vector, indexes),
statuses=apply_list_mapping(statuses_vector, indexes),
event_timestamps=apply_list_mapping(timestamp_vector, indexes),
)
# Process each feature vector in a single pass
for timestamp_vector, statuses_vector, values_vector in feature_data:
response_vector = construct_response_feature_vector(
values_vector, statuses_vector, timestamp_vector, indexes, output_len
)
online_features_response.results.append(response_vector)


def _populate_response_from_feature_data_v2(
Expand All @@ -891,6 +893,7 @@ def _populate_response_from_feature_data_v2(
indexes: Iterable[List[int]],
online_features_response: GetOnlineFeaturesResponse,
requested_features: Iterable[str],
output_len: int,
):
"""Populate the GetOnlineFeaturesResponse with feature data.
Expand All @@ -908,6 +911,7 @@ def _populate_response_from_feature_data_v2(
"customer_fv__daily_transactions").
requested_features: The names of the features in `feature_data`. This should be ordered in the same way as the
data in `feature_data`.
output_len: The number of result rows in `online_features_response`.
"""
# Add the feature names to the response.
requested_feature_refs = [(feature_name) for feature_name in requested_features]
Expand All @@ -917,17 +921,11 @@ def _populate_response_from_feature_data_v2(

# Populate the result with data fetched from the OnlineStore
# which is guaranteed to be aligned with `requested_features`.
for (
feature_idx,
(timestamp_vector, statuses_vector, values_vector),
) in enumerate(zip(zip(*timestamps), zip(*statuses), zip(*values))):
online_features_response.results.append(
GetOnlineFeaturesResponse.FeatureVector(
values=apply_list_mapping(values_vector, indexes),
statuses=apply_list_mapping(statuses_vector, indexes),
event_timestamps=apply_list_mapping(timestamp_vector, indexes),
)
for timestamp_vector, statuses_vector, values_vector in feature_data:
response_vector = construct_response_feature_vector(
values_vector, statuses_vector, timestamp_vector, indexes, output_len
)
online_features_response.results.append(response_vector)


def _convert_entity_key_to_proto_to_dict(
Expand Down Expand Up @@ -1246,33 +1244,32 @@ def _convert_rows_to_protobuf(
requested_features: List[str],
read_rows: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]],
) -> List[Tuple[List[Timestamp], List["FieldStatus.ValueType"], List[ValueProto]]]:
# Each row is a set of features for a given entity key.
# We only need to convert the data to Protobuf once.
# Pre-calculate the length to avoid repeated calculations
n_rows = len(read_rows)

# Create single instances of commonly used values
null_value = ValueProto()
read_row_protos = []
for read_row in read_rows:
row_ts_proto = Timestamp()
row_ts, feature_data = read_row
# TODO (Ly): reuse whatever timestamp if row_ts is None?
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)
event_timestamps = [row_ts_proto] * len(requested_features)
if feature_data is None:
statuses = [FieldStatus.NOT_FOUND] * len(requested_features)
values = [null_value] * len(requested_features)
else:
statuses = []
values = []
for feature_name in requested_features:
# Make sure order of data is the same as requested_features.
if feature_name not in feature_data:
statuses.append(FieldStatus.NOT_FOUND)
values.append(null_value)
else:
statuses.append(FieldStatus.PRESENT)
values.append(feature_data[feature_name])
read_row_protos.append((event_timestamps, statuses, values))
return read_row_protos
null_status = FieldStatus.NOT_FOUND
null_timestamp = Timestamp()
present_status = FieldStatus.PRESENT

requested_features_vectors = []
for feature_name in requested_features:
ts_vector = [null_timestamp] * n_rows
status_vector = [null_status] * n_rows
value_vector = [null_value] * n_rows
for idx, read_row in enumerate(read_rows):
row_ts_proto = Timestamp()
row_ts, feature_data = read_row
# TODO (Ly): reuse whatever timestamp if row_ts is None?
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)
ts_vector[idx] = row_ts_proto
if (feature_data is not None) and (feature_name in feature_data):
status_vector[idx] = present_status
value_vector[idx] = feature_data[feature_name]
requested_features_vectors.append((ts_vector, status_vector, value_vector))
return requested_features_vectors


def has_all_tags(
Expand Down
6 changes: 4 additions & 2 deletions sdk/python/tests/unit/test_unit_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_get_unique_entities_success():
projection=MockFeatureViewProjection(join_key_map={}),
)

unique_entities, indexes = utils._get_unique_entities(
unique_entities, indexes, output_len = utils._get_unique_entities(
table=fv,
join_key_values=entity_values,
entity_name_to_join_key_map=entity_name_to_join_key_map,
Expand All @@ -51,6 +51,7 @@ def test_get_unique_entities_success():

assert unique_entities == expected_entities
assert indexes == expected_indexes
assert output_len == 3


def test_get_unique_entities_missing_join_key_success():
Expand All @@ -74,7 +75,7 @@ def test_get_unique_entities_missing_join_key_success():
projection=MockFeatureViewProjection(join_key_map={}),
)

unique_entities, indexes = utils._get_unique_entities(
unique_entities, indexes, output_len = utils._get_unique_entities(
table=fv,
join_key_values=entity_values,
entity_name_to_join_key_map=entity_name_to_join_key_map,
Expand All @@ -87,6 +88,7 @@ def test_get_unique_entities_missing_join_key_success():

assert unique_entities == expected_entities
assert indexes == expected_indexes
assert output_len == 3
# We're not say anything about the entity_1 missing from the unique_entities list
assert "entity_1" not in [entity.keys() for entity in unique_entities]

Expand Down

0 comments on commit 6bf7516

Please sign in to comment.