Skip to content

Commit

Permalink
updated to implement full text search
Browse files Browse the repository at this point in the history
Signed-off-by: Francisco Javier Arceo <[email protected]>
  • Loading branch information
franciscojavierarceo committed Feb 26, 2025
1 parent 61f5050 commit 2623df8
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 52 deletions.
2 changes: 2 additions & 0 deletions sdk/python/feast/feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class GetOnlineFeaturesRequest(BaseModel):
features: Optional[List[str]] = None
full_feature_names: bool = False
query_embedding: Optional[List[float]] = None
query_string: Optional[str] = None


def _get_features(request: GetOnlineFeaturesRequest, store: "feast.FeatureStore"):
Expand Down Expand Up @@ -195,6 +196,7 @@ async def retrieve_online_documents(
entity_rows=request.entities,
full_feature_names=request.full_feature_names,
query=request.query_embedding,
query_string=request.query_string,
)

response = await run_in_threadpool(
Expand Down
5 changes: 5 additions & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,7 @@ def retrieve_online_documents_v2(
top_k: int,
features: List[str],
distance_metric: Optional[str] = "L2",
query_string: Optional[str] = None,
) -> OnlineResponse:
"""
Retrieves the top k closest document features. Note, embeddings are a subset of features.
Expand All @@ -1878,6 +1879,7 @@ def retrieve_online_documents_v2(
query: The query to retrieve the closest document features for.
top_k: The number of closest document features to retrieve.
distance_metric: The distance metric to use for retrieval.
query_string: The query string to retrieve the closest document features using keyword search (bm25).
"""
if isinstance(query, str):
raise ValueError(
Expand Down Expand Up @@ -1919,6 +1921,7 @@ def retrieve_online_documents_v2(
query,
top_k,
distance_metric,
query_string,
)

def _retrieve_from_online_store(
Expand Down Expand Up @@ -1988,6 +1991,7 @@ def _retrieve_from_online_store_v2(
query: List[float],
top_k: int,
distance_metric: Optional[str],
query_string: Optional[str],
) -> OnlineResponse:
"""
Search and return document features from the online document store.
Expand All @@ -2003,6 +2007,7 @@ def _retrieve_from_online_store_v2(
query=query,
top_k=top_k,
distance_metric=distance_metric,
query_string=query_string,
)

entity_key_dict: Dict[str, List[ValueProto]] = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def retrieve_online_documents_v2(
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
query_string: Optional[str] = None,
) -> List[
Tuple[
Optional[datetime],
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def retrieve_online_documents_v2(
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
query_string: Optional[str] = None,
) -> List[
Tuple[
Optional[datetime],
Expand All @@ -456,6 +457,7 @@ def retrieve_online_documents_v2(
requested_features: The list of features whose embeddings should be used for retrieval.
embedding: The embeddings to use for retrieval.
top_k: The number of documents to retrieve.
query_string: The query string to search for using keyword search (bm25) (optional)
Returns:
object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple
Expand Down
189 changes: 141 additions & 48 deletions sdk/python/feast/infra/online_stores/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.type_map import feast_value_type_to_python_type
from feast.types import FEAST_VECTOR_TYPES
from feast.types import FEAST_VECTOR_TYPES, PrimitiveFeastType
from feast.utils import (
_build_retrieve_online_document_record,
_serialize_vector_to_float_list,
Expand Down Expand Up @@ -442,6 +442,7 @@ def retrieve_online_documents_v2(
query: List[float],
top_k: int,
distance_metric: Optional[str] = None,
query_string: Optional[str] = None,
) -> List[
Tuple[
Optional[datetime],
Expand All @@ -458,72 +459,135 @@ def retrieve_online_documents_v2(
query: Query embedding to search for
top_k: Number of items to return
distance_metric: Distance metric to use (optional)
query_string: The query string to search for using keyword search (bm25) (optional)
Returns:
List of tuples containing the event timestamp, entity key, and feature values
"""
online_store = config.online_store
if not isinstance(online_store, SqliteOnlineStoreConfig):
raise ValueError("online_store must be SqliteOnlineStoreConfig")
if not online_store.vector_enabled:
raise ValueError("Vector search is not enabled in the online store config")
if not online_store.vector_enabled and not online_store.text_search_enabled:
raise ValueError(
"You must enable either vector search or text search in the online store config"
)

conn = self._get_conn(config)
cur = conn.cursor()

if not online_store.vector_len:
if online_store.vector_enabled and not online_store.vector_len:
raise ValueError("vector_len is not configured in the online store config")

query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore
table_name = _table_id(config.project, table)
vector_field = _get_vector_field(table)

cur.execute(
f"""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_table using vec0(
vector_value float[{online_store.vector_len}]
);
"""
)
if online_store.vector_enabled:
query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore
cur.execute(
f"""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_table using vec0(
vector_value float[{online_store.vector_len}]
);
"""
)
cur.execute(
f"""
INSERT INTO vec_table (rowid, vector_value)
select rowid, vector_value from {table_name}
where feature_name = "{vector_field}"
"""
)
elif online_store.text_search_enabled:
string_field_list = [
f.name for f in table.features if f.dtype == PrimitiveFeastType.STRING
]
string_fields = ", ".join(string_field_list)
# TODO: swap this for a value configurable in each Field()
BM25_DEFAULT_WEIGHTS = ", ".join(
[
str(1.0)
for f in table.features
if f.dtype == PrimitiveFeastType.STRING
]
)
cur.execute(
f"""
CREATE VIRTUAL TABLE IF NOT EXISTS search_table using fts5(
entity_key, fv_rowid, {string_fields}, tokenize="porter unicode61"
);
"""
)
insert_query = _generate_bm25_search_insert_query(
table_name, string_field_list
)
cur.execute(insert_query)

cur.execute(
f"""
INSERT INTO vec_table (rowid, vector_value)
select rowid, vector_value from {table_name}
where feature_name = "{vector_field}"
"""
)
else:
raise ValueError(
"Neither vector search nor text search are enabled in the online store config"
)

cur.execute(
f"""
if online_store.vector_enabled:
cur.execute(
f"""
select
fv2.entity_key,
fv2.feature_name,
fv2.value,
fv.vector_value,
f.distance,
fv.event_ts,
fv.created_ts
from (
select
rowid,
vector_value,
distance
from vec_table
where vector_value match ?
order by distance
limit ?
) f
left join {table_name} fv
on f.rowid = fv.rowid
left join {table_name} fv2
on fv.entity_key = fv2.entity_key
where fv2.feature_name != "{vector_field}"
""",
(
query_embedding_bin,
top_k,
),
)
elif online_store.text_search_enabled:
cur.execute(
f"""
select
fv2.entity_key,
fv2.feature_name,
fv2.value,
fv.entity_key,
fv.feature_name,
fv.value,
fv.vector_value,
f.distance,
fv.event_ts,
fv.created_ts
from (
select
rowid,
vector_value,
distance
from vec_table
where vector_value match ?
order by distance
limit ?
) f
left join {table_name} fv
on f.rowid = fv.rowid
left join {table_name} fv2
on fv.entity_key = fv2.entity_key
where fv2.feature_name != "{vector_field}"
""",
(
query_embedding_bin,
top_k,
),
)
from {table_name} fv
inner join (
select
fv_rowid,
entity_key,
{string_fields},
bm25(search_table, {BM25_DEFAULT_WEIGHTS}) as distance
from search_table
where search_table match ? order by distance limit ?
) f
on f.entity_key = fv.entity_key
""",
(query_string, top_k),
)

else:
raise ValueError(
"Neither vector search nor text search are enabled in the online store config"
)

rows = cur.fetchall()
results: List[
Expand Down Expand Up @@ -557,9 +621,10 @@ def retrieve_online_documents_v2(
feature_val.ParseFromString(value_bin)
entity_dict[entity_key]["entity_key_proto"] = entity_key_proto
entity_dict[entity_key][feature_name] = feature_val
entity_dict[entity_key][vector_field] = _serialize_vector_to_float_list(
vector_value
)
if online_store.vector_enabled:
entity_dict[entity_key][vector_field] = _serialize_vector_to_float_list(
vector_value
)
entity_dict[entity_key]["distance"] = ValueProto(float_val=distance)
entity_dict[entity_key]["event_ts"] = event_ts
entity_dict[entity_key]["created_ts"] = created_ts
Expand Down Expand Up @@ -706,3 +771,31 @@ def _get_vector_field(table: FeatureView) -> str:
)
vector_field: str = vector_fields[0].name
return vector_field


def _generate_bm25_search_insert_query(
table_name: str, string_field_list: List[str]
) -> str:
"""
Generates an SQL insertion query for the given table and string fields.
Args:
table_name (str): The name of the table to select data from.
string_field_list (List[str]): The list of string fields to be used in the insertion.
Returns:
str: The generated SQL insertion query.
"""
_string_fields = ", ".join(string_field_list)
query = f"INSERT INTO search_table (entity_key, fv_rowid, {_string_fields})\nSELECT\n\tDISTINCT fv0.entity_key,\n\tfv0.rowid as fv_rowid"
from_query = f"\nFROM (select rowid, * from {table_name} where feature_name = '{string_field_list[0]}') fv0"

for i, string_field in enumerate(string_field_list):
query += f"\n\t,fv{i}.value as {string_field}"
if i > 0:
from_query += (
f"\nLEFT JOIN (select rowid, * from {table_name} where feature_name = '{string_field}') fv{i}"
+ f"\n\tON fv0.entity_key = fv{i}.entity_key"
)

return query + from_query
2 changes: 2 additions & 0 deletions sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def retrieve_online_documents_v2(
query: List[float],
top_k: int,
distance_metric: Optional[str] = None,
query_string: Optional[str] = None,
) -> List:
result = []
if self.online_store:
Expand All @@ -331,6 +332,7 @@ def retrieve_online_documents_v2(
query,
top_k,
distance_metric,
query_string,
)
return result

Expand Down
2 changes: 2 additions & 0 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def retrieve_online_documents_v2(
query: List[float],
top_k: int,
distance_metric: Optional[str] = None,
query_string: Optional[str] = None,
) -> List[
Tuple[
Optional[datetime],
Expand All @@ -476,6 +477,7 @@ def retrieve_online_documents_v2(
requested_features: the requested document feature names.
query: The query embedding to search for.
top_k: The number of documents to return.
query_string: The query string to search for using keyword search (bm25) (optional)
Returns:
A list of dictionaries, where each dictionary contains the datetime, entitykey, and a dictionary
Expand Down
1 change: 1 addition & 0 deletions sdk/python/tests/foo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def retrieve_online_documents_v2(
query: List[float],
top_k: int,
distance_metric: Optional[str] = None,
query_string: Optional[str] = None,
) -> List[
Tuple[
Optional[datetime],
Expand Down
Loading

0 comments on commit 2623df8

Please sign in to comment.