Skip to content

Commit

Permalink
feat: support nomic ai model
Browse files Browse the repository at this point in the history
  • Loading branch information
guenthermi committed Sep 24, 2024
1 parent f707e58 commit 4ca4204
Showing 1 changed file with 64 additions and 8 deletions.
72 changes: 64 additions & 8 deletions chunked_pooling/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import AutoModel
from transformers.modeling_outputs import BaseModelOutputWithPooling


def construct_document(doc):
if isinstance(doc, str):
return doc
elif 'title' in doc:
return f'{doc["title"]} {doc["text"].strip()}'
else:
return doc['text'].strip()


class JinaEmbeddingsV3Wrapper(nn.Module):
Expand Down Expand Up @@ -31,7 +41,7 @@ def encode_corpus(
*args,
**kwargs,
):
_sentences = [self._construct_document(sentence) for sentence in sentences]
_sentences = [construct_document(sentence) for sentence in sentences]
return self._model.encode(_sentences, *args, task=self.tasks[1], **kwargs)

def get_instructions(self):
Expand All @@ -45,13 +55,57 @@ def forward(self, *args, **kwargs):
)
return self._model.forward(*args, adapter_mask=adapter_mask, **kwargs)

def _construct_document(self, doc):
if isinstance(doc, str):
return doc
elif 'title' in doc:
return f'{doc["title"]} {doc["text"].strip()}'
else:
return doc['text'].strip()
@property
def device(self):
return self._model.device

@staticmethod
def has_instructions():
return True


class NomicAIWrapper(nn.Module):
def __init__(self, model_name, **model_kwargs):
super().__init__()
self._model = SentenceTransformer(
model_name, trust_remote_code=True, **model_kwargs
)
self.instructions = ['search_query: ', 'search_document: ']

def get_instructions(self):
return self.instructions

def forward(self, *args, **kwargs):
# TODO combine kwargs into input
model_output = self._model.forward(kwargs)
base_model_output = BaseModelOutputWithPooling(
last_hidden_state=model_output['token_embeddings'],
pooler_output=model_output['sentence_embedding'],
attentions=model_output['attention_mask'],
)
return base_model_output

def encode_queries(
self,
sentences: Union[str, List[str]],
*args,
**kwargs,
):
return self._model.encode(
[self.instructions[0] + s for s in sentences], *args, **kwargs
)

def encode_corpus(
self,
sentences: Union[str, List[str]],
*args,
**kwargs,
):
return self._model.encode(
[self.instructions[1] + construct_document(s) for s in sentences],
*args,
**kwargs,
)

@property
def device(self):
Expand All @@ -65,7 +119,9 @@ def has_instructions():
MODEL_WRAPPERS = {
'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper,
'sentence-transformers/all-MiniLM-L6-v2': SentenceTransformer,
'nomic-ai/nomic-embed-text-v1': NomicAIWrapper,
}

MODELS_WITHOUT_PROMPT_NAME_ARG = [
'jinaai/jina-embeddings-v2-small-en',
'jinaai/jina-embeddings-v2-base-en',
Expand Down

0 comments on commit 4ca4204

Please sign in to comment.