Skip to content

Commit

Permalink
bug fix for addition of new features
Browse files Browse the repository at this point in the history
Signed-off-by: Francisco Javier Arceo <[email protected]>
  • Loading branch information
franciscojavierarceo committed Feb 23, 2025
1 parent 4e9ddec commit 61f5050
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
38 changes: 24 additions & 14 deletions sdk/python/feast/infra/online_stores/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def retrieve_online_documents(
# Convert the embedding to a binary format instead of using SerializeToString()
query_embedding_bin = serialize_f32(embedding, config.online_store.vector_len)
table_name = _table_id(project, table)
vector_field = _get_vector_field(table)

cur.execute(
f"""
Expand All @@ -369,14 +370,15 @@ def retrieve_online_documents(
f"""
INSERT INTO vec_table(rowid, vector_value)
select rowid, vector_value from {table_name}
where feature_name = "{vector_field}"
"""
)
cur.execute(
f"""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_table using vec0(
vector_value float[{config.online_store.vector_len}]
);
"""
INSERT INTO vec_table(rowid, vector_value)
VALUES (?, ?)
""",
(0, query_embedding_bin),
)

# Have to join this with the {table_name} to get the feature name and entity_key
Expand Down Expand Up @@ -473,16 +475,7 @@ def retrieve_online_documents_v2(

query_embedding_bin = serialize_f32(query, online_store.vector_len) # type: ignore
table_name = _table_id(config.project, table)
vector_fields: List[Field] = [
f for f in table.features if getattr(f, "vector_index", None)
]
assert len(vector_fields) > 0, (
f"No vector field found, please update feature view = {table.name} to declare a vector field"
)
assert len(vector_fields) < 2, (
"Only one vector field is supported, please update feature view = {table.name} to declare one vector field"
)
vector_field: str = vector_fields[0].name
vector_field = _get_vector_field(table)

cur.execute(
f"""
Expand Down Expand Up @@ -696,3 +689,20 @@ def update(self):

def teardown(self):
self.conn.execute(f"DROP TABLE IF EXISTS {self.name}")


def _get_vector_field(table: FeatureView) -> str:
"""
Get the vector field from the feature view. There can be only one.
"""
vector_fields: List[Field] = [
f for f in table.features if getattr(f, "vector_index", None)
]
assert len(vector_fields) > 0, (
f"No vector field found, please update feature view = {table.name} to declare a vector field"
)
assert len(vector_fields) < 2, (
"Only one vector field is supported, please update feature view = {table.name} to declare one vector field"
)
vector_field: str = vector_fields[0].name
return vector_field
7 changes: 4 additions & 3 deletions sdk/python/tests/unit/online_store/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,9 +713,10 @@ def test_sqlite_get_online_documents() -> None:
)
assert record_count == len(data) * len(document_embeddings_fv.features)

query_embedding = np.random.random(
vector_length,
)
# query_embedding = np.random.random(
# vector_length,
# )
query_embedding = [float(x) for x in np.random.random(vector_length)]
result = store.retrieve_online_documents(
feature="document_embeddings:Embeddings", query=query_embedding, top_k=3
).to_dict()
Expand Down

0 comments on commit 61f5050

Please sign in to comment.