diff --git a/environment.yml b/environment.yml index b75674d..ee7018e 100644 --- a/environment.yml +++ b/environment.yml @@ -6,6 +6,7 @@ channels: dependencies: - python - beautifulsoup4 + - faiss-cpu - langchain - pytest - pytest-cov diff --git a/requirements.txt b/requirements.txt index 7ed58c0..f044e81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ beautifulsoup4 +faiss-cpu # this is an unofficial package langchain pytest pytest-cov diff --git a/sphinx_rag_search_engine/embedding/tests/test_sentence_transformer.py b/sphinx_rag_search_engine/embedding/tests/test_sentence_transformer.py index 8924e15..ece6f1f 100644 --- a/sphinx_rag_search_engine/embedding/tests/test_sentence_transformer.py +++ b/sphinx_rag_search_engine/embedding/tests/test_sentence_transformer.py @@ -3,7 +3,7 @@ from sphinx_rag_search_engine.embedding import SentenceTransformer -def test_xxx(): +def test_sentence_transformer(): cache_folder_path = Path(__file__).parent / "data" model_name_or_path = "sentence-transformers/paraphrase-albert-small-v2" diff --git a/sphinx_rag_search_engine/retrieval/__init__.py b/sphinx_rag_search_engine/retrieval/__init__.py new file mode 100644 index 0000000..f4936b1 --- /dev/null +++ b/sphinx_rag_search_engine/retrieval/__init__.py @@ -0,0 +1,3 @@ +from ._faiss import FAISS + +__all__ = ["FAISS"] diff --git a/sphinx_rag_search_engine/retrieval/_faiss.py b/sphinx_rag_search_engine/retrieval/_faiss.py new file mode 100644 index 0000000..e05680d --- /dev/null +++ b/sphinx_rag_search_engine/retrieval/_faiss.py @@ -0,0 +1,34 @@ +from numbers import Integral + +import faiss +from sklearn.base import BaseEstimator +from sklearn.utils._param_validation import HasMethods, Interval + + +class FAISS(BaseEstimator): + + _parameter_constraints = { + "embedding": [HasMethods(["fit_transform", "transform"])], + "top_k": [Interval(Integral, left=1, right=None, closed="left")], + } + + def __init__(self, *, embedding, top_k=1): + self.embedding = embedding + self.top_k = top_k + + def fit(self, X, y=None): + self._validate_params() + self.X_fit_ = X + self.X_embedded_ = self.embedding.fit_transform(X) + # normalize vectors to compute the cosine similarity + faiss.normalize_L2(self.X_embedded_) + self.index_ = faiss.IndexFlatIP(self.X_embedded_.shape[1]) + self.index_.add(self.X_embedded_) + return self + + def transform(self, X): + X_embedded = self.embedding.transform(X) + # normalize vectors to compute the cosine similarity + faiss.normalize_L2(X_embedded) + _, indices = self.index_.search(X_embedded, 1) + return self.X_fit_[indices[:, 0]] diff --git a/sphinx_rag_search_engine/retrieval/tests/__init__.py b/sphinx_rag_search_engine/retrieval/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sphinx_rag_search_engine/retrieval/tests/test_faiss.py b/sphinx_rag_search_engine/retrieval/tests/test_faiss.py new file mode 100644 index 0000000..eac1b7e --- /dev/null +++ b/sphinx_rag_search_engine/retrieval/tests/test_faiss.py @@ -0,0 +1,16 @@ +from pathlib import Path + +from sphinx_rag_search_engine.embedding import SentenceTransformer +from sphinx_rag_search_engine.retrieval import FAISS + + +def test_xxx(): + cache_folder_path = ( + Path(__file__).parent.parent.parent / "embedding" / "tests" / "data" + ) + model_name_or_path = "sentence-transformers/paraphrase-albert-small-v2" + + embedder = SentenceTransformer( + model_name_or_path=model_name_or_path, cache_folder=str(cache_folder_path) + ) + faiss = FAISS(embedding=embedder).fit([{"source": "hello world", "text": "hello world"}])