Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support nomic ai model #12

Merged
merged 6 commits into from
Sep 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 63 additions & 8 deletions chunked_pooling/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import AutoModel
from transformers.modeling_outputs import BaseModelOutputWithPooling


def construct_document(doc):
if isinstance(doc, str):
return doc
elif 'title' in doc:
return f'{doc["title"]} {doc["text"].strip()}'
else:
return doc['text'].strip()


class JinaEmbeddingsV3Wrapper(nn.Module):
Expand Down Expand Up @@ -31,7 +41,7 @@ def encode_corpus(
*args,
**kwargs,
):
_sentences = [self._construct_document(sentence) for sentence in sentences]
_sentences = [construct_document(sentence) for sentence in sentences]
return self._model.encode(_sentences, *args, task=self.tasks[1], **kwargs)

def get_instructions(self):
Expand All @@ -45,13 +55,56 @@ def forward(self, *args, **kwargs):
)
return self._model.forward(*args, adapter_mask=adapter_mask, **kwargs)

def _construct_document(self, doc):
if isinstance(doc, str):
return doc
elif 'title' in doc:
return f'{doc["title"]} {doc["text"].strip()}'
else:
return doc['text'].strip()
@property
def device(self):
return self._model.device

@staticmethod
def has_instructions():
return True


class NomicAIWrapper(nn.Module):
def __init__(self, model_name, **model_kwargs):
super().__init__()
self._model = SentenceTransformer(
model_name, trust_remote_code=True, **model_kwargs
)
self.instructions = ['search_query: ', 'search_document: ']

def get_instructions(self):
return self.instructions

def forward(self, *args, **kwargs):
model_output = self._model.forward(kwargs)
base_model_output = BaseModelOutputWithPooling(
last_hidden_state=model_output['token_embeddings'],
pooler_output=model_output['sentence_embedding'],
attentions=model_output['attention_mask'],
)
return base_model_output

def encode_queries(
self,
sentences: Union[str, List[str]],
*args,
**kwargs,
):
return self._model.encode(
[self.instructions[0] + s for s in sentences], *args, **kwargs
)

def encode_corpus(
self,
sentences: Union[str, List[str]],
*args,
**kwargs,
):
return self._model.encode(
[self.instructions[1] + construct_document(s) for s in sentences],
*args,
**kwargs,
)

@property
def device(self):
Expand All @@ -65,7 +118,9 @@ def has_instructions():
MODEL_WRAPPERS = {
'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper,
'sentence-transformers/all-MiniLM-L6-v2': SentenceTransformer,
'nomic-ai/nomic-embed-text-v1': NomicAIWrapper,
}

MODELS_WITHOUT_PROMPT_NAME_ARG = [
'jinaai/jina-embeddings-v2-small-en',
'jinaai/jina-embeddings-v2-base-en',
Expand Down
Loading