Skip to content

Commit

Permalink
feat: support TEI embedding service, configurable reranking model (#287)
Browse files Browse the repository at this point in the history
* feat: add support for TEI embedding service, allow reranking model to be configurable.

Signed-off-by: Kennywu <[email protected]>

* fix: add cohere default reranking model

* fix: comfort pre-commit

---------

Signed-off-by: Kennywu <[email protected]>
Co-authored-by: wujiaye <[email protected]>
Co-authored-by: Tadashi <[email protected]>
  • Loading branch information
3 people authored Sep 30, 2024
1 parent 2e3c17b commit 53530e2
Show file tree
Hide file tree
Showing 20 changed files with 928 additions and 22 deletions.
15 changes: 13 additions & 2 deletions flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
}
KH_LLMS = {}
KH_EMBEDDINGS = {}
KH_RERANKINGS = {}

# populate options from config
if config("AZURE_OPENAI_API_KEY", default="") and config(
Expand Down Expand Up @@ -212,7 +213,7 @@
"spec": {
"__type__": "kotaemon.llms.chats.LCCohereChat",
"model_name": "command-r-plus-08-2024",
"api_key": "your-key",
"api_key": config("COHERE_API_KEY", default="your-key"),
},
"default": False,
}
Expand All @@ -222,7 +223,7 @@
"spec": {
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
"model": "embed-multilingual-v3.0",
"cohere_api_key": "your-key",
"cohere_api_key": config("COHERE_API_KEY", default="your-key"),
"user_agent": "default",
},
"default": False,
Expand All @@ -235,6 +236,16 @@
# "default": False,
# }

# default reranking models
KH_RERANKINGS["cohere"] = {
"spec": {
"__type__": "kotaemon.rerankings.CohereReranking",
"model_name": "rerank-multilingual-v2.0",
"cohere_api_key": config("COHERE_API_KEY", default="your-key"),
},
"default": True,
}

KH_REASONINGS = [
"ktem.reasoning.simple.FullQAPipeline",
"ktem.reasoning.simple.FullDecomposeQAPipeline",
Expand Down
2 changes: 2 additions & 0 deletions libs/kotaemon/kotaemon/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
LCOpenAIEmbeddings,
)
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from .tei_endpoint_embed import TeiEndpointEmbeddings

__all__ = [
"BaseEmbeddings",
"EndpointEmbeddings",
"TeiEndpointEmbeddings",
"LCOpenAIEmbeddings",
"LCAzureOpenAIEmbeddings",
"LCCohereEmbeddings",
Expand Down
105 changes: 105 additions & 0 deletions libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import aiohttp
import requests

from kotaemon.base import Document, DocumentWithEmbedding, Param

from .base import BaseEmbeddings

session = requests.session()


class TeiEndpointEmbeddings(BaseEmbeddings):
"""An Embeddings component that uses an
TEI (Text-Embedding-Inference) API compatible endpoint.
Ref: https://github.com/huggingface/text-embeddings-inference
Attributes:
endpoint_url (str): The url of an TEI
(Text-Embedding-Inference) API compatible endpoint.
normalize (bool): Whether to normalize embeddings to unit length.
truncate (bool): Whether to truncate embeddings
to a fixed/default length.
"""

endpoint_url: str = Param(None, help="TEI embedding service api base URL")
normalize: bool = Param(
True,
help="Normalize embeddings to unit length",
)
truncate: bool = Param(
True,
help="Truncate embeddings to a fixed/default length",
)

async def client_(self, inputs: list[str]):
async with aiohttp.ClientSession() as session:
async with session.post(
url=self.endpoint_url,
json={
"inputs": inputs,
"normalize": self.normalize,
"truncate": self.truncate,
},
) as resp:
embeddings = await resp.json()
return embeddings

async def ainvoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
) -> list[DocumentWithEmbedding]:
if not isinstance(text, list):
text = [text]
text = self.prepare_input(text)

outputs = []
batch_size = 6
num_batch = max(len(text) // batch_size, 1)
for i in range(num_batch):
if i == num_batch - 1:
mini_batch = text[batch_size * i :]
else:
mini_batch = text[batch_size * i : batch_size * (i + 1)]
mini_batch = [x.content for x in mini_batch]
embeddings = await self.client_(mini_batch) # type: ignore
outputs.extend(
[
DocumentWithEmbedding(content=doc, embedding=embedding)
for doc, embedding in zip(mini_batch, embeddings)
]
)

return outputs

def invoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
) -> list[DocumentWithEmbedding]:
if not isinstance(text, list):
text = [text]

text = self.prepare_input(text)

outputs = []
batch_size = 6
num_batch = max(len(text) // batch_size, 1)
for i in range(num_batch):
if i == num_batch - 1:
mini_batch = text[batch_size * i :]
else:
mini_batch = text[batch_size * i : batch_size * (i + 1)]
mini_batch = [x.content for x in mini_batch]
embeddings = session.post(
url=self.endpoint_url,
json={
"inputs": mini_batch,
"normalize": self.normalize,
"truncate": self.truncate,
},
).json()
outputs.extend(
[
DocumentWithEmbedding(content=doc, embedding=embedding)
for doc, embedding in zip(mini_batch, embeddings)
]
)
return outputs
5 changes: 2 additions & 3 deletions libs/kotaemon/kotaemon/indices/rankings/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def run(self, documents: list[Document], query: str) -> list[Document]:
print("Cannot get Cohere API key from `ktem`", e)

if not self.cohere_api_key:
print("Cohere API key not found. Skipping reranking.")
print("Cohere API key not found. Skipping rerankings.")
return documents

cohere_client = cohere.Client(self.cohere_api_key)
Expand All @@ -52,10 +52,9 @@ def run(self, documents: list[Document], query: str) -> list[Document]:
response = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs
)
# print("Cohere score", [r.relevance_score for r in response.results])
for r in response.results:
doc = documents[r.index]
doc.metadata["cohere_reranking_score"] = r.relevance_score
doc.metadata["reranking_score"] = r.relevance_score
compressed_docs.append(doc)

