diff --git a/chunked_pooling/wrappers.py b/chunked_pooling/wrappers.py index a4bb0ef..59f5f36 100644 --- a/chunked_pooling/wrappers.py +++ b/chunked_pooling/wrappers.py @@ -4,6 +4,16 @@ import torch.nn as nn from sentence_transformers import SentenceTransformer from transformers import AutoModel +from transformers.modeling_outputs import BaseModelOutputWithPooling + + +def construct_document(doc): + if isinstance(doc, str): + return doc + elif 'title' in doc: + return f'{doc["title"]} {doc["text"].strip()}' + else: + return doc['text'].strip() class JinaEmbeddingsV3Wrapper(nn.Module): @@ -31,7 +41,7 @@ def encode_corpus( *args, **kwargs, ): - _sentences = [self._construct_document(sentence) for sentence in sentences] + _sentences = [construct_document(sentence) for sentence in sentences] return self._model.encode(_sentences, *args, task=self.tasks[1], **kwargs) def get_instructions(self): @@ -45,13 +55,56 @@ def forward(self, *args, **kwargs): ) return self._model.forward(*args, adapter_mask=adapter_mask, **kwargs) - def _construct_document(self, doc): - if isinstance(doc, str): - return doc - elif 'title' in doc: - return f'{doc["title"]} {doc["text"].strip()}' - else: - return doc['text'].strip() + @property + def device(self): + return self._model.device + + @staticmethod + def has_instructions(): + return True + + +class NomicAIWrapper(nn.Module): + def __init__(self, model_name, **model_kwargs): + super().__init__() + self._model = SentenceTransformer( + model_name, trust_remote_code=True, **model_kwargs + ) + self.instructions = ['search_query: ', 'search_document: '] + + def get_instructions(self): + return self.instructions + + def forward(self, *args, **kwargs): + model_output = self._model.forward(kwargs) + base_model_output = BaseModelOutputWithPooling( + last_hidden_state=model_output['token_embeddings'], + pooler_output=model_output['sentence_embedding'], + attentions=model_output['attention_mask'], + ) + return base_model_output + + def encode_queries( + self, + sentences: Union[str, List[str]], + *args, + **kwargs, + ): + return self._model.encode( + [self.instructions[0] + s for s in sentences], *args, **kwargs + ) + + def encode_corpus( + self, + sentences: Union[str, List[str]], + *args, + **kwargs, + ): + return self._model.encode( + [self.instructions[1] + construct_document(s) for s in sentences], + *args, + **kwargs, + ) @property def device(self): @@ -65,7 +118,9 @@ def has_instructions(): MODEL_WRAPPERS = { 'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper, 'sentence-transformers/all-MiniLM-L6-v2': SentenceTransformer, + 'nomic-ai/nomic-embed-text-v1': NomicAIWrapper, } + MODELS_WITHOUT_PROMPT_NAME_ARG = [ 'jinaai/jina-embeddings-v2-small-en', 'jinaai/jina-embeddings-v2-base-en',