From 0dff5f16d17ce398fa1b28cfc3822608988d621d Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Wed, 26 Feb 2025 11:45:45 -0500 Subject: [PATCH] updated to implement full text search Signed-off-by: Francisco Javier Arceo --- sdk/python/feast/feature_server.py | 2 + sdk/python/feast/feature_store.py | 5 + .../milvus_online_store/milvus.py | 1 + .../feast/infra/online_stores/online_store.py | 2 + .../feast/infra/online_stores/sqlite.py | 189 +++++++++++++----- .../feast/infra/passthrough_provider.py | 2 + sdk/python/feast/infra/provider.py | 2 + sdk/python/tests/foo_provider.py | 1 + .../online_store/test_online_retrieval.py | 79 +++++++- 9 files changed, 231 insertions(+), 52 deletions(-) diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index ed742bcb98f..023ce7d1115 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -75,6 +75,7 @@ class GetOnlineFeaturesRequest(BaseModel): features: Optional[List[str]] = None full_feature_names: bool = False query_embedding: Optional[List[float]] = None + query_string: Optional[str] = None def _get_features(request: GetOnlineFeaturesRequest, store: "feast.FeatureStore"): @@ -195,6 +196,7 @@ async def retrieve_online_documents( entity_rows=request.entities, full_feature_names=request.full_feature_names, query=request.query_embedding, + query_string=request.query_string, ) response = await run_in_threadpool( diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index f0bdc4c1f28..c3a8cd7a2bc 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1867,6 +1867,7 @@ def retrieve_online_documents_v2( top_k: int, features: List[str], distance_metric: Optional[str] = "L2", + query_string: Optional[str] = None, ) -> OnlineResponse: """ Retrieves the top k closest document features. Note, embeddings are a subset of features. @@ -1878,6 +1879,7 @@ def retrieve_online_documents_v2( query: The query to retrieve the closest document features for. top_k: The number of closest document features to retrieve. distance_metric: The distance metric to use for retrieval. + query_string: The query string to retrieve the closest document features using keyword search (bm25). """ if isinstance(query, str): raise ValueError( @@ -1919,6 +1921,7 @@ def retrieve_online_documents_v2( query, top_k, distance_metric, + query_string, ) def _retrieve_from_online_store( @@ -1988,6 +1991,7 @@ def _retrieve_from_online_store_v2( query: List[float], top_k: int, distance_metric: Optional[str], + query_string: Optional[str], ) -> OnlineResponse: """ Search and return document features from the online document store. @@ -2003,6 +2007,7 @@ def _retrieve_from_online_store_v2( query=query, top_k=top_k, distance_metric=distance_metric, + query_string=query_string, ) entity_key_dict: Dict[str, List[ValueProto]] = {} diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index 4d14d826085..9b9f003ebba 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -463,6 +463,7 @@ def retrieve_online_documents_v2( embedding: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index f5202b66f66..7165ef59a3e 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -439,6 +439,7 @@ def retrieve_online_documents_v2( embedding: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], @@ -456,6 +457,7 @@ def retrieve_online_documents_v2( requested_features: The list of features whose embeddings should be used for retrieval. embedding: The embeddings to use for retrieval. top_k: The number of documents to retrieve. + query_string: The query string to search for using keyword search (bm25) (optional) Returns: object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index 2fcf36b1001..87da851ce5f 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -40,7 +40,7 @@ from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.type_map import feast_value_type_to_python_type -from feast.types import FEAST_VECTOR_TYPES +from feast.types import FEAST_VECTOR_TYPES, PrimitiveFeastType from feast.utils import ( _build_retrieve_online_document_record, _serialize_vector_to_float_list, @@ -442,6 +442,7 @@ def retrieve_online_documents_v2( query: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], @@ -458,72 +459,135 @@ def retrieve_online_documents_v2( query: Query embedding to search for top_k: Number of items to return distance_metric: Distance metric to use (optional) + query_string: The query string to search for using keyword search (bm25) (optional) Returns: List of tuples containing the event timestamp, entity key, and feature values """ online_store = config.online_store if not isinstance(online_store, SqliteOnlineStoreConfig): raise ValueError("online_store must be SqliteOnlineStoreConfig") - if not online_store.vector_enabled: - raise ValueError("Vector search is not enabled in the online store config") + if not online_store.vector_enabled and not online_store.text_search_enabled: + raise ValueError( + "You must enable either vector search or text search in the online store config" + ) conn = self._get_conn(config) cur = conn.cursor() - if not online_store.vector_len: + if online_store.vector_enabled and not online_store.vector_len: raise ValueError("vector_len is not configured in the online store config") - query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore table_name = _table_id(config.project, table) vector_field = _get_vector_field(table) - cur.execute( - f""" - CREATE VIRTUAL TABLE IF NOT EXISTS vec_table using vec0( - vector_value float[{online_store.vector_len}] - ); - """ - ) + if online_store.vector_enabled: + query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore + cur.execute( + f""" + CREATE VIRTUAL TABLE IF NOT EXISTS vec_table using vec0( + vector_value float[{online_store.vector_len}] + ); + """ + ) + cur.execute( + f""" + INSERT INTO vec_table (rowid, vector_value) + select rowid, vector_value from {table_name} + where feature_name = "{vector_field}" + """ + ) + elif online_store.text_search_enabled: + string_field_list = [ + f.name for f in table.features if f.dtype == PrimitiveFeastType.STRING + ] + string_fields = ", ".join(string_field_list) + # TODO: swap this for a value configurable in each Field() + BM25_DEFAULT_WEIGHTS = ", ".join( + [ + str(1.0) + for f in table.features + if f.dtype == PrimitiveFeastType.STRING + ] + ) + cur.execute( + f""" + CREATE VIRTUAL TABLE IF NOT EXISTS search_table using fts5( + entity_key, fv_rowid, {string_fields}, tokenize="porter unicode61" + ); + """ + ) + insert_query = _generate_bm25_search_insert_query( + table_name, string_field_list + ) + cur.execute(insert_query) - cur.execute( - f""" - INSERT INTO vec_table (rowid, vector_value) - select rowid, vector_value from {table_name} - where feature_name = "{vector_field}" - """ - ) + else: + raise ValueError( + "Neither vector search nor text search are enabled in the online store config" + ) - cur.execute( - f""" + if online_store.vector_enabled: + cur.execute( + f""" + select + fv2.entity_key, + fv2.feature_name, + fv2.value, + fv.vector_value, + f.distance, + fv.event_ts, + fv.created_ts + from ( + select + rowid, + vector_value, + distance + from vec_table + where vector_value match ? + order by distance + limit ? + ) f + left join {table_name} fv + on f.rowid = fv.rowid + left join {table_name} fv2 + on fv.entity_key = fv2.entity_key + where fv2.feature_name != "{vector_field}" + """, + ( + query_embedding_bin, + top_k, + ), + ) + elif online_store.text_search_enabled: + cur.execute( + f""" select - fv2.entity_key, - fv2.feature_name, - fv2.value, + fv.entity_key, + fv.feature_name, + fv.value, fv.vector_value, f.distance, fv.event_ts, fv.created_ts - from ( - select - rowid, - vector_value, - distance - from vec_table - where vector_value match ? - order by distance - limit ? - ) f - left join {table_name} fv - on f.rowid = fv.rowid - left join {table_name} fv2 - on fv.entity_key = fv2.entity_key - where fv2.feature_name != "{vector_field}" - """, - ( - query_embedding_bin, - top_k, - ), - ) + from {table_name} fv + inner join ( + select + fv_rowid, + entity_key, + {string_fields}, + bm25(search_table, {BM25_DEFAULT_WEIGHTS}) as distance + from search_table + where search_table match ? order by distance limit ? + ) f + on f.entity_key = fv.entity_key + """, + (query_string, top_k), + ) + + else: + raise ValueError( + "Neither vector search nor text search are enabled in the online store config" + ) rows = cur.fetchall() results: List[ @@ -557,9 +621,10 @@ def retrieve_online_documents_v2( feature_val.ParseFromString(value_bin) entity_dict[entity_key]["entity_key_proto"] = entity_key_proto entity_dict[entity_key][feature_name] = feature_val - entity_dict[entity_key][vector_field] = _serialize_vector_to_float_list( - vector_value - ) + if online_store.vector_enabled: + entity_dict[entity_key][vector_field] = _serialize_vector_to_float_list( + vector_value + ) entity_dict[entity_key]["distance"] = ValueProto(float_val=distance) entity_dict[entity_key]["event_ts"] = event_ts entity_dict[entity_key]["created_ts"] = created_ts @@ -706,3 +771,31 @@ def _get_vector_field(table: FeatureView) -> str: ) vector_field: str = vector_fields[0].name return vector_field + + +def _generate_bm25_search_insert_query( + table_name: str, string_field_list: List[str] +) -> str: + """ + Generates an SQL insertion query for the given table and string fields. + + Args: + table_name (str): The name of the table to select data from. + string_field_list (List[str]): The list of string fields to be used in the insertion. + + Returns: + str: The generated SQL insertion query. + """ + _string_fields = ", ".join(string_field_list) + query = f"INSERT INTO search_table (entity_key, fv_rowid, {_string_fields})\nSELECT\n\tDISTINCT fv0.entity_key,\n\tfv0.rowid as fv_rowid" + from_query = f"\nFROM (select rowid, * from {table_name} where feature_name = '{string_field_list[0]}') fv0" + + for i, string_field in enumerate(string_field_list): + query += f"\n\t,fv{i}.value as {string_field}" + if i > 0: + from_query += ( + f"\nLEFT JOIN (select rowid, * from {table_name} where feature_name = '{string_field}') fv{i}" + + f"\n\tON fv0.entity_key = fv{i}.entity_key" + ) + + return query + from_query diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 74b05113282..2a896a58b06 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -321,6 +321,7 @@ def retrieve_online_documents_v2( query: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List: result = [] if self.online_store: @@ -331,6 +332,7 @@ def retrieve_online_documents_v2( query, top_k, distance_metric, + query_string, ) return result diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index f765e754436..5155e76dddf 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -459,6 +459,7 @@ def retrieve_online_documents_v2( query: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], @@ -476,6 +477,7 @@ def retrieve_online_documents_v2( requested_features: the requested document feature names. query: The query embedding to search for. top_k: The number of documents to return. + query_string: The query string to search for using keyword search (bm25) (optional) Returns: A list of dictionaries, where each dictionary contains the datetime, entitykey, and a dictionary diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index ca6a02c4bd0..df8edf1232e 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -172,6 +172,7 @@ def retrieve_online_documents_v2( query: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index b91f3360fc0..c88cc57ef13 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -713,10 +713,9 @@ def test_sqlite_get_online_documents() -> None: ) assert record_count == len(data) * len(document_embeddings_fv.features) - # query_embedding = np.random.random( - # vector_length, - # ) - query_embedding = [float(x) for x in np.random.random(vector_length)] + query_embedding = np.random.random( + vector_length, + ) result = store.retrieve_online_documents( feature="document_embeddings:Embeddings", query=query_embedding, top_k=3 ).to_dict() @@ -930,6 +929,78 @@ def test_sqlite_get_online_documents_v2() -> None: assert len(result["distance"]) == 3 +def test_sqlite_get_online_documents_v2_search() -> None: + """Test retrieving documents using v2 method with key word search""" + n = 10 + vector_length = 8 + runner = CliRunner() + with runner.local_repo( + get_example_repo("example_feature_repo_1.py"), "file" + ) as store: + store.config.online_store.text_search_enabled = True + store.config.entity_key_serialization_version = 3 + document_embeddings_fv = store.get_feature_view(name="document_embeddings") + + provider = store._get_provider() + + # Create test data + item_keys = [ + EntityKeyProto( + join_keys=["item_id"], entity_values=[ValueProto(int64_val=i)] + ) + for i in range(n) + ] + data = [] + for i, item_key in enumerate(item_keys): + data.append( + ( + item_key, + { + "Embeddings": ValueProto( + float_list_val=FloatListProto( + val=[float(x) for x in np.random.random(vector_length)] + ) + ), + "content": ValueProto( + string_val=f"the {i}th sentence with some text" + ), + "title": ValueProto(string_val=f"Title {i}"), + }, + _utc_now(), + _utc_now(), + ) + ) + + provider.online_write_batch( + config=store.config, + table=document_embeddings_fv, + data=data, + progress=None, + ) + + # Test vector similarity search + query_embedding = [float(x) for x in np.random.random(vector_length)] + result = store.retrieve_online_documents_v2( + features=[ + "document_embeddings:Embeddings", + "document_embeddings:content", + "document_embeddings:title", + ], + query=query_embedding, + query_string="(content: 5) OR (title: 1) OR (title: 3)", + top_k=3, + ).to_dict() + + assert "Embeddings" in result + assert "content" in result + assert "title" in result + assert "distance" in result + assert ["1th sentence with some text" in r for r in result["content"]] + assert ["Title " in r for r in result["title"]] + assert len(result["distance"]) == 2 + assert result["distance"] == [-1.8458267450332642, -1.8458267450332642] + + @pytest.mark.skip(reason="Skipping this test as CI struggles with it") def test_local_milvus() -> None: import random