-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support TEI embedding service, configurable reranking model (#287)
* feat: add support for TEI embedding service, allow reranking model to be configurable. Signed-off-by: Kennywu <[email protected]> * fix: add cohere default reranking model * fix: comfort pre-commit --------- Signed-off-by: Kennywu <[email protected]> Co-authored-by: wujiaye <[email protected]> Co-authored-by: Tadashi <[email protected]>
- Loading branch information
1 parent
2e3c17b
commit 53530e2
Showing
20 changed files
with
928 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import aiohttp | ||
import requests | ||
|
||
from kotaemon.base import Document, DocumentWithEmbedding, Param | ||
|
||
from .base import BaseEmbeddings | ||
|
||
session = requests.session() | ||
|
||
|
||
class TeiEndpointEmbeddings(BaseEmbeddings): | ||
"""An Embeddings component that uses an | ||
TEI (Text-Embedding-Inference) API compatible endpoint. | ||
Ref: https://github.com/huggingface/text-embeddings-inference | ||
Attributes: | ||
endpoint_url (str): The url of an TEI | ||
(Text-Embedding-Inference) API compatible endpoint. | ||
normalize (bool): Whether to normalize embeddings to unit length. | ||
truncate (bool): Whether to truncate embeddings | ||
to a fixed/default length. | ||
""" | ||
|
||
endpoint_url: str = Param(None, help="TEI embedding service api base URL") | ||
normalize: bool = Param( | ||
True, | ||
help="Normalize embeddings to unit length", | ||
) | ||
truncate: bool = Param( | ||
True, | ||
help="Truncate embeddings to a fixed/default length", | ||
) | ||
|
||
async def client_(self, inputs: list[str]): | ||
async with aiohttp.ClientSession() as session: | ||
async with session.post( | ||
url=self.endpoint_url, | ||
json={ | ||
"inputs": inputs, | ||
"normalize": self.normalize, | ||
"truncate": self.truncate, | ||
}, | ||
) as resp: | ||
embeddings = await resp.json() | ||
return embeddings | ||
|
||
async def ainvoke( | ||
self, text: str | list[str] | Document | list[Document], *args, **kwargs | ||
) -> list[DocumentWithEmbedding]: | ||
if not isinstance(text, list): | ||
text = [text] | ||
text = self.prepare_input(text) | ||
|
||
outputs = [] | ||
batch_size = 6 | ||
num_batch = max(len(text) // batch_size, 1) | ||
for i in range(num_batch): | ||
if i == num_batch - 1: | ||
mini_batch = text[batch_size * i :] | ||
else: | ||
mini_batch = text[batch_size * i : batch_size * (i + 1)] | ||
mini_batch = [x.content for x in mini_batch] | ||
embeddings = await self.client_(mini_batch) # type: ignore | ||
outputs.extend( | ||
[ | ||
DocumentWithEmbedding(content=doc, embedding=embedding) | ||
for doc, embedding in zip(mini_batch, embeddings) | ||
] | ||
) | ||
|
||
return outputs | ||
|
||
def invoke( | ||
self, text: str | list[str] | Document | list[Document], *args, **kwargs | ||
) -> list[DocumentWithEmbedding]: | ||
if not isinstance(text, list): | ||
text = [text] | ||
|
||
text = self.prepare_input(text) | ||
|
||
outputs = [] | ||
batch_size = 6 | ||
num_batch = max(len(text) // batch_size, 1) | ||
for i in range(num_batch): | ||
if i == num_batch - 1: | ||
mini_batch = text[batch_size * i :] | ||
else: | ||
mini_batch = text[batch_size * i : batch_size * (i + 1)] | ||
mini_batch = [x.content for x in mini_batch] | ||
embeddings = session.post( | ||
url=self.endpoint_url, | ||
json={ | ||
"inputs": mini_batch, | ||
"normalize": self.normalize, | ||
"truncate": self.truncate, | ||
}, | ||
).json() | ||
outputs.extend( | ||
[ | ||
DocumentWithEmbedding(content=doc, embedding=embedding) | ||
for doc, embedding in zip(mini_batch, embeddings) | ||
] | ||
) | ||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .base import BaseReranking | ||
from .cohere import CohereReranking | ||
from .tei_fast_rerank import TeiFastReranking | ||
|
||
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from __future__ import annotations | ||
|
||
from abc import abstractmethod | ||
|
||
from kotaemon.base import BaseComponent, Document | ||
|
||
|
||
class BaseReranking(BaseComponent): | ||
@abstractmethod | ||
def run(self, documents: list[Document], query: str) -> list[Document]: | ||
"""Main method to transform list of documents | ||
(re-ranking, filtering, etc)""" | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from __future__ import annotations | ||
|
||
from decouple import config | ||
|
||
from kotaemon.base import Document, Param | ||
|
||
from .base import BaseReranking | ||
|
||
|
||
class CohereReranking(BaseReranking): | ||
"""Cohere Reranking model""" | ||
|
||
model_name: str = Param( | ||
"rerank-multilingual-v2.0", | ||
help=( | ||
"ID of the model to use. You can go to [Supported Models]" | ||
"(https://docs.cohere.com/docs/rerank-2) to see the supported models" | ||
), | ||
required=True, | ||
) | ||
cohere_api_key: str = Param( | ||
config("COHERE_API_KEY", ""), | ||
help="Cohere API key", | ||
required=True, | ||
) | ||
|
||
def run(self, documents: list[Document], query: str) -> list[Document]: | ||
"""Use Cohere Reranker model to re-order documents | ||
with their relevance score""" | ||
try: | ||
import cohere | ||
except ImportError: | ||
raise ImportError( | ||
"Please install Cohere " "`pip install cohere` to use Cohere Reranking" | ||
) | ||
|
||
if not self.cohere_api_key: | ||
print("Cohere API key not found. Skipping rerankings.") | ||
return documents | ||
|
||
cohere_client = cohere.Client(self.cohere_api_key) | ||
compressed_docs: list[Document] = [] | ||
|
||
if not documents: # to avoid empty api call | ||
return compressed_docs | ||
|
||
_docs = [d.content for d in documents] | ||
response = cohere_client.rerank( | ||
model=self.model_name, query=query, documents=_docs | ||
) | ||
for r in response.results: | ||
doc = documents[r.index] | ||
doc.metadata["reranking_score"] = r.relevance_score | ||
compressed_docs.append(doc) | ||
|
||
return compressed_docs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Optional | ||
|
||
import requests | ||
|
||
from kotaemon.base import Document, Param | ||
|
||
from .base import BaseReranking | ||
|
||
session = requests.session() | ||
|
||
|
||
class TeiFastReranking(BaseReranking): | ||
"""Text Embeddings Inference (TEI) Reranking model | ||
(https://huggingface.co/docs/text-embeddings-inference/en/index) | ||
""" | ||
|
||
endpoint_url: str = Param( | ||
None, help="TEI Reranking service api base URL", required=True | ||
) | ||
model_name: Optional[str] = Param( | ||
None, | ||
help=( | ||
"ID of the model to use. You can go to [Supported Models]" | ||
"(https://github.com/huggingface" | ||
"/text-embeddings-inference?tab=readme-ov-file" | ||
"#supported-models) to see the supported models" | ||
), | ||
) | ||
is_truncated: Optional[bool] = Param(True, help="Whether to truncate the inputs") | ||
|
||
def client(self, query, texts): | ||
response = session.post( | ||
url=self.endpoint_url, | ||
json={ | ||
"query": query, | ||
"texts": texts, | ||
"is_truncated": self.is_truncated, # default is True | ||
}, | ||
).json() | ||
return response | ||
|
||
def run(self, documents: list[Document], query: str) -> list[Document]: | ||
"""Use the deployed TEI rerankings service to re-order documents | ||
with their relevance score""" | ||
if not self.endpoint_url: | ||
print("TEI API reranking URL not found. Skipping rerankings.") | ||
return documents | ||
|
||
compressed_docs: list[Document] = [] | ||
|
||
if not documents: # to avoid empty api call | ||
return compressed_docs | ||
|
||
if isinstance(documents[0], str): | ||
documents = self.prepare_input(documents) | ||
|
||
batch_size = 6 | ||
num_batch = max(len(documents) // batch_size, 1) | ||
for i in range(num_batch): | ||
if i == num_batch - 1: | ||
mini_batch = documents[batch_size * i :] | ||
else: | ||
mini_batch = documents[batch_size * i : batch_size * (i + 1)] | ||
|
||
_docs = [d.content for d in mini_batch] | ||
rerank_resp = self.client(query, _docs) | ||
for r in rerank_resp: | ||
doc = mini_batch[r["index"]] | ||
doc.metadata["reranking_score"] = r["score"] | ||
compressed_docs.append(doc) | ||
|
||
compressed_docs = sorted( | ||
compressed_docs, key=lambda x: x.metadata["reranking_score"], reverse=True | ||
) | ||
return compressed_docs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.