From 65cd347bbf8b0cfd0cbb4389c68f980731aebc21 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Sat, 22 Feb 2025 05:59:47 -0500 Subject: [PATCH 1/4] Updating retrieve online documents v2 to work for other fields for sqlite Signed-off-by: Francisco Javier Arceo --- sdk/python/feast/feature_view.py | 4 + .../feast/infra/online_stores/sqlite.py | 209 +++++++++++------- sdk/python/feast/types.py | 13 +- .../example_repos/example_feature_repo_1.py | 2 + .../online_store/test_online_retrieval.py | 105 ++++++++- 5 files changed, 249 insertions(+), 84 deletions(-) diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 4aeb9a9c1dc..2ee454dd459 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -191,6 +191,10 @@ def __init__( else: features.append(field) + assert len([f for f in features if f.vector_index]) < 2, ( + f"Only one vector feature is allowed per feature view. Please update {self.name}." + ) + # TODO(felixwang9817): Add more robust validation of features. cols = [field.name for field in schema] for col in cols: diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index 945ec965fc9..3d04ef82cce 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -18,12 +18,13 @@ import sys from datetime import date, datetime from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union from pydantic import StrictStr from feast import Entity from feast.feature_view import FeatureView +from feast.field import Field from feast.infra.infra_object import SQLITE_INFRA_OBJECT_CLASS_TYPE, InfraObject from feast.infra.key_encoding_utils import ( deserialize_entity_key, @@ -38,7 +39,13 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel, RepoConfig -from feast.utils import _build_retrieve_online_document_record, to_naive_utc +from feast.type_map import feast_value_type_to_python_type +from feast.types import FEAST_VECTOR_TYPES +from feast.utils import ( + _build_retrieve_online_document_record, + _serialize_vector_to_float_list, + to_naive_utc, +) def adapt_date_iso(val: date): @@ -94,6 +101,7 @@ class SqliteOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): vector_enabled: bool = False vector_len: Optional[int] = None + text_search_enabled: bool = False class SqliteOnlineStore(OnlineStore): @@ -144,9 +152,8 @@ def online_write_batch( progress: Optional[Callable[[int], Any]], ) -> None: conn = self._get_conn(config) - project = config.project - + feature_type_dict = {f.name: f.dtype for f in table.features} with conn: for entity_key, values, timestamp, created_ts in data: entity_key_bin = serialize_entity_key( @@ -160,71 +167,51 @@ def online_write_batch( table_name = _table_id(project, table) for feature_name, val in values.items(): if config.online_store.vector_enabled: - vector_bin = serialize_f32( - val.float_list_val.val, config.online_store.vector_len - ) # type: ignore - conn.execute( - f""" - UPDATE {table_name} - SET value = ?, vector_value = ?, event_ts = ?, created_ts = ? - WHERE (entity_key = ? AND feature_name = ?) - """, - ( - # SET - val.SerializeToString(), - vector_bin, - timestamp, - created_ts, - # WHERE - entity_key_bin, - feature_name, - ), - ) + if feature_type_dict[feature_name] in FEAST_VECTOR_TYPES: + val_bin = serialize_f32( + val.float_list_val.val, config.online_store.vector_len + ) # type: ignore + else: + val_bin = feast_value_type_to_python_type(val) conn.execute( - f"""INSERT OR IGNORE INTO {table_name} - (entity_key, feature_name, value, vector_value, event_ts, created_ts) - VALUES (?, ?, ?, ?, ?, ?)""", + f""" + INSERT INTO {table_name} (entity_key, feature_name, value, vector_value, event_ts, created_ts) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(entity_key, feature_name) DO UPDATE SET + value = excluded.value, + vector_value = excluded.vector_value, + event_ts = excluded.event_ts, + created_ts = excluded.created_ts; + """, ( - entity_key_bin, - feature_name, - val.SerializeToString(), - vector_bin, - timestamp, - created_ts, + entity_key_bin, # entity_key + feature_name, # feature_name + val.SerializeToString(), # value + val_bin, # vector_value + timestamp, # event_ts + created_ts, # created_ts ), ) - else: conn.execute( f""" - UPDATE {table_name} - SET value = ?, event_ts = ?, created_ts = ? - WHERE (entity_key = ? AND feature_name = ?) + INSERT INTO {table_name} (entity_key, feature_name, value, event_ts, created_ts) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(entity_key, feature_name) DO UPDATE SET + value = excluded.value, + event_ts = excluded.event_ts, + created_ts = excluded.created_ts; """, ( - # SET - val.SerializeToString(), - timestamp, - created_ts, - # WHERE - entity_key_bin, - feature_name, + entity_key_bin, # entity_key + feature_name, # feature_name + val.SerializeToString(), # value + timestamp, # event_ts + created_ts, # created_ts ), ) - conn.execute( - f"""INSERT OR IGNORE INTO {table_name} - (entity_key, feature_name, value, event_ts, created_ts) - VALUES (?, ?, ?, ?, ?)""", - ( - entity_key_bin, - feature_name, - val.SerializeToString(), - timestamp, - created_ts, - ), - ) if progress: progress(1) @@ -482,13 +469,21 @@ def retrieve_online_documents_v2( conn = self._get_conn(config) cur = conn.cursor() - online_store = config.online_store - if not isinstance(online_store, SqliteOnlineStoreConfig): - raise ValueError("online_store must be SqliteOnlineStoreConfig") if not online_store.vector_len: raise ValueError("vector_len is not configured in the online store config") + query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore table_name = _table_id(config.project, table) + vector_fields: List[Field] = [ + f for f in table.features if getattr(f, "vector_index", None) + ] + assert len(vector_fields) > 0, ( + f"No vector field found, please update feature view = {table.name} to declare a vector field" + ) + assert len(vector_fields) < 2, ( + "Only one vector field is supported, please update feature view = {table.name} to declare one vector field" + ) + vector_field: str = vector_fields[0].name cur.execute( f""" @@ -500,17 +495,19 @@ def retrieve_online_documents_v2( cur.execute( f""" - INSERT INTO vec_table(rowid, vector_value) + INSERT INTO vec_table (rowid, vector_value) select rowid, vector_value from {table_name} + where feature_name = "{vector_field}" """ ) cur.execute( f""" select - fv.entity_key, - fv.feature_name, - fv.value, + fv2.entity_key, + fv2.feature_name, + fv2.value, + fv.vector_value, f.distance, fv.event_ts, fv.created_ts @@ -526,17 +523,18 @@ def retrieve_online_documents_v2( ) f left join {table_name} fv on f.rowid = fv.rowid - where fv.feature_name in ({",".join(["?" for _ in requested_features])}) + left join {table_name} fv2 + on fv.entity_key = fv2.entity_key + where fv2.feature_name != "{vector_field}" """, ( query_embedding_bin, top_k, - *[f.split(":")[-1] for f in requested_features], ), ) rows = cur.fetchall() - result: List[ + results: List[ Tuple[ Optional[datetime], Optional[EntityKeyProto], @@ -544,20 +542,61 @@ def retrieve_online_documents_v2( ] ] = [] - for entity_key, feature_name, value_bin, distance, event_ts, created_ts in rows: - val = ValueProto() - val.ParseFromString(value_bin) - entity_key_proto = None - if entity_key: - entity_key_proto = deserialize_entity_key( - entity_key, - entity_key_serialization_version=config.entity_key_serialization_version, + entity_dict: Dict[ + str, Dict[str, Union[str, ValueProto, EntityKeyProto, datetime]] + ] = {} + for ( + entity_key, + feature_name, + value_bin, + vector_value, + distance, + event_ts, + created_ts, + ) in rows: + entity_key_proto = deserialize_entity_key( + entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ) + if entity_key not in entity_dict: + entity_dict[entity_key] = {} + + feature_val = ValueProto() + feature_val.ParseFromString(value_bin) + entity_dict[entity_key]["entity_key_proto"] = entity_key_proto + entity_dict[entity_key][feature_name] = feature_val + entity_dict[entity_key][vector_field] = _serialize_vector_to_float_list( + vector_value + ) + entity_dict[entity_key]["distance"] = ValueProto(float_val=distance) + entity_dict[entity_key]["event_ts"] = event_ts + entity_dict[entity_key]["created_ts"] = created_ts + + for entity_key_value in entity_dict: + res_event_ts: Optional[datetime] = None + res_entity_key_proto: Optional[EntityKeyProto] = None + if isinstance(entity_dict[entity_key_value]["event_ts"], datetime): + res_event_ts = entity_dict[entity_key_value]["event_ts"] # type: ignore[assignment] + + if isinstance( + entity_dict[entity_key_value]["entity_key_proto"], EntityKeyProto + ): + res_entity_key_proto = entity_dict[entity_key_value]["entity_key_proto"] # type: ignore[assignment] + + res_dict: Dict[str, ValueProto] = { + k: v + for k, v in entity_dict[entity_key_value].items() + if isinstance(v, ValueProto) and isinstance(k, str) + } + + results.append( + ( + res_event_ts, + res_entity_key_proto, + res_dict, ) - res = {feature_name: val} - res["distance"] = ValueProto(float_val=distance) - result.append((event_ts, entity_key_proto, res)) - - return result + ) + return results def _initialize_conn( @@ -640,7 +679,17 @@ def update(self): except ModuleNotFoundError: logging.warning("Cannot use sqlite_vec for vector search") self.conn.execute( - f"CREATE TABLE IF NOT EXISTS {self.name} (entity_key BLOB, feature_name TEXT, value BLOB, vector_value BLOB, event_ts timestamp, created_ts timestamp, PRIMARY KEY(entity_key, feature_name))" + f""" + CREATE TABLE IF NOT EXISTS {self.name} ( + entity_key BLOB, + feature_name TEXT, + value BLOB, + vector_value BLOB, + event_ts timestamp, + created_ts timestamp, + PRIMARY KEY(entity_key, feature_name) + ) + """ ) self.conn.execute( f"CREATE INDEX IF NOT EXISTS {self.name}_ek ON {self.name} (entity_key);" diff --git a/sdk/python/feast/types.py b/sdk/python/feast/types.py index 59980d816a8..4f13fbf2652 100644 --- a/sdk/python/feast/types.py +++ b/sdk/python/feast/types.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from datetime import datetime, timezone from enum import Enum -from typing import Dict, Union +from typing import Dict, List, Union import pyarrow @@ -196,6 +196,17 @@ def __str__(self): UnixTimestamp: pyarrow.timestamp("us", tz=_utc_now().tzname()), } +FEAST_VECTOR_TYPES: List[Union[ValueType, PrimitiveFeastType, ComplexFeastType]] = [ + ValueType.BYTES_LIST, + ValueType.INT32_LIST, + ValueType.INT64_LIST, + ValueType.FLOAT_LIST, + ValueType.BOOL_LIST, +] +for k in VALUE_TYPES_TO_FEAST_TYPES: + if k in FEAST_VECTOR_TYPES: + FEAST_VECTOR_TYPES.append(VALUE_TYPES_TO_FEAST_TYPES[k]) + def from_feast_to_pyarrow_type(feast_type: FeastType) -> pyarrow.DataType: """ diff --git a/sdk/python/tests/example_repos/example_feature_repo_1.py b/sdk/python/tests/example_repos/example_feature_repo_1.py index ea33859f4de..1671bd0ae3a 100644 --- a/sdk/python/tests/example_repos/example_feature_repo_1.py +++ b/sdk/python/tests/example_repos/example_feature_repo_1.py @@ -125,6 +125,8 @@ vector_search_metric="L2", ), Field(name="item_id", dtype=String), + Field(name="content", dtype=String), + Field(name="title", dtype=String), ], source=rag_documents_source, ttl=timedelta(hours=24), diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index 5e5a3104c13..b9513518108 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -753,6 +753,93 @@ def test_sqlite_vec_import() -> None: assert result == [(2, 2.39), (1, 2.39)] +def test_sqlite_hybrid_search() -> None: + imdb_sample_data = { + "Rank": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5}, + "Title": { + 0: "Guardians of the Galaxy", + 1: "Prometheus", + 2: "Split", + 3: "Sing", + 4: "Suicide Squad", + }, + "Genre": { + 0: "Action,Adventure,Sci-Fi", + 1: "Adventure,Mystery,Sci-Fi", + 2: "Horror,Thriller", + 3: "Animation,Comedy,Family", + 4: "Action,Adventure,Fantasy", + }, + "Description": { + 0: "A group of intergalactic criminals are forced to work together to stop a fanatical warrior from taking control of the universe.", + 1: "Following clues to the origin of mankind, a team finds a structure on a distant moon, but they soon realize they are not alone.", + 2: "Three girls are kidnapped by a man with a diagnosed 23 distinct personalities. They must try to escape before the apparent emergence of a frightful new 24th.", + 3: "In a city of humanoid animals, a hustling theater impresario's attempt to save his theater with a singing competition becomes grander than he anticipates even as its finalists' find that their lives will never be the same.", + 4: "A secret government agency recruits some of the most dangerous incarcerated super-villains to form a defensive task force. Their first mission: save the world from the apocalypse.", + }, + "Director": { + 0: "James Gunn", + 1: "Ridley Scott", + 2: "M. Night Shyamalan", + 3: "Christophe Lourdelet", + 4: "David Ayer", + }, + "Actors": { + 0: "Chris Pratt, Vin Diesel, Bradley Cooper, Zoe Saldana", + 1: "Noomi Rapace, Logan Marshall-Green, Michael Fassbender, Charlize Theron", + 2: "James McAvoy, Anya Taylor-Joy, Haley Lu Richardson, Jessica Sula", + 3: "Matthew McConaughey,Reese Witherspoon, Seth MacFarlane, Scarlett Johansson", + 4: "Will Smith, Jared Leto, Margot Robbie, Viola Davis", + }, + "Year": {0: 2014, 1: 2012, 2: 2016, 3: 2016, 4: 2016}, + "Runtime (Minutes)": {0: 121, 1: 124, 2: 117, 3: 108, 4: 123}, + "Rating": {0: 8.1, 1: 7.0, 2: 7.3, 3: 7.2, 4: 6.2}, + "Votes": {0: 757074, 1: 485820, 2: 157606, 3: 60545, 4: 393727}, + "Revenue (Millions)": {0: 333.13, 1: 126.46, 2: 138.12, 3: 270.32, 4: 325.02}, + "Metascore": {0: 76.0, 1: 65.0, 2: 62.0, 3: 59.0, 4: 40.0}, + } + df = pd.DataFrame(imdb_sample_data) + db = sqlite3.connect(":memory:") + + cur = db.cursor() + + cur.execute( + 'create virtual table imdb using fts5(title, description, genre, rating, tokenize="porter unicode61");' + ) + cur.executemany( + "insert into imdb (title, description, genre, rating) values (?,?,?,?);", + df[["Title", "Description", "Genre", "Rating"]].to_records(index=False), + ) + db.commit() + + query = "Prom" + res = cur.execute(f"""select title, description, genre, rating, rank + from imdb + where title MATCH "{query}*" + ORDER BY rank + limit 5""").fetchall() + assert len(res) == 1 + assert res[0][0] == "Prometheus" + + q = "(title : the OR of) AND (genre: Action OR Comedy)" + res_df = pd.read_sql_query( + f""" + select + rowid, + title, + description, + bm25(imdb, 10.0, 5.0) + from imdb + where imdb MATCH "{q}" + ORDER BY bm25(imdb, 10.0, 5.0) + limit 5 + """, + db, + ) + res_df["rowid"].tolist() == [1, 4, 5] + res_df["title"].tolist() == ["Guardians of the Galaxy", "Sing", "Suicide Squad"] + + @pytest.mark.skipif( sys.version_info[0:2] != (3, 10), reason="Only works on Python 3.10", @@ -780,7 +867,7 @@ def test_sqlite_get_online_documents_v2() -> None: for i in range(n) ] data = [] - for item_key in item_keys: + for i, item_key in enumerate(item_keys): data.append( ( item_key, @@ -789,7 +876,11 @@ def test_sqlite_get_online_documents_v2() -> None: float_list_val=FloatListProto( val=[float(x) for x in np.random.random(vector_length)] ) - ) + ), + "content": ValueProto( + string_val=f"the {i}th sentence with some text" + ), + "title": ValueProto(string_val=f"Title {i}"), }, _utc_now(), _utc_now(), @@ -806,13 +897,21 @@ def test_sqlite_get_online_documents_v2() -> None: # Test vector similarity search query_embedding = [float(x) for x in np.random.random(vector_length)] result = store.retrieve_online_documents_v2( - features=["document_embeddings:Embeddings"], + features=[ + "document_embeddings:Embeddings", + "document_embeddings:content", + "document_embeddings:title", + ], query=query_embedding, top_k=3, ).to_dict() assert "Embeddings" in result + assert "content" in result + assert "title" in result assert "distance" in result + assert ["1th sentence with some text" in r for r in result["content"]] + assert ["Title " in r for r in result["title"]] assert len(result["distance"]) == 3 From 198cb20f249e02339f3b05dbc23127c78bda7f70 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Sat, 22 Feb 2025 23:16:29 -0500 Subject: [PATCH 2/4] updating tests...not working entirely yet but close Signed-off-by: Francisco Javier Arceo --- .../feast/infra/online_stores/sqlite.py | 1 - .../online_store/test_online_retrieval.py | 42 ++++++++++++------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index 3d04ef82cce..d9b70aa010a 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -171,7 +171,6 @@ def online_write_batch( val_bin = serialize_f32( val.float_list_val.val, config.online_store.vector_len ) # type: ignore - else: val_bin = feast_value_type_to_python_type(val) conn.execute( diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index b9513518108..bf424aa5b6a 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -640,12 +640,12 @@ def test_sqlite_get_online_documents() -> None: item_keys = [ EntityKeyProto( - join_keys=["item_id"], entity_values=[ValueProto(int64_val=i)] + join_keys=["item_id"], entity_values=[ValueProto(string_val=str(i))] ) for i in range(n) ] data = [] - for item_key in item_keys: + for i, item_key in enumerate(item_keys): data.append( ( item_key, @@ -656,19 +656,17 @@ def test_sqlite_get_online_documents() -> None: vector_length, ) ) - ) + ), + "content": ValueProto( + string_val=f"the {i}th sentence with some text" + ), + "title": ValueProto(string_val=f"Title {i}"), }, _utc_now(), _utc_now(), ) ) - provider.online_write_batch( - config=store.config, - table=document_embeddings_fv, - data=data, - progress=None, - ) documents_df = pd.DataFrame( { "item_id": [str(i) for i in range(n)], @@ -678,26 +676,42 @@ def test_sqlite_get_online_documents() -> None: ) for i in range(n) ], + "content": [f"the {i}th sentence with some text" for i in range(n)], + "title": [f"Title {i}" for i in range(n)], "event_timestamp": [_utc_now() for _ in range(n)], } ) - store.write_to_online_store( - feature_view_name="document_embeddings", - df=documents_df, + print(len(data), documents_df.shape[0]) + provider.online_write_batch( + config=store.config, + table=document_embeddings_fv, + data=data, + progress=None, ) - document_table = store._provider._online_store._conn.execute( "SELECT name FROM sqlite_master WHERE type='table' and name like '%_document_embeddings';" ).fetchall() + assert len(document_table) == 1 document_table_name = document_table[0][0] + + record_count = len( + store._provider._online_store._conn.execute( + f"select * from {document_table_name}" + ).fetchall() + ) + assert record_count == len(data) * len(document_embeddings_fv.features) + store.write_to_online_store( + feature_view_name="document_embeddings", + df=documents_df, + ) record_count = len( store._provider._online_store._conn.execute( f"select * from {document_table_name}" ).fetchall() ) - assert record_count == len(data) + documents_df.shape[0] + assert record_count == len(data) * len(document_embeddings_fv.features) query_embedding = np.random.random( vector_length, From 3ea7138dc398df07cb304b31908397e3568d44a4 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Sun, 23 Feb 2025 06:25:22 -0500 Subject: [PATCH 3/4] bug fix for addition of new features Signed-off-by: Francisco Javier Arceo --- .../feast/infra/online_stores/sqlite.py | 38 ++++++++++++------- .../online_store/test_online_retrieval.py | 7 ++-- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index d9b70aa010a..2fcf36b1001 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -355,6 +355,7 @@ def retrieve_online_documents( # Convert the embedding to a binary format instead of using SerializeToString() query_embedding_bin = serialize_f32(embedding, config.online_store.vector_len) table_name = _table_id(project, table) + vector_field = _get_vector_field(table) cur.execute( f""" @@ -369,14 +370,15 @@ def retrieve_online_documents( f""" INSERT INTO vec_table(rowid, vector_value) select rowid, vector_value from {table_name} + where feature_name = "{vector_field}" """ ) cur.execute( + f""" + CREATE VIRTUAL TABLE IF NOT EXISTS vec_table using vec0( + vector_value float[{config.online_store.vector_len}] + ); """ - INSERT INTO vec_table(rowid, vector_value) - VALUES (?, ?) - """, - (0, query_embedding_bin), ) # Have to join this with the {table_name} to get the feature name and entity_key @@ -473,16 +475,7 @@ def retrieve_online_documents_v2( query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore table_name = _table_id(config.project, table) - vector_fields: List[Field] = [ - f for f in table.features if getattr(f, "vector_index", None) - ] - assert len(vector_fields) > 0, ( - f"No vector field found, please update feature view = {table.name} to declare a vector field" - ) - assert len(vector_fields) < 2, ( - "Only one vector field is supported, please update feature view = {table.name} to declare one vector field" - ) - vector_field: str = vector_fields[0].name + vector_field = _get_vector_field(table) cur.execute( f""" @@ -696,3 +689,20 @@ def update(self): def teardown(self): self.conn.execute(f"DROP TABLE IF EXISTS {self.name}") + + +def _get_vector_field(table: FeatureView) -> str: + """ + Get the vector field from the feature view. There can be only one. + """ + vector_fields: List[Field] = [ + f for f in table.features if getattr(f, "vector_index", None) + ] + assert len(vector_fields) > 0, ( + f"No vector field found, please update feature view = {table.name} to declare a vector field" + ) + assert len(vector_fields) < 2, ( + "Only one vector field is supported, please update feature view = {table.name} to declare one vector field" + ) + vector_field: str = vector_fields[0].name + return vector_field diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index bf424aa5b6a..b91f3360fc0 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -713,9 +713,10 @@ def test_sqlite_get_online_documents() -> None: ) assert record_count == len(data) * len(document_embeddings_fv.features) - query_embedding = np.random.random( - vector_length, - ) + # query_embedding = np.random.random( + # vector_length, + # ) + query_embedding = [float(x) for x in np.random.random(vector_length)] result = store.retrieve_online_documents( feature="document_embeddings:Embeddings", query=query_embedding, top_k=3 ).to_dict() From 0dff5f16d17ce398fa1b28cfc3822608988d621d Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Wed, 26 Feb 2025 11:45:45 -0500 Subject: [PATCH 4/4] updated to implement full text search Signed-off-by: Francisco Javier Arceo --- sdk/python/feast/feature_server.py | 2 + sdk/python/feast/feature_store.py | 5 + .../milvus_online_store/milvus.py | 1 + .../feast/infra/online_stores/online_store.py | 2 + .../feast/infra/online_stores/sqlite.py | 189 +++++++++++++----- .../feast/infra/passthrough_provider.py | 2 + sdk/python/feast/infra/provider.py | 2 + sdk/python/tests/foo_provider.py | 1 + .../online_store/test_online_retrieval.py | 79 +++++++- 9 files changed, 231 insertions(+), 52 deletions(-) diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index ed742bcb98f..023ce7d1115 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -75,6 +75,7 @@ class GetOnlineFeaturesRequest(BaseModel): features: Optional[List[str]] = None full_feature_names: bool = False query_embedding: Optional[List[float]] = None + query_string: Optional[str] = None def _get_features(request: GetOnlineFeaturesRequest, store: "feast.FeatureStore"): @@ -195,6 +196,7 @@ async def retrieve_online_documents( entity_rows=request.entities, full_feature_names=request.full_feature_names, query=request.query_embedding, + query_string=request.query_string, ) response = await run_in_threadpool( diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index f0bdc4c1f28..c3a8cd7a2bc 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1867,6 +1867,7 @@ def retrieve_online_documents_v2( top_k: int, features: List[str], distance_metric: Optional[str] = "L2", + query_string: Optional[str] = None, ) -> OnlineResponse: """ Retrieves the top k closest document features. Note, embeddings are a subset of features. @@ -1878,6 +1879,7 @@ def retrieve_online_documents_v2( query: The query to retrieve the closest document features for. top_k: The number of closest document features to retrieve. distance_metric: The distance metric to use for retrieval. + query_string: The query string to retrieve the closest document features using keyword search (bm25). """ if isinstance(query, str): raise ValueError( @@ -1919,6 +1921,7 @@ def retrieve_online_documents_v2( query, top_k, distance_metric, + query_string, ) def _retrieve_from_online_store( @@ -1988,6 +1991,7 @@ def _retrieve_from_online_store_v2( query: List[float], top_k: int, distance_metric: Optional[str], + query_string: Optional[str], ) -> OnlineResponse: """ Search and return document features from the online document store. @@ -2003,6 +2007,7 @@ def _retrieve_from_online_store_v2( query=query, top_k=top_k, distance_metric=distance_metric, + query_string=query_string, ) entity_key_dict: Dict[str, List[ValueProto]] = {} diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index 4d14d826085..9b9f003ebba 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -463,6 +463,7 @@ def retrieve_online_documents_v2( embedding: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index f5202b66f66..7165ef59a3e 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -439,6 +439,7 @@ def retrieve_online_documents_v2( embedding: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], @@ -456,6 +457,7 @@ def retrieve_online_documents_v2( requested_features: The list of features whose embeddings should be used for retrieval. embedding: The embeddings to use for retrieval. top_k: The number of documents to retrieve. + query_string: The query string to search for using keyword search (bm25) (optional) Returns: object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index 2fcf36b1001..87da851ce5f 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -40,7 +40,7 @@ from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.type_map import feast_value_type_to_python_type -from feast.types import FEAST_VECTOR_TYPES +from feast.types import FEAST_VECTOR_TYPES, PrimitiveFeastType from feast.utils import ( _build_retrieve_online_document_record, _serialize_vector_to_float_list, @@ -442,6 +442,7 @@ def retrieve_online_documents_v2( query: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], @@ -458,72 +459,135 @@ def retrieve_online_documents_v2( query: Query embedding to search for top_k: Number of items to return distance_metric: Distance metric to use (optional) + query_string: The query string to search for using keyword search (bm25) (optional) Returns: List of tuples containing the event timestamp, entity key, and feature values """ online_store = config.online_store if not isinstance(online_store, SqliteOnlineStoreConfig): raise ValueError("online_store must be SqliteOnlineStoreConfig") - if not online_store.vector_enabled: - raise ValueError("Vector search is not enabled in the online store config") + if not online_store.vector_enabled and not online_store.text_search_enabled: + raise ValueError( + "You must enable either vector search or text search in the online store config" + ) conn = self._get_conn(config) cur = conn.cursor() - if not online_store.vector_len: + if online_store.vector_enabled and not online_store.vector_len: raise ValueError("vector_len is not configured in the online store config") - query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore table_name = _table_id(config.project, table) vector_field = _get_vector_field(table) - cur.execute( - f""" - CREATE VIRTUAL TABLE IF NOT EXISTS vec_table using vec0( - vector_value float[{online_store.vector_len}] - ); - """ - ) + if online_store.vector_enabled: + query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore + cur.execute( + f""" + CREATE VIRTUAL TABLE IF NOT EXISTS vec_table using vec0( + vector_value float[{online_store.vector_len}] + ); + """ + ) + cur.execute( + f""" + INSERT INTO vec_table (rowid, vector_value) + select rowid, vector_value from {table_name} + where feature_name = "{vector_field}" + """ + ) + elif online_store.text_search_enabled: + string_field_list = [ + f.name for f in table.features if f.dtype == PrimitiveFeastType.STRING + ] + string_fields = ", ".join(string_field_list) + # TODO: swap this for a value configurable in each Field() + BM25_DEFAULT_WEIGHTS = ", ".join( + [ + str(1.0) + for f in table.features + if f.dtype == PrimitiveFeastType.STRING + ] + ) + cur.execute( + f""" + CREATE VIRTUAL TABLE IF NOT EXISTS search_table using fts5( + entity_key, fv_rowid, {string_fields}, tokenize="porter unicode61" + ); + """ + ) + insert_query = _generate_bm25_search_insert_query( + table_name, string_field_list + ) + cur.execute(insert_query) - cur.execute( - f""" - INSERT INTO vec_table (rowid, vector_value) - select rowid, vector_value from {table_name} - where feature_name = "{vector_field}" - """ - ) + else: + raise ValueError( + "Neither vector search nor text search are enabled in the online store config" + ) - cur.execute( - f""" + if online_store.vector_enabled: + cur.execute( + f""" + select + fv2.entity_key, + fv2.feature_name, + fv2.value, + fv.vector_value, + f.distance, + fv.event_ts, + fv.created_ts + from ( + select + rowid, + vector_value, + distance + from vec_table + where vector_value match ? + order by distance + limit ? + ) f + left join {table_name} fv + on f.rowid = fv.rowid + left join {table_name} fv2 + on fv.entity_key = fv2.entity_key + where fv2.feature_name != "{vector_field}" + """, + ( + query_embedding_bin, + top_k, + ), + ) + elif online_store.text_search_enabled: + cur.execute( + f""" select - fv2.entity_key, - fv2.feature_name, - fv2.value, + fv.entity_key, + fv.feature_name, + fv.value, fv.vector_value, f.distance, fv.event_ts, fv.created_ts - from ( - select - rowid, - vector_value, - distance - from vec_table - where vector_value match ? - order by distance - limit ? - ) f - left join {table_name} fv - on f.rowid = fv.rowid - left join {table_name} fv2 - on fv.entity_key = fv2.entity_key - where fv2.feature_name != "{vector_field}" - """, - ( - query_embedding_bin, - top_k, - ), - ) + from {table_name} fv + inner join ( + select + fv_rowid, + entity_key, + {string_fields}, + bm25(search_table, {BM25_DEFAULT_WEIGHTS}) as distance + from search_table + where search_table match ? order by distance limit ? + ) f + on f.entity_key = fv.entity_key + """, + (query_string, top_k), + ) + + else: + raise ValueError( + "Neither vector search nor text search are enabled in the online store config" + ) rows = cur.fetchall() results: List[ @@ -557,9 +621,10 @@ def retrieve_online_documents_v2( feature_val.ParseFromString(value_bin) entity_dict[entity_key]["entity_key_proto"] = entity_key_proto entity_dict[entity_key][feature_name] = feature_val - entity_dict[entity_key][vector_field] = _serialize_vector_to_float_list( - vector_value - ) + if online_store.vector_enabled: + entity_dict[entity_key][vector_field] = _serialize_vector_to_float_list( + vector_value + ) entity_dict[entity_key]["distance"] = ValueProto(float_val=distance) entity_dict[entity_key]["event_ts"] = event_ts entity_dict[entity_key]["created_ts"] = created_ts @@ -706,3 +771,31 @@ def _get_vector_field(table: FeatureView) -> str: ) vector_field: str = vector_fields[0].name return vector_field + + +def _generate_bm25_search_insert_query( + table_name: str, string_field_list: List[str] +) -> str: + """ + Generates an SQL insertion query for the given table and string fields. + + Args: + table_name (str): The name of the table to select data from. + string_field_list (List[str]): The list of string fields to be used in the insertion. + + Returns: + str: The generated SQL insertion query. + """ + _string_fields = ", ".join(string_field_list) + query = f"INSERT INTO search_table (entity_key, fv_rowid, {_string_fields})\nSELECT\n\tDISTINCT fv0.entity_key,\n\tfv0.rowid as fv_rowid" + from_query = f"\nFROM (select rowid, * from {table_name} where feature_name = '{string_field_list[0]}') fv0" + + for i, string_field in enumerate(string_field_list): + query += f"\n\t,fv{i}.value as {string_field}" + if i > 0: + from_query += ( + f"\nLEFT JOIN (select rowid, * from {table_name} where feature_name = '{string_field}') fv{i}" + + f"\n\tON fv0.entity_key = fv{i}.entity_key" + ) + + return query + from_query diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 74b05113282..2a896a58b06 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -321,6 +321,7 @@ def retrieve_online_documents_v2( query: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List: result = [] if self.online_store: @@ -331,6 +332,7 @@ def retrieve_online_documents_v2( query, top_k, distance_metric, + query_string, ) return result diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index f765e754436..5155e76dddf 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -459,6 +459,7 @@ def retrieve_online_documents_v2( query: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], @@ -476,6 +477,7 @@ def retrieve_online_documents_v2( requested_features: the requested document feature names. query: The query embedding to search for. top_k: The number of documents to return. + query_string: The query string to search for using keyword search (bm25) (optional) Returns: A list of dictionaries, where each dictionary contains the datetime, entitykey, and a dictionary diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index ca6a02c4bd0..df8edf1232e 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -172,6 +172,7 @@ def retrieve_online_documents_v2( query: List[float], top_k: int, distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index b91f3360fc0..c88cc57ef13 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -713,10 +713,9 @@ def test_sqlite_get_online_documents() -> None: ) assert record_count == len(data) * len(document_embeddings_fv.features) - # query_embedding = np.random.random( - # vector_length, - # ) - query_embedding = [float(x) for x in np.random.random(vector_length)] + query_embedding = np.random.random( + vector_length, + ) result = store.retrieve_online_documents( feature="document_embeddings:Embeddings", query=query_embedding, top_k=3 ).to_dict() @@ -930,6 +929,78 @@ def test_sqlite_get_online_documents_v2() -> None: assert len(result["distance"]) == 3 +def test_sqlite_get_online_documents_v2_search() -> None: + """Test retrieving documents using v2 method with key word search""" + n = 10 + vector_length = 8 + runner = CliRunner() + with runner.local_repo( + get_example_repo("example_feature_repo_1.py"), "file" + ) as store: + store.config.online_store.text_search_enabled = True + store.config.entity_key_serialization_version = 3 + document_embeddings_fv = store.get_feature_view(name="document_embeddings") + + provider = store._get_provider() + + # Create test data + item_keys = [ + EntityKeyProto( + join_keys=["item_id"], entity_values=[ValueProto(int64_val=i)] + ) + for i in range(n) + ] + data = [] + for i, item_key in enumerate(item_keys): + data.append( + ( + item_key, + { + "Embeddings": ValueProto( + float_list_val=FloatListProto( + val=[float(x) for x in np.random.random(vector_length)] + ) + ), + "content": ValueProto( + string_val=f"the {i}th sentence with some text" + ), + "title": ValueProto(string_val=f"Title {i}"), + }, + _utc_now(), + _utc_now(), + ) + ) + + provider.online_write_batch( + config=store.config, + table=document_embeddings_fv, + data=data, + progress=None, + ) + + # Test vector similarity search + query_embedding = [float(x) for x in np.random.random(vector_length)] + result = store.retrieve_online_documents_v2( + features=[ + "document_embeddings:Embeddings", + "document_embeddings:content", + "document_embeddings:title", + ], + query=query_embedding, + query_string="(content: 5) OR (title: 1) OR (title: 3)", + top_k=3, + ).to_dict() + + assert "Embeddings" in result + assert "content" in result + assert "title" in result + assert "distance" in result + assert ["1th sentence with some text" in r for r in result["content"]] + assert ["Title " in r for r in result["title"]] + assert len(result["distance"]) == 2 + assert result["distance"] == [-1.8458267450332642, -1.8458267450332642] + + @pytest.mark.skip(reason="Skipping this test as CI struggles with it") def test_local_milvus() -> None: import random