diff --git a/chunked_pooling/chunked_eval_tasks.py b/chunked_pooling/chunked_eval_tasks.py
index 23dbcbf..8cd18fb 100644
--- a/chunked_pooling/chunked_eval_tasks.py
+++ b/chunked_pooling/chunked_eval_tasks.py
@@ -228,7 +228,7 @@ class LEMBWikimQARetrievalChunked(AbsTaskChunkedRetrieval):
name="LEMBWikimQARetrievalChunked",
dataset={
"path": "dwzhu/LongEmbed",
- "revision": "6e346642246bfb4928c560ee08640dc84d074e8c",
+ "revision": "10039a580487dacecf79db69166e17ace3ede392",
"name": "LEMBWikimQARetrieval",
},
reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
@@ -297,6 +297,166 @@ def load_data(self, **kwargs):
self.data_loaded = True
+class LEMBSummScreenFDRetrievalChunked(AbsTaskChunkedRetrieval):
+ """
+ modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py
+ """
+
+ _EVAL_SPLIT = "test"
+
+ metadata = TaskMetadata(
+ name="LEMBSummScreenFDRetrievalChunked",
+ dataset={
+ "path": "dwzhu/LongEmbed",
+ "revision": "10039a580487dacecf79db69166e17ace3ede392",
+ "name": "LEMBSummScreenFDRetrieval",
+ },
+ reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
+ description=("summ_screen_fd subset of dwzhu/LongEmbed dataset."),
+ type="Retrieval",
+ category="s2p",
+ modalities=["text"],
+ eval_splits=[_EVAL_SPLIT],
+ eval_langs=["eng-Latn"],
+ main_score="ndcg_at_10",
+ date=("1950-01-01", "2019-12-31"),
+ domains=None,
+ socioeconomic_status=None,
+ n_samples=None,
+ avg_character_length=None,
+ form=None,
+ text_creation=None,
+ task_subtypes=["Article retrieval"],
+ license="not specified",
+ annotations_creators="derived",
+ dialect=[],
+ sample_creation="found",
+ bibtex_citation="""
+ @inproceedings{ho2020constructing,
+ title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps},
+ author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko},
+ booktitle={Proceedings of the 28th International Conference on Computational Linguistics},
+ pages={6609--6625},
+ year={2020}
+ }
+ """,
+ descriptive_stats={
+ "n_samples": {_EVAL_SPLIT: 500},
+ "avg_character_length": {
+ "test": {
+ "average_document_length": 30854.327,
+ "average_query_length": 591.49,
+ "num_documents": 300,
+ "num_queries": 300,
+ "average_relevant_docs_per_query": 1.0,
+ }
+ },
+ },
+ )
+
+ def load_data(self, **kwargs):
+ if self.data_loaded:
+ return
+
+ dataset_dict = {**self.metadata.dataset}
+ dataset_dict['name'] = 'summ_screen_fd'
+
+ query_list = datasets.load_dataset(**dataset_dict)["queries"]
+ queries = {row["qid"]: row["text"] for row in query_list}
+
+ corpus_list = datasets.load_dataset(**dataset_dict)["corpus"]
+ corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list}
+
+ qrels_list = datasets.load_dataset(**dataset_dict)["qrels"]
+ qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list}
+
+ self.corpus = {self._EVAL_SPLIT: corpus}
+ self.queries = {self._EVAL_SPLIT: queries}
+ self.relevant_docs = {self._EVAL_SPLIT: qrels}
+
+ self.data_loaded = True
+
+
+class LEMBQMSumRetrievalChunked(AbsTaskChunkedRetrieval):
+ """
+ modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py
+ """
+
+ _EVAL_SPLIT = "test"
+
+ metadata = TaskMetadata(
+ name="LEMBQMSumRetrievalChunked",
+ dataset={
+ "path": "dwzhu/LongEmbed",
+ "revision": "10039a580487dacecf79db69166e17ace3ede392",
+ "name": "LEMBQMSumRetrieval",
+ },
+ reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
+ description=("qmsum subset of dwzhu/LongEmbed dataset."),
+ type="Retrieval",
+ category="s2p",
+ modalities=["text"],
+ eval_splits=[_EVAL_SPLIT],
+ eval_langs=["eng-Latn"],
+ main_score="ndcg_at_10",
+ date=("1950-01-01", "2019-12-31"),
+ domains=None,
+ socioeconomic_status=None,
+ n_samples=None,
+ avg_character_length=None,
+ form=None,
+ text_creation=None,
+ task_subtypes=["Article retrieval"],
+ license="not specified",
+ annotations_creators="derived",
+ dialect=[],
+ sample_creation="found",
+ bibtex_citation="""
+ @inproceedings{ho2020constructing,
+ title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps},
+ author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko},
+ booktitle={Proceedings of the 28th International Conference on Computational Linguistics},
+ pages={6609--6625},
+ year={2020}
+ }
+ """,
+ descriptive_stats={
+ "n_samples": {_EVAL_SPLIT: 500},
+ "avg_character_length": {
+ "test": {
+ "average_document_length": 53335.817,
+ "average_query_length": 433.50,
+ "num_documents": 300,
+ "num_queries": 300,
+ "average_relevant_docs_per_query": 1.0,
+ }
+ },
+ },
+ )
+
+ def load_data(self, **kwargs):
+ if self.data_loaded:
+ return
+
+ dataset_dict = {**self.metadata.dataset}
+ dataset_dict['name'] = 'qmsum'
+
+ query_list = datasets.load_dataset(**dataset_dict)["queries"]
+ queries = {row["qid"]: row["text"] for row in query_list}
+
+ corpus_list = datasets.load_dataset(**dataset_dict)["corpus"]
+ corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list}
+
+ qrels_list = datasets.load_dataset(**dataset_dict)["qrels"]
+ qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list}
+
+ self.corpus = {self._EVAL_SPLIT: corpus}
+ self.queries = {self._EVAL_SPLIT: queries}
+ self.relevant_docs = {self._EVAL_SPLIT: qrels}
+
+ self.data_loaded = True
+
+
class LEMBNeedleRetrievalChunked(AbsTaskChunkedRetrieval):
"""
modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBNeedleRetrieval.py
diff --git a/chunked_pooling/chunking.py b/chunked_pooling/chunking.py
index facf1b0..4585aa7 100644
--- a/chunked_pooling/chunking.py
+++ b/chunked_pooling/chunking.py
@@ -150,8 +150,8 @@ def chunk(
tokenizer=tokenizer,
)
elif chunking_strategy == "fixed":
- if chunk_size < 10:
- raise ValueError("Chunk size must be greater than 10.")
+ if chunk_size < 4:
+ raise ValueError("Chunk size must be >= 4.")
return self.chunk_by_tokens(text, chunk_size, tokenizer)
elif chunking_strategy == "sentences":
return self.chunk_by_sentences(text, n_sentences, tokenizer)
diff --git a/chunked_pooling/mteb_chunked_eval.py b/chunked_pooling/mteb_chunked_eval.py
index b119deb..827fc1c 100644
--- a/chunked_pooling/mteb_chunked_eval.py
+++ b/chunked_pooling/mteb_chunked_eval.py
@@ -27,6 +27,8 @@ def __init__(
model_has_instructions: bool = False,
embedding_model_name: Optional[str] = None, # for semantic chunking
truncate_max_length: Optional[int] = 8192,
+ long_late_chunking_embed_size: Optional[int] = 0,
+ long_late_chunking_overlap_size: Optional[int] = 512,
**kwargs,
):
super().__init__(**kwargs)
@@ -51,6 +53,9 @@ def __init__(
}
self.truncate_max_length = truncate_max_length
+ self.long_late_chunking_embed_size = long_late_chunking_embed_size
+ self.long_late_chunking_overlap_size = long_late_chunking_overlap_size
+
def load_data(self, **kwargs):
self.retrieval_task.load_data(**kwargs)
self.corpus = self.retrieval_task.corpus
@@ -114,6 +119,34 @@ def _truncate_documents(self, corpus):
v['text'] = v['text'][: last_token_span[1]]
return corpus
+ def _embed_with_overlap(self, model, model_inputs):
+
+ len_tokens = len(model_inputs["input_ids"][0])
+
+ if len_tokens > self.long_late_chunking_embed_size:
+ indices = []
+ for i in range(0, len_tokens, self.long_late_chunking_embed_size - self.long_late_chunking_overlap_size):
+ start = i
+ end = min(i + self.long_late_chunking_embed_size, len_tokens)
+ indices.append((start, end))
+ else:
+ indices = [(0, len_tokens)]
+
+ outputs = []
+ for start, end in indices:
+
+ batch_inputs = {k: v[:, start:end] for k, v in model_inputs.items()}
+
+ with torch.no_grad():
+ model_output = model(**batch_inputs)
+
+ if start > 0:
+ outputs.append(model_output[0][:, self.long_late_chunking_overlap_size:])
+ else:
+ outputs.append(model_output[0])
+
+ return torch.cat(outputs, dim=1).to(model.device)
+
def _evaluate_monolingual(
self,
model,
@@ -181,17 +214,24 @@ def _evaluate_monolingual(
text_inputs,
return_tensors='pt',
padding=True,
- truncation=True,
- max_length=8192,
+ truncation=self.truncate_max_length is not None,
+ max_length=self.truncate_max_length,
)
if model.device.type == 'cuda':
model_inputs = {
k: v.to(model.device) for k, v in model_inputs.items()
}
- model_outputs = model(**model_inputs)
- output_embs = chunked_pooling(
- model_outputs, annotations, max_length=8192
- )
+
+ if self.long_late_chunking_embed_size > 0:
+ model_outputs = self._embed_with_overlap(model, model_inputs)
+ output_embs = chunked_pooling(
+ [model_outputs], annotations, max_length=None
+ )
+ else: # truncation
+ model_outputs = model(**model_inputs)
+ output_embs = chunked_pooling(
+ model_outputs, annotations, max_length=self.truncate_max_length
+ )
corpus_embs.extend(output_embs)
max_chunks = max([len(x) for x in corpus_embs])
diff --git a/explanatory_contextual_retrieval.py b/explanatory_contextual_retrieval.py
new file mode 100644
index 0000000..269b518
--- /dev/null
+++ b/explanatory_contextual_retrieval.py
@@ -0,0 +1,197 @@
+# experiments/explanatory_contextual_retrieval.py
+#
+# a simple example with a trivial piece of text to showcase the late chunking method against
+# contextual retrieval method. contextual retrieval manually inserts context to each
+# chunk, i.e. forces context to be around each chunk. so works as a good comparison
+# to late chunking to see if the similarities are similar (which they appear to be)
+
+from chunked_pooling.wrappers import load_model
+from transformers import AutoModel, AutoTokenizer, pipeline, AutoModelForCausalLM
+import torch
+import numpy as np
+
+import chunked_pooling
+from chunked_pooling import chunked_pooling
+from chunked_pooling.chunking import Chunker
+
+from typing import List, Tuple
+from transformers import AutoModel, AutoTokenizer, pipeline
+
+import requests
+import os
+
+def request_anthropic_api(prompt: str):
+ url = "https://api.anthropic.com/v1/messages"
+ headers = {
+ "x-api-key": os.getenv("ANTHROPIC_API_KEY"),
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json"
+ }
+ data = {
+ "model": "claude-3-haiku-20240307",
+ "max_tokens": 2048,
+ "messages": [
+ {"role": "user", "content": prompt}
+ ]
+ }
+ response = requests.post(url, headers=headers, json=data)
+ return response.json()["content"][0]["text"]
+
+def setup_local_llm(llm_name):
+
+ model = AutoModelForCausalLM.from_pretrained(llm_name, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(llm_name, trust_remote_code=True)
+
+ def llm(prompt):
+ messages = [{"role": "user", "content": prompt}]
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
+ inputs = inputs.to(model.device)
+ outputs = model.generate(inputs, max_new_tokens=512)
+ text_output = tokenizer.batch_decode(outputs)[0]
+ if "<|assistant|>" in text_output:
+ text_output = text_output.split("<|assistant|>")[1].strip()
+ return text_output
+
+ return llm
+
+def cosine_similarity(vector1, vector2):
+ vector1_norm = vector1 / np.linalg.norm(vector1)
+ vector2_norm = vector2 / np.linalg.norm(vector2)
+ return np.dot(vector1_norm, vector2_norm)
+
+class LateChunkingEmbedder:
+
+ def __init__(self,
+ model: AutoModel,
+ tokenizer: AutoTokenizer,
+ chunking_strategy: str = "sentences",
+ n_sentences: int = 1
+ ):
+
+ self.model = model
+ self.tokenizer = tokenizer
+
+ self.chunker = Chunker(chunking_strategy = chunking_strategy)
+ self.n_sentences = n_sentences
+
+
+ def run(self, document: str):
+ annotations = [self.chunker.chunk(text=document, tokenizer=self.tokenizer, n_sentences=self.n_sentences)]
+ model_inputs = self.tokenizer(
+ document,
+ return_tensors='pt',
+ padding=True,
+ truncation=True,
+ max_length=8192,
+ )
+ model_outputs = self.model(**model_inputs)
+ self.output_embs = chunked_pooling(
+ model_outputs, annotations, max_length=8192,
+ )[0]
+ return self.output_embs
+
+ def query(self, query: str):
+ if "output_embs" not in dir(self):
+ raise ValueError("no embeddings calculated, use .run(document) to create chunk embeddings")
+ query_embedding = self.model.encode(query)
+ similarities = []
+ for emb in self.output_embs:
+ similarities.append(cosine_similarity(query_embedding, emb))
+
+ return similarities
+
+
+class ContextualRetrievalEmbedder():
+ def __init__(self,
+ model: AutoModel,
+ tokenizer: AutoTokenizer,
+ llm_name: str = "microsoft/Phi-3.5-mini-instruct",
+ chunking_strategy: str = "fixed"
+ ):
+
+ self.llm = setup_local_llm(llm_name)
+ # self.llm = request_anthropic_api
+
+ self.prompt = """
+
+ {{WHOLE_DOCUMENT}}
+
+ Here is the chunk we want to situate within the whole document
+
+ {{CHUNK_CONTENT}}
+
+ Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk. Answer only with the succinct context and nothing else.
+ """.strip()
+
+ self.model = model
+ self.tokenizer = tokenizer
+
+ self.chunker = Chunker(chunking_strategy = chunking_strategy)
+
+
+ def _add_context(self, chunk: str, document: str):
+ prompt = self.prompt.replace("{{WHOLE_DOCUMENT}}", document).replace("{{CHUNK_CONTENT}}", chunk)
+ extra_context = self.llm(prompt)
+ return extra_context + " " + chunk
+
+ def _tokens_to_text(self, text: str, annotations: List[Tuple[int, int]]):
+ tokens = self.tokenizer.encode_plus(
+ text, return_offsets_mapping=True, add_special_tokens=False
+ )
+ token_offsets = tokens.offset_mapping
+ chunks = []
+ for start, end in annotations:
+ chunk = text[token_offsets[start][0]:token_offsets[end-1][1]]
+ chunks.append(chunk)
+ return chunks
+
+ def run(self, document: str):
+ annotations = [self.chunker.chunk(text=document, tokenizer=self.tokenizer, n_sentences=1)]
+ self.chunks = self._tokens_to_text(text=document, annotations=annotations[0])
+ self.chunks = [self._add_context(chunk, document) for chunk in self.chunks]
+
+ model_outputs = self.model.encode(self.chunks)
+ self.output_embs = [model_outputs[i, :] for i in range(len(self.chunks))]
+ return self.output_embs
+
+ def query(self, query: str):
+ if "output_embs" not in dir(self):
+ raise ValueError("no embeddings calculated, use .run(document) to create chunk embeddings")
+ query_embedding = self.model.encode(query)
+ similarities = []
+ for emb in self.output_embs:
+ similarities.append(cosine_similarity(query_embedding, emb))
+
+ return similarities
+
+
+
+if __name__ == "__main__":
+
+ text = """
+ The recent SEC filing provided insights into ACME Corp's performance for Q2 2023.
+ It highlighted a 3% revenue growth over the previous quarter.
+ The company, which had a revenue of $314 million in the prior quarter, showed steady progress.
+ They attributed this growth to strategic initiatives and operational efficiencies.
+ The report emphasized the company's resilience and ability to navigate market challenges, reflecting positively on their financial health and future prospects.
+ """.strip().replace("\n", "")
+
+ llm_model_name = "microsoft/Phi-3.5-mini-instruct"
+ embedding_model_name = "jinaai/jina-embeddings-v2-small-en"
+
+ embedding_model, has_instructions = load_model(embedding_model_name)
+ embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, trust_remote_code=True)
+
+ cr = ContextualRetrievalEmbedder(embedding_model, embedding_tokenizer, llm_model_name, chunking_strategy="sentences")
+ cr.run(text);
+ cr_cosine_similarities = cr.query("What is ACME Corp's revenue growth for Q2 2023?")
+
+ lc = LateChunkingEmbedder(embedding_model, embedding_tokenizer)
+ lc.run(text)
+ lc_cosine_similarities = lc.query("What is ACME Corp's revenue growth for Q2 2023?")
+
+ # import pandas as pd
+ for i, (cr_similarity, lc_similarity) in enumerate(zip(cr_cosine_similarities, lc_cosine_similarities)):
+ print(f"{text.split('.')[:-1][i].strip()}")
+ print(f"Similarities: Contextual Retrieval: {cr_similarity:.4f} | Late Chunking: {lc_similarity:.4f}")
+ print("")
\ No newline at end of file
diff --git a/run_chunked_eval.py b/run_chunked_eval.py
index 88494bd..adbe7fa 100644
--- a/run_chunked_eval.py
+++ b/run_chunked_eval.py
@@ -10,6 +10,9 @@
DEFAULT_CHUNK_SIZE = 256
DEFAULT_N_SENTENCES = 5
BATCH_SIZE = 1
+DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE = 256
+DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE = 0 # set to 0 to disable long late chunking
+DEFAULT_TRUNCATE_MAX_LENGTH = 8192
@click.command()
@@ -37,9 +40,9 @@
)
@click.option(
'--truncate-max-length',
- default=None,
+ default=DEFAULT_TRUNCATE_MAX_LENGTH,
type=int,
- help='Maximum number of tokens; By default, no truncation is done.',
+ help='Maximum number of tokens; by default, truncation to 8192 tokens. If None, Long Late Chunking algorithm should be enabled.',
)
@click.option(
'--chunk-size',
@@ -53,6 +56,18 @@
type=int,
help='Number of sentences per chunk for sentence strategy.',
)
+@click.option(
+ '--long-late-chunking-embed-size',
+ default=DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE,
+ type=int,
+ help='Token length of the embeddings that come before/after soft boundaries (i.e. overlapping embeddings). Above zero, overlap is used between neighbouring embeddings.',
+)
+@click.option(
+ '--long-late-chunking-overlap-size',
+ default=DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE,
+ type=int,
+ help='Number of tokens per chunk for fixed strategy.',
+)
def main(
model_name,
strategy,
@@ -62,12 +77,18 @@ def main(
truncate_max_length,
chunk_size,
n_sentences,
+ long_late_chunking_embed_size,
+ long_late_chunking_overlap_size
):
try:
task_cls = globals()[task_name]
except:
raise ValueError(f'Unknown task name: {task_name}')
-
+
+ if truncate_max_length is not None and (long_late_chunking_embed_size > 0):
+ truncate_max_length = None
+ print(f'Truncation is disabled because Long Late Chunking algorithm is enabled.')
+
model, has_instructions = load_model(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
@@ -92,6 +113,8 @@ def main(
tokenizer=tokenizer,
prune_size=None,
truncate_max_length=truncate_max_length,
+ long_late_chunking_embed_size=long_late_chunking_embed_size,
+ long_late_chunking_overlap_size=long_late_chunking_overlap_size,
**chunking_args,
)
]