Skip to content

Commit

Permalink
iter
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Dec 5, 2023
1 parent cf08f0d commit 60e53fd
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 1 deletion.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
dependencies:
- python
- beautifulsoup4
- faiss-cpu
- langchain
- pytest
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
beautifulsoup4
faiss-cpu # this is an unofficial package
langchain
pytest
pytest-cov
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sphinx_rag_search_engine.embedding import SentenceTransformer


def test_xxx():
def test_sentence_transformer():
cache_folder_path = Path(__file__).parent / "data"
model_name_or_path = "sentence-transformers/paraphrase-albert-small-v2"

Expand Down
3 changes: 3 additions & 0 deletions sphinx_rag_search_engine/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._faiss import FAISS

__all__ = ["FAISS"]
34 changes: 34 additions & 0 deletions sphinx_rag_search_engine/retrieval/_faiss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from numbers import Integral

import faiss
from sklearn.base import BaseEstimator
from sklearn.utils._param_validation import HasMethods, Interval


class FAISS(BaseEstimator):

_parameter_constraints = {
"embedding": [HasMethods(["fit_transform", "transform"])],
"top_k": [Interval(Integral, left=1, right=None, closed="left")],
}

def __init__(self, *, embedding, top_k=1):
self.embedding = embedding
self.top_k = top_k

def fit(self, X, y=None):
self._validate_params()
self.X_fit_ = X
self.X_embedded_ = self.embedding.fit_transform(X)
# normalize vectors to compute the cosine similarity
faiss.normalize_L2(self.X_embedded_)
self.index_ = faiss.IndexFlatIP(self.X_embedded_.shape[1])
self.index_.add(self.X_embedded_)
return self

def transform(self, X):
X_embedded = self.embedding.transform(X)
# normalize vectors to compute the cosine similarity
faiss.normalize_L2(X_embedded)
_, indices = self.index_.search(X_embedded, 1)
return self.X_fit_[indices[:, 0]]
Empty file.
16 changes: 16 additions & 0 deletions sphinx_rag_search_engine/retrieval/tests/test_faiss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pathlib import Path

from sphinx_rag_search_engine.embedding import SentenceTransformer
from sphinx_rag_search_engine.retrieval import FAISS


def test_xxx():
cache_folder_path = (
Path(__file__).parent.parent.parent / "embedding" / "tests" / "data"
)
model_name_or_path = "sentence-transformers/paraphrase-albert-small-v2"

embedder = SentenceTransformer(
model_name_or_path=model_name_or_path, cache_folder=str(cache_folder_path)
)
faiss = FAISS(embedding=embedder).fit([{"source": "hello world", "text": "hello world"}])

0 comments on commit 60e53fd

Please sign in to comment.