From 9cdd418bca0d3f91638de671ddabe2ffdf895679 Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Wed, 4 Dec 2024 14:46:56 +0100 Subject: [PATCH] Working version of vector database integration --- README.md | 12 ++++++++-- dicee/scripts/index_serve.py | 43 +++++++++++++++++++++++++++--------- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 8a9db356..58e5f4f6 100644 --- a/README.md +++ b/README.md @@ -204,12 +204,20 @@ INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` -Retrieve embeddings of germany +Retrieve an embedding vector. ```bash curl -X 'GET' 'http://0.0.0.0:8000/api/get?q=germany' -H 'accept: application/json' +# {"result": [{"name": "europe","vector": [...]}]} ``` +Retrieve embedding vectors. ```bash -curl -X 'GET' 'http://0.0.0.0:8000/api/search?q=europe' -H 'accept: application/json' +curl -X 'POST' 'http://0.0.0.0:8000/api/search_batch' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"queries": ["brunei","guam"]}' +# {"results": [{ "name": "europe","vector": [...]},{ "name": "northern_europe","vector": [...]}]} +``` +Retrieve an average of embedding vectors. +```bash +curl -X 'POST' 'http://0.0.0.0:8000/api/search_batch' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"queries": ["europe","northern_europe"],"reducer": "mean"}' +# {"results":{"name": ["europe","northern_europe"],"vectors": [...]}} ``` diff --git a/dicee/scripts/index_serve.py b/dicee/scripts/index_serve.py index 7430aab5..9f484ce9 100644 --- a/dicee/scripts/index_serve.py +++ b/dicee/scripts/index_serve.py @@ -13,7 +13,8 @@ from fastapi import FastAPI import uvicorn - +from pydantic import BaseModel +from typing import List, Optional def get_default_arguments(): parser = argparse.ArgumentParser(add_help=False) @@ -86,14 +87,19 @@ def __init__(self, args): # semantic search self.topk=5 - def get(self,entity:str=None): - if entity is None: - return {"Input {entity} cannot be None"} - elif self.entity_to_idx.get(entity,None) is None: - return {f"Input {entity} not found"} + def retrieve_embedding(self,entity:str=None,entities:List[str]=None)->List: + ids=[] + inputs= [entity] + if entities is not None: + inputs.extend(entities) + for ent in inputs: + if idx := self.entity_to_idx.get(ent, None): + assert isinstance(idx, int) + ids.append(idx) + if len(ids)<1: + return {"error":f"IDs are not found for ({entity} or {entities})"} else: - ids=[self.entity_to_idx[entity]] - return self.qdrant_client.retrieve(collection_name=self.collection_name,ids=ids, with_vectors=True) + return [{"name": result.payload["name"], "vector": result.vector} for result in self.qdrant_client.retrieve(collection_name=self.collection_name,ids=ids, with_vectors=True)] def search(self, entity: str): return self.qdrant_client.query_points(collection_name=self.collection_name, query=self.entity_to_idx[entity],limit=self.topk) @@ -108,8 +114,25 @@ async def search_embeddings(q: str): @app.get("/api/get") async def retrieve_embeddings(q: str): - return {"result": neural_searcher.get(entity=q)} - + return {"result": neural_searcher.retrieve_embedding(entity=q)} + +class StringListRequest(BaseModel): + queries: List[str] + reducer: Optional[str] = None # Add the reducer flag with default as None + + +@app.post("/api/search_batch") +async def search_embeddings_batch(request: StringListRequest): + if request.reducer == "mean": + names=[] + vectors=[] + for result in neural_searcher.retrieve_embedding(entities=request.queries): + names.append(result["name"]) + vectors.append(result["vector"]) + embeddings = np.mean(vectors, axis=0).tolist() # Reduce to mean + return {"results":{"name":names,"vectors":embeddings}} + else: + return {"results": neural_searcher.retrieve_embedding(entities=request.queries)} def serve(args): global neural_searcher