return compressed_docs
2 changes: 1 addition & 1 deletion libs/kotaemon/kotaemon/indices/vectorindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def query_docstore():
# if reranker is LLMReranking, limit the document with top_k items only
if isinstance(reranker, LLMReranking):
result = self._filter_docs(result, top_k=top_k)
result = reranker(documents=result, query=text)
result = reranker.run(documents=result, query=text)

result = self._filter_docs(result, top_k=top_k)
print(f"Got raw {len(result)} retrieved documents")
Expand Down
5 changes: 5 additions & 0 deletions libs/kotaemon/kotaemon/rerankings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .base import BaseReranking
from .cohere import CohereReranking
from .tei_fast_rerank import TeiFastReranking

__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"]
13 changes: 13 additions & 0 deletions libs/kotaemon/kotaemon/rerankings/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

from abc import abstractmethod

from kotaemon.base import BaseComponent, Document


class BaseReranking(BaseComponent):
@abstractmethod
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Main method to transform list of documents
(re-ranking, filtering, etc)"""
...
56 changes: 56 additions & 0 deletions libs/kotaemon/kotaemon/rerankings/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from decouple import config

from kotaemon.base import Document, Param

from .base import BaseReranking


class CohereReranking(BaseReranking):
"""Cohere Reranking model"""

model_name: str = Param(
"rerank-multilingual-v2.0",
help=(
"ID of the model to use. You can go to [Supported Models]"
"(https://docs.cohere.com/docs/rerank-2) to see the supported models"
),
required=True,
)
cohere_api_key: str = Param(
config("COHERE_API_KEY", ""),
help="Cohere API key",
required=True,
)

def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use Cohere Reranker model to re-order documents
with their relevance score"""
try:
import cohere
except ImportError:
raise ImportError(
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
)

if not self.cohere_api_key:
print("Cohere API key not found. Skipping rerankings.")
return documents

cohere_client = cohere.Client(self.cohere_api_key)
compressed_docs: list[Document] = []

if not documents: # to avoid empty api call
return compressed_docs

_docs = [d.content for d in documents]
response = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs
)
for r in response.results:
doc = documents[r.index]
doc.metadata["reranking_score"] = r.relevance_score
compressed_docs.append(doc)

return compressed_docs
77 changes: 77 additions & 0 deletions libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

from typing import Optional

import requests

from kotaemon.base import Document, Param

from .base import BaseReranking

session = requests.session()


class TeiFastReranking(BaseReranking):
"""Text Embeddings Inference (TEI) Reranking model
(https://huggingface.co/docs/text-embeddings-inference/en/index)
"""

endpoint_url: str = Param(
None, help="TEI Reranking service api base URL", required=True
)
model_name: Optional[str] = Param(
None,
help=(
"ID of the model to use. You can go to [Supported Models]"
"(https://github.com/huggingface"
"/text-embeddings-inference?tab=readme-ov-file"
"#supported-models) to see the supported models"
),
)
is_truncated: Optional[bool] = Param(True, help="Whether to truncate the inputs")

def client(self, query, texts):
response = session.post(
url=self.endpoint_url,
json={
"query": query,
"texts": texts,
"is_truncated": self.is_truncated, # default is True
},
).json()
return response

def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use the deployed TEI rerankings service to re-order documents
with their relevance score"""
if not self.endpoint_url:
print("TEI API reranking URL not found. Skipping rerankings.")
return documents

compressed_docs: list[Document] = []

if not documents: # to avoid empty api call
return compressed_docs

if isinstance(documents[0], str):
documents = self.prepare_input(documents)

batch_size = 6
num_batch = max(len(documents) // batch_size, 1)
for i in range(num_batch):
if i == num_batch - 1:
mini_batch = documents[batch_size * i :]
else:
mini_batch = documents[batch_size * i : batch_size * (i + 1)]

_docs = [d.content for d in mini_batch]
rerank_resp = self.client(query, _docs)
for r in rerank_resp:
doc = mini_batch[r["index"]]
doc.metadata["reranking_score"] = r["score"]
compressed_docs.append(doc)

compressed_docs = sorted(
compressed_docs, key=lambda x: x.metadata["reranking_score"], reverse=True
)
return compressed_docs
2 changes: 2 additions & 0 deletions libs/ktem/ktem/embeddings/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def load_vendors(self):
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
OpenAIEmbeddings,
TeiEndpointEmbeddings,
)

self._vendors = [
Expand All @@ -67,6 +68,7 @@ def load_vendors(self):
FastEmbedEmbeddings,
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
TeiEndpointEmbeddings,
]

def __getitem__(self, key: str) -> BaseEmbeddings:
Expand Down
Loading

0 comments on commit 53530e2

Please sign in to comment.