From 6bf75166fe6966595faeec160a6708fa5715e9b1 Mon Sep 17 00:00:00 2001 From: Artem <58334441+wckdman@users.noreply.github.com> Date: Sun, 16 Feb 2025 16:05:59 +0100 Subject: [PATCH] refactor: Rework `get_online_features` helper functions (#5060) * rework _populate_response_from_feature_data Signed-off-by: Artem Petrov <58334441+wckdman@users.noreply.github.com> * rework _convert_rows_to_protobuf Signed-off-by: Artem Petrov <58334441+wckdman@users.noreply.github.com> * fix typing Signed-off-by: Artem Petrov <58334441+wckdman@users.noreply.github.com> --------- Signed-off-by: Artem Petrov <58334441+wckdman@users.noreply.github.com> --- sdk/python/feast/feature_store.py | 3 +- .../feast/infra/online_stores/online_store.py | 10 +- sdk/python/feast/utils.py | 137 +++++++++--------- .../tests/unit/test_unit_feature_store.py | 6 +- 4 files changed, 79 insertions(+), 77 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 0f092538cf0..f0bdc4c1f28 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -2018,7 +2018,7 @@ def _retrieve_from_online_store_v2( entity_key_dict[key] = [] entity_key_dict[key].append(python_value) - table_entity_values, idxs = utils._get_unique_entities_from_values( + table_entity_values, idxs, output_len = utils._get_unique_entities_from_values( entity_key_dict, ) @@ -2040,6 +2040,7 @@ def _retrieve_from_online_store_v2( full_feature_names=False, requested_features=features_to_request, table=table, + output_len=output_len, ) return OnlineResponse(online_features_response) diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index a86fdba4017..f5202b66f66 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -187,7 +187,7 @@ def get_online_features( for table, requested_features in grouped_refs: # Get the correct set of entity values with the correct join keys. - table_entity_values, idxs = utils._get_unique_entities( + table_entity_values, idxs, output_len = utils._get_unique_entities( table, join_key_values, entity_name_to_join_key_map, @@ -215,6 +215,7 @@ def get_online_features( full_feature_names, requested_features, table, + output_len, ) if requested_on_demand_feature_views: @@ -274,7 +275,7 @@ async def get_online_features_async( async def query_table(table, requested_features): # Get the correct set of entity values with the correct join keys. - table_entity_values, idxs = utils._get_unique_entities( + table_entity_values, idxs, output_len = utils._get_unique_entities( table, join_key_values, entity_name_to_join_key_map, @@ -290,7 +291,7 @@ async def query_table(table, requested_features): requested_features=requested_features, ) - return idxs, read_rows + return idxs, read_rows, output_len all_responses = await asyncio.gather( *[ @@ -299,7 +300,7 @@ async def query_table(table, requested_features): ] ) - for (idxs, read_rows), (table, requested_features) in zip( + for (idxs, read_rows, output_len), (table, requested_features) in zip( all_responses, grouped_refs ): feature_data = utils._convert_rows_to_protobuf( @@ -314,6 +315,7 @@ async def query_table(table, requested_features): full_feature_names, requested_features, table, + output_len, ) if requested_on_demand_feature_views: diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 86cb08ec932..e64e38b143a 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -490,16 +490,28 @@ def _group_feature_refs( return fvs_result, odfvs_result -def apply_list_mapping( - lst: Iterable[Any], mapping_indexes: Iterable[List[int]] -) -> Iterable[Any]: - output_len = sum(len(item) for item in mapping_indexes) - output = [None] * output_len - for elem, destinations in zip(lst, mapping_indexes): +def construct_response_feature_vector( + values_vector: Iterable[Any], + statuses_vector: Iterable[Any], + timestamp_vector: Iterable[Any], + mapping_indexes: Iterable[List[int]], + output_len: int, +) -> GetOnlineFeaturesResponse.FeatureVector: + values_output: Iterable[Any] = [None] * output_len + statuses_output: Iterable[Any] = [None] * output_len + timestamp_output: Iterable[Any] = [None] * output_len + + for i, destinations in enumerate(mapping_indexes): for idx in destinations: - output[idx] = elem - - return output + values_output[idx] = values_vector[i] # type: ignore[index] + statuses_output[idx] = statuses_vector[i] # type: ignore[index] + timestamp_output[idx] = timestamp_vector[i] # type: ignore[index] + + return GetOnlineFeaturesResponse.FeatureVector( + values=values_output, + statuses=statuses_output, + event_timestamps=timestamp_output, + ) def _augment_response_with_on_demand_transforms( @@ -674,7 +686,7 @@ def _get_unique_entities( table: "FeatureView", join_key_values: Dict[str, List[ValueProto]], entity_name_to_join_key_map: Dict[str, str], -) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...]]: +) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...], int]: """Return the set of unique composite Entities for a Feature View and the indexes at which they appear. This method allows us to query the OnlineStore for data we need only once @@ -712,7 +724,7 @@ def _get_unique_entities( # If there are no rows, return empty tuples. if not rowise: - return (), () + return (), (), 0 # Sort rowise so that rows with the same join key values are adjacent. rowise.sort(key=lambda row: tuple(getattr(x, x.WhichOneof("val")) for x in row[1])) @@ -725,16 +737,16 @@ def _get_unique_entities( # If no groups were formed (should not happen for valid input), return empty tuples. if not groups: - return (), () + return (), (), 0 # Unpack the unique entities and their original row indexes. unique_entities, indexes = tuple(zip(*groups)) - return unique_entities, indexes + return unique_entities, indexes, len(rowise) def _get_unique_entities_from_values( table_entity_values: Dict[str, List[ValueProto]], -) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...]]: +) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...], int]: """Return the set of unique composite Entities for a Feature View and the indexes at which they appear. This method allows us to query the OnlineStore for data we need only once @@ -758,7 +770,7 @@ def _get_unique_entities_from_values( ] ) ) - return unique_entities, indexes + return unique_entities, indexes, len(rowise) def _drop_unneeded_columns( @@ -835,6 +847,7 @@ def _populate_response_from_feature_data( full_feature_names: bool, requested_features: Iterable[str], table: "FeatureView", + output_len: int, ): """Populate the GetOnlineFeaturesResponse with feature data. @@ -853,33 +866,22 @@ def _populate_response_from_feature_data( requested_features: The names of the features in `feature_data`. This should be ordered in the same way as the data in `feature_data`. table: The FeatureView that `feature_data` was retrieved from. + output_len: The number of result rows in `online_features_response`. """ # Add the feature names to the response. + table_name = table.projection.name_to_use() requested_feature_refs = [ - ( - f"{table.projection.name_to_use()}__{feature_name}" - if full_feature_names - else feature_name - ) + f"{table_name}__{feature_name}" if full_feature_names else feature_name for feature_name in requested_features ] online_features_response.metadata.feature_names.val.extend(requested_feature_refs) - timestamps, statuses, values = zip(*feature_data) - - # Populate the result with data fetched from the OnlineStore - # which is guaranteed to be aligned with `requested_features`. - for ( - feature_idx, - (timestamp_vector, statuses_vector, values_vector), - ) in enumerate(zip(zip(*timestamps), zip(*statuses), zip(*values))): - online_features_response.results.append( - GetOnlineFeaturesResponse.FeatureVector( - values=apply_list_mapping(values_vector, indexes), - statuses=apply_list_mapping(statuses_vector, indexes), - event_timestamps=apply_list_mapping(timestamp_vector, indexes), - ) + # Process each feature vector in a single pass + for timestamp_vector, statuses_vector, values_vector in feature_data: + response_vector = construct_response_feature_vector( + values_vector, statuses_vector, timestamp_vector, indexes, output_len ) + online_features_response.results.append(response_vector) def _populate_response_from_feature_data_v2( @@ -891,6 +893,7 @@ def _populate_response_from_feature_data_v2( indexes: Iterable[List[int]], online_features_response: GetOnlineFeaturesResponse, requested_features: Iterable[str], + output_len: int, ): """Populate the GetOnlineFeaturesResponse with feature data. @@ -908,6 +911,7 @@ def _populate_response_from_feature_data_v2( "customer_fv__daily_transactions"). requested_features: The names of the features in `feature_data`. This should be ordered in the same way as the data in `feature_data`. + output_len: The number of result rows in `online_features_response`. """ # Add the feature names to the response. requested_feature_refs = [(feature_name) for feature_name in requested_features] @@ -917,17 +921,11 @@ def _populate_response_from_feature_data_v2( # Populate the result with data fetched from the OnlineStore # which is guaranteed to be aligned with `requested_features`. - for ( - feature_idx, - (timestamp_vector, statuses_vector, values_vector), - ) in enumerate(zip(zip(*timestamps), zip(*statuses), zip(*values))): - online_features_response.results.append( - GetOnlineFeaturesResponse.FeatureVector( - values=apply_list_mapping(values_vector, indexes), - statuses=apply_list_mapping(statuses_vector, indexes), - event_timestamps=apply_list_mapping(timestamp_vector, indexes), - ) + for timestamp_vector, statuses_vector, values_vector in feature_data: + response_vector = construct_response_feature_vector( + values_vector, statuses_vector, timestamp_vector, indexes, output_len ) + online_features_response.results.append(response_vector) def _convert_entity_key_to_proto_to_dict( @@ -1246,33 +1244,32 @@ def _convert_rows_to_protobuf( requested_features: List[str], read_rows: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]], ) -> List[Tuple[List[Timestamp], List["FieldStatus.ValueType"], List[ValueProto]]]: - # Each row is a set of features for a given entity key. - # We only need to convert the data to Protobuf once. + # Pre-calculate the length to avoid repeated calculations + n_rows = len(read_rows) + + # Create single instances of commonly used values null_value = ValueProto() - read_row_protos = [] - for read_row in read_rows: - row_ts_proto = Timestamp() - row_ts, feature_data = read_row - # TODO (Ly): reuse whatever timestamp if row_ts is None? - if row_ts is not None: - row_ts_proto.FromDatetime(row_ts) - event_timestamps = [row_ts_proto] * len(requested_features) - if feature_data is None: - statuses = [FieldStatus.NOT_FOUND] * len(requested_features) - values = [null_value] * len(requested_features) - else: - statuses = [] - values = [] - for feature_name in requested_features: - # Make sure order of data is the same as requested_features. - if feature_name not in feature_data: - statuses.append(FieldStatus.NOT_FOUND) - values.append(null_value) - else: - statuses.append(FieldStatus.PRESENT) - values.append(feature_data[feature_name]) - read_row_protos.append((event_timestamps, statuses, values)) - return read_row_protos + null_status = FieldStatus.NOT_FOUND + null_timestamp = Timestamp() + present_status = FieldStatus.PRESENT + + requested_features_vectors = [] + for feature_name in requested_features: + ts_vector = [null_timestamp] * n_rows + status_vector = [null_status] * n_rows + value_vector = [null_value] * n_rows + for idx, read_row in enumerate(read_rows): + row_ts_proto = Timestamp() + row_ts, feature_data = read_row + # TODO (Ly): reuse whatever timestamp if row_ts is None? + if row_ts is not None: + row_ts_proto.FromDatetime(row_ts) + ts_vector[idx] = row_ts_proto + if (feature_data is not None) and (feature_name in feature_data): + status_vector[idx] = present_status + value_vector[idx] = feature_data[feature_name] + requested_features_vectors.append((ts_vector, status_vector, value_vector)) + return requested_features_vectors def has_all_tags( diff --git a/sdk/python/tests/unit/test_unit_feature_store.py b/sdk/python/tests/unit/test_unit_feature_store.py index 8d7b32760a4..3bad0ec6c59 100644 --- a/sdk/python/tests/unit/test_unit_feature_store.py +++ b/sdk/python/tests/unit/test_unit_feature_store.py @@ -38,7 +38,7 @@ def test_get_unique_entities_success(): projection=MockFeatureViewProjection(join_key_map={}), ) - unique_entities, indexes = utils._get_unique_entities( + unique_entities, indexes, output_len = utils._get_unique_entities( table=fv, join_key_values=entity_values, entity_name_to_join_key_map=entity_name_to_join_key_map, @@ -51,6 +51,7 @@ def test_get_unique_entities_success(): assert unique_entities == expected_entities assert indexes == expected_indexes + assert output_len == 3 def test_get_unique_entities_missing_join_key_success(): @@ -74,7 +75,7 @@ def test_get_unique_entities_missing_join_key_success(): projection=MockFeatureViewProjection(join_key_map={}), ) - unique_entities, indexes = utils._get_unique_entities( + unique_entities, indexes, output_len = utils._get_unique_entities( table=fv, join_key_values=entity_values, entity_name_to_join_key_map=entity_name_to_join_key_map, @@ -87,6 +88,7 @@ def test_get_unique_entities_missing_join_key_success(): assert unique_entities == expected_entities assert indexes == expected_indexes + assert output_len == 3 # We're not say anything about the entity_1 missing from the unique_entities list assert "entity_1" not in [entity.keys() for entity in unique_entities]