Skip to content

Commit

Permalink
Working version of vector database integration
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Dec 4, 2024
1 parent 75576cf commit 9cdd418
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,20 @@ INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
```
Retrieve embeddings of germany
Retrieve an embedding vector.
```bash
curl -X 'GET' 'http://0.0.0.0:8000/api/get?q=germany' -H 'accept: application/json'
# {"result": [{"name": "europe","vector": [...]}]}
```
Retrieve embedding vectors.
```bash
curl -X 'GET' 'http://0.0.0.0:8000/api/search?q=europe' -H 'accept: application/json'
curl -X 'POST' 'http://0.0.0.0:8000/api/search_batch' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"queries": ["brunei","guam"]}'
# {"results": [{ "name": "europe","vector": [...]},{ "name": "northern_europe","vector": [...]}]}
```
Retrieve an average of embedding vectors.
```bash
curl -X 'POST' 'http://0.0.0.0:8000/api/search_batch' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"queries": ["europe","northern_europe"],"reducer": "mean"}'
# {"results":{"name": ["europe","northern_europe"],"vectors": [...]}}
```

</details>
Expand Down
43 changes: 33 additions & 10 deletions dicee/scripts/index_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from fastapi import FastAPI
import uvicorn

from pydantic import BaseModel
from typing import List, Optional

def get_default_arguments():
parser = argparse.ArgumentParser(add_help=False)
Expand Down Expand Up @@ -86,14 +87,19 @@ def __init__(self, args):
# semantic search
self.topk=5

def get(self,entity:str=None):
if entity is None:
return {"Input {entity} cannot be None"}
elif self.entity_to_idx.get(entity,None) is None:
return {f"Input {entity} not found"}
def retrieve_embedding(self,entity:str=None,entities:List[str]=None)->List:
ids=[]
inputs= [entity]
if entities is not None:
inputs.extend(entities)
for ent in inputs:
if idx := self.entity_to_idx.get(ent, None):
assert isinstance(idx, int)
ids.append(idx)
if len(ids)<1:
return {"error":f"IDs are not found for ({entity} or {entities})"}
else:
ids=[self.entity_to_idx[entity]]
return self.qdrant_client.retrieve(collection_name=self.collection_name,ids=ids, with_vectors=True)
return [{"name": result.payload["name"], "vector": result.vector} for result in self.qdrant_client.retrieve(collection_name=self.collection_name,ids=ids, with_vectors=True)]

def search(self, entity: str):
return self.qdrant_client.query_points(collection_name=self.collection_name, query=self.entity_to_idx[entity],limit=self.topk)
Expand All @@ -108,8 +114,25 @@ async def search_embeddings(q: str):

@app.get("/api/get")
async def retrieve_embeddings(q: str):
return {"result": neural_searcher.get(entity=q)}

return {"result": neural_searcher.retrieve_embedding(entity=q)}

class StringListRequest(BaseModel):
queries: List[str]
reducer: Optional[str] = None # Add the reducer flag with default as None


@app.post("/api/search_batch")
async def search_embeddings_batch(request: StringListRequest):
if request.reducer == "mean":
names=[]
vectors=[]
for result in neural_searcher.retrieve_embedding(entities=request.queries):
names.append(result["name"])
vectors.append(result["vector"])
embeddings = np.mean(vectors, axis=0).tolist() # Reduce to mean
return {"results":{"name":names,"vectors":embeddings}}
else:
return {"results": neural_searcher.retrieve_embedding(entities=request.queries)}

def serve(args):
global neural_searcher
Expand Down

0 comments on commit 9cdd418

Please sign in to comment.