diff --git a/chunked_pooling/chunked_eval_tasks.py b/chunked_pooling/chunked_eval_tasks.py index 23dbcbf..8cd18fb 100644 --- a/chunked_pooling/chunked_eval_tasks.py +++ b/chunked_pooling/chunked_eval_tasks.py @@ -228,7 +228,7 @@ class LEMBWikimQARetrievalChunked(AbsTaskChunkedRetrieval): name="LEMBWikimQARetrievalChunked", dataset={ "path": "dwzhu/LongEmbed", - "revision": "6e346642246bfb4928c560ee08640dc84d074e8c", + "revision": "10039a580487dacecf79db69166e17ace3ede392", "name": "LEMBWikimQARetrieval", }, reference="https://huggingface.co/datasets/dwzhu/LongEmbed", @@ -297,6 +297,166 @@ def load_data(self, **kwargs): self.data_loaded = True +class LEMBSummScreenFDRetrievalChunked(AbsTaskChunkedRetrieval): + """ + modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py + """ + + _EVAL_SPLIT = "test" + + metadata = TaskMetadata( + name="LEMBSummScreenFDRetrievalChunked", + dataset={ + "path": "dwzhu/LongEmbed", + "revision": "10039a580487dacecf79db69166e17ace3ede392", + "name": "LEMBSummScreenFDRetrieval", + }, + reference="https://huggingface.co/datasets/dwzhu/LongEmbed", + description=("summ_screen_fd subset of dwzhu/LongEmbed dataset."), + type="Retrieval", + category="s2p", + modalities=["text"], + eval_splits=[_EVAL_SPLIT], + eval_langs=["eng-Latn"], + main_score="ndcg_at_10", + date=("1950-01-01", "2019-12-31"), + domains=None, + socioeconomic_status=None, + n_samples=None, + avg_character_length=None, + form=None, + text_creation=None, + task_subtypes=["Article retrieval"], + license="not specified", + annotations_creators="derived", + dialect=[], + sample_creation="found", + bibtex_citation=""" + @inproceedings{ho2020constructing, + title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps}, + author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko}, + booktitle={Proceedings of the 28th International Conference on Computational Linguistics}, + pages={6609--6625}, + year={2020} + } + """, + descriptive_stats={ + "n_samples": {_EVAL_SPLIT: 500}, + "avg_character_length": { + "test": { + "average_document_length": 30854.327, + "average_query_length": 591.49, + "num_documents": 300, + "num_queries": 300, + "average_relevant_docs_per_query": 1.0, + } + }, + }, + ) + + def load_data(self, **kwargs): + if self.data_loaded: + return + + dataset_dict = {**self.metadata.dataset} + dataset_dict['name'] = 'summ_screen_fd' + + query_list = datasets.load_dataset(**dataset_dict)["queries"] + queries = {row["qid"]: row["text"] for row in query_list} + + corpus_list = datasets.load_dataset(**dataset_dict)["corpus"] + corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list} + + qrels_list = datasets.load_dataset(**dataset_dict)["qrels"] + qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list} + + self.corpus = {self._EVAL_SPLIT: corpus} + self.queries = {self._EVAL_SPLIT: queries} + self.relevant_docs = {self._EVAL_SPLIT: qrels} + + self.data_loaded = True + + +class LEMBQMSumRetrievalChunked(AbsTaskChunkedRetrieval): + """ + modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py + """ + + _EVAL_SPLIT = "test" + + metadata = TaskMetadata( + name="LEMBQMSumRetrievalChunked", + dataset={ + "path": "dwzhu/LongEmbed", + "revision": "10039a580487dacecf79db69166e17ace3ede392", + "name": "LEMBQMSumRetrieval", + }, + reference="https://huggingface.co/datasets/dwzhu/LongEmbed", + description=("qmsum subset of dwzhu/LongEmbed dataset."), + type="Retrieval", + category="s2p", + modalities=["text"], + eval_splits=[_EVAL_SPLIT], + eval_langs=["eng-Latn"], + main_score="ndcg_at_10", + date=("1950-01-01", "2019-12-31"), + domains=None, + socioeconomic_status=None, + n_samples=None, + avg_character_length=None, + form=None, + text_creation=None, + task_subtypes=["Article retrieval"], + license="not specified", + annotations_creators="derived", + dialect=[], + sample_creation="found", + bibtex_citation=""" + @inproceedings{ho2020constructing, + title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps}, + author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko}, + booktitle={Proceedings of the 28th International Conference on Computational Linguistics}, + pages={6609--6625}, + year={2020} + } + """, + descriptive_stats={ + "n_samples": {_EVAL_SPLIT: 500}, + "avg_character_length": { + "test": { + "average_document_length": 53335.817, + "average_query_length": 433.50, + "num_documents": 300, + "num_queries": 300, + "average_relevant_docs_per_query": 1.0, + } + }, + }, + ) + + def load_data(self, **kwargs): + if self.data_loaded: + return + + dataset_dict = {**self.metadata.dataset} + dataset_dict['name'] = 'qmsum' + + query_list = datasets.load_dataset(**dataset_dict)["queries"] + queries = {row["qid"]: row["text"] for row in query_list} + + corpus_list = datasets.load_dataset(**dataset_dict)["corpus"] + corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list} + + qrels_list = datasets.load_dataset(**dataset_dict)["qrels"] + qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list} + + self.corpus = {self._EVAL_SPLIT: corpus} + self.queries = {self._EVAL_SPLIT: queries} + self.relevant_docs = {self._EVAL_SPLIT: qrels} + + self.data_loaded = True + + class LEMBNeedleRetrievalChunked(AbsTaskChunkedRetrieval): """ modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBNeedleRetrieval.py diff --git a/chunked_pooling/chunking.py b/chunked_pooling/chunking.py index facf1b0..4585aa7 100644 --- a/chunked_pooling/chunking.py +++ b/chunked_pooling/chunking.py @@ -150,8 +150,8 @@ def chunk( tokenizer=tokenizer, ) elif chunking_strategy == "fixed": - if chunk_size < 10: - raise ValueError("Chunk size must be greater than 10.") + if chunk_size < 4: + raise ValueError("Chunk size must be >= 4.") return self.chunk_by_tokens(text, chunk_size, tokenizer) elif chunking_strategy == "sentences": return self.chunk_by_sentences(text, n_sentences, tokenizer) diff --git a/chunked_pooling/mteb_chunked_eval.py b/chunked_pooling/mteb_chunked_eval.py index b119deb..827fc1c 100644 --- a/chunked_pooling/mteb_chunked_eval.py +++ b/chunked_pooling/mteb_chunked_eval.py @@ -27,6 +27,8 @@ def __init__( model_has_instructions: bool = False, embedding_model_name: Optional[str] = None, # for semantic chunking truncate_max_length: Optional[int] = 8192, + long_late_chunking_embed_size: Optional[int] = 0, + long_late_chunking_overlap_size: Optional[int] = 512, **kwargs, ): super().__init__(**kwargs) @@ -51,6 +53,9 @@ def __init__( } self.truncate_max_length = truncate_max_length + self.long_late_chunking_embed_size = long_late_chunking_embed_size + self.long_late_chunking_overlap_size = long_late_chunking_overlap_size + def load_data(self, **kwargs): self.retrieval_task.load_data(**kwargs) self.corpus = self.retrieval_task.corpus @@ -114,6 +119,34 @@ def _truncate_documents(self, corpus): v['text'] = v['text'][: last_token_span[1]] return corpus + def _embed_with_overlap(self, model, model_inputs): + + len_tokens = len(model_inputs["input_ids"][0]) + + if len_tokens > self.long_late_chunking_embed_size: + indices = [] + for i in range(0, len_tokens, self.long_late_chunking_embed_size - self.long_late_chunking_overlap_size): + start = i + end = min(i + self.long_late_chunking_embed_size, len_tokens) + indices.append((start, end)) + else: + indices = [(0, len_tokens)] + + outputs = [] + for start, end in indices: + + batch_inputs = {k: v[:, start:end] for k, v in model_inputs.items()} + + with torch.no_grad(): + model_output = model(**batch_inputs) + + if start > 0: + outputs.append(model_output[0][:, self.long_late_chunking_overlap_size:]) + else: + outputs.append(model_output[0]) + + return torch.cat(outputs, dim=1).to(model.device) + def _evaluate_monolingual( self, model, @@ -181,17 +214,24 @@ def _evaluate_monolingual( text_inputs, return_tensors='pt', padding=True, - truncation=True, - max_length=8192, + truncation=self.truncate_max_length is not None, + max_length=self.truncate_max_length, ) if model.device.type == 'cuda': model_inputs = { k: v.to(model.device) for k, v in model_inputs.items() } - model_outputs = model(**model_inputs) - output_embs = chunked_pooling( - model_outputs, annotations, max_length=8192 - ) + + if self.long_late_chunking_embed_size > 0: + model_outputs = self._embed_with_overlap(model, model_inputs) + output_embs = chunked_pooling( + [model_outputs], annotations, max_length=None + ) + else: # truncation + model_outputs = model(**model_inputs) + output_embs = chunked_pooling( + model_outputs, annotations, max_length=self.truncate_max_length + ) corpus_embs.extend(output_embs) max_chunks = max([len(x) for x in corpus_embs]) diff --git a/explanatory_contextual_retrieval.py b/explanatory_contextual_retrieval.py new file mode 100644 index 0000000..269b518 --- /dev/null +++ b/explanatory_contextual_retrieval.py @@ -0,0 +1,197 @@ +# experiments/explanatory_contextual_retrieval.py +# +# a simple example with a trivial piece of text to showcase the late chunking method against +# contextual retrieval method. contextual retrieval manually inserts context to each +# chunk, i.e. forces context to be around each chunk. so works as a good comparison +# to late chunking to see if the similarities are similar (which they appear to be) + +from chunked_pooling.wrappers import load_model +from transformers import AutoModel, AutoTokenizer, pipeline, AutoModelForCausalLM +import torch +import numpy as np + +import chunked_pooling +from chunked_pooling import chunked_pooling +from chunked_pooling.chunking import Chunker + +from typing import List, Tuple +from transformers import AutoModel, AutoTokenizer, pipeline + +import requests +import os + +def request_anthropic_api(prompt: str): + url = "https://api.anthropic.com/v1/messages" + headers = { + "x-api-key": os.getenv("ANTHROPIC_API_KEY"), + "anthropic-version": "2023-06-01", + "content-type": "application/json" + } + data = { + "model": "claude-3-haiku-20240307", + "max_tokens": 2048, + "messages": [ + {"role": "user", "content": prompt} + ] + } + response = requests.post(url, headers=headers, json=data) + return response.json()["content"][0]["text"] + +def setup_local_llm(llm_name): + + model = AutoModelForCausalLM.from_pretrained(llm_name, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(llm_name, trust_remote_code=True) + + def llm(prompt): + messages = [{"role": "user", "content": prompt}] + inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") + inputs = inputs.to(model.device) + outputs = model.generate(inputs, max_new_tokens=512) + text_output = tokenizer.batch_decode(outputs)[0] + if "<|assistant|>" in text_output: + text_output = text_output.split("<|assistant|>")[1].strip() + return text_output + + return llm + +def cosine_similarity(vector1, vector2): + vector1_norm = vector1 / np.linalg.norm(vector1) + vector2_norm = vector2 / np.linalg.norm(vector2) + return np.dot(vector1_norm, vector2_norm) + +class LateChunkingEmbedder: + + def __init__(self, + model: AutoModel, + tokenizer: AutoTokenizer, + chunking_strategy: str = "sentences", + n_sentences: int = 1 + ): + + self.model = model + self.tokenizer = tokenizer + + self.chunker = Chunker(chunking_strategy = chunking_strategy) + self.n_sentences = n_sentences + + + def run(self, document: str): + annotations = [self.chunker.chunk(text=document, tokenizer=self.tokenizer, n_sentences=self.n_sentences)] + model_inputs = self.tokenizer( + document, + return_tensors='pt', + padding=True, + truncation=True, + max_length=8192, + ) + model_outputs = self.model(**model_inputs) + self.output_embs = chunked_pooling( + model_outputs, annotations, max_length=8192, + )[0] + return self.output_embs + + def query(self, query: str): + if "output_embs" not in dir(self): + raise ValueError("no embeddings calculated, use .run(document) to create chunk embeddings") + query_embedding = self.model.encode(query) + similarities = [] + for emb in self.output_embs: + similarities.append(cosine_similarity(query_embedding, emb)) + + return similarities + + +class ContextualRetrievalEmbedder(): + def __init__(self, + model: AutoModel, + tokenizer: AutoTokenizer, + llm_name: str = "microsoft/Phi-3.5-mini-instruct", + chunking_strategy: str = "fixed" + ): + + self.llm = setup_local_llm(llm_name) + # self.llm = request_anthropic_api + + self.prompt = """ + + {{WHOLE_DOCUMENT}} + + Here is the chunk we want to situate within the whole document + + {{CHUNK_CONTENT}} + + Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk. Answer only with the succinct context and nothing else. + """.strip() + + self.model = model + self.tokenizer = tokenizer + + self.chunker = Chunker(chunking_strategy = chunking_strategy) + + + def _add_context(self, chunk: str, document: str): + prompt = self.prompt.replace("{{WHOLE_DOCUMENT}}", document).replace("{{CHUNK_CONTENT}}", chunk) + extra_context = self.llm(prompt) + return extra_context + " " + chunk + + def _tokens_to_text(self, text: str, annotations: List[Tuple[int, int]]): + tokens = self.tokenizer.encode_plus( + text, return_offsets_mapping=True, add_special_tokens=False + ) + token_offsets = tokens.offset_mapping + chunks = [] + for start, end in annotations: + chunk = text[token_offsets[start][0]:token_offsets[end-1][1]] + chunks.append(chunk) + return chunks + + def run(self, document: str): + annotations = [self.chunker.chunk(text=document, tokenizer=self.tokenizer, n_sentences=1)] + self.chunks = self._tokens_to_text(text=document, annotations=annotations[0]) + self.chunks = [self._add_context(chunk, document) for chunk in self.chunks] + + model_outputs = self.model.encode(self.chunks) + self.output_embs = [model_outputs[i, :] for i in range(len(self.chunks))] + return self.output_embs + + def query(self, query: str): + if "output_embs" not in dir(self): + raise ValueError("no embeddings calculated, use .run(document) to create chunk embeddings") + query_embedding = self.model.encode(query) + similarities = [] + for emb in self.output_embs: + similarities.append(cosine_similarity(query_embedding, emb)) + + return similarities + + + +if __name__ == "__main__": + + text = """ + The recent SEC filing provided insights into ACME Corp's performance for Q2 2023. + It highlighted a 3% revenue growth over the previous quarter. + The company, which had a revenue of $314 million in the prior quarter, showed steady progress. + They attributed this growth to strategic initiatives and operational efficiencies. + The report emphasized the company's resilience and ability to navigate market challenges, reflecting positively on their financial health and future prospects. + """.strip().replace("\n", "") + + llm_model_name = "microsoft/Phi-3.5-mini-instruct" + embedding_model_name = "jinaai/jina-embeddings-v2-small-en" + + embedding_model, has_instructions = load_model(embedding_model_name) + embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, trust_remote_code=True) + + cr = ContextualRetrievalEmbedder(embedding_model, embedding_tokenizer, llm_model_name, chunking_strategy="sentences") + cr.run(text); + cr_cosine_similarities = cr.query("What is ACME Corp's revenue growth for Q2 2023?") + + lc = LateChunkingEmbedder(embedding_model, embedding_tokenizer) + lc.run(text) + lc_cosine_similarities = lc.query("What is ACME Corp's revenue growth for Q2 2023?") + + # import pandas as pd + for i, (cr_similarity, lc_similarity) in enumerate(zip(cr_cosine_similarities, lc_cosine_similarities)): + print(f"{text.split('.')[:-1][i].strip()}") + print(f"Similarities: Contextual Retrieval: {cr_similarity:.4f} | Late Chunking: {lc_similarity:.4f}") + print("") \ No newline at end of file diff --git a/run_chunked_eval.py b/run_chunked_eval.py index 88494bd..adbe7fa 100644 --- a/run_chunked_eval.py +++ b/run_chunked_eval.py @@ -10,6 +10,9 @@ DEFAULT_CHUNK_SIZE = 256 DEFAULT_N_SENTENCES = 5 BATCH_SIZE = 1 +DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE = 256 +DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE = 0 # set to 0 to disable long late chunking +DEFAULT_TRUNCATE_MAX_LENGTH = 8192 @click.command() @@ -37,9 +40,9 @@ ) @click.option( '--truncate-max-length', - default=None, + default=DEFAULT_TRUNCATE_MAX_LENGTH, type=int, - help='Maximum number of tokens; By default, no truncation is done.', + help='Maximum number of tokens; by default, truncation to 8192 tokens. If None, Long Late Chunking algorithm should be enabled.', ) @click.option( '--chunk-size', @@ -53,6 +56,18 @@ type=int, help='Number of sentences per chunk for sentence strategy.', ) +@click.option( + '--long-late-chunking-embed-size', + default=DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE, + type=int, + help='Token length of the embeddings that come before/after soft boundaries (i.e. overlapping embeddings). Above zero, overlap is used between neighbouring embeddings.', +) +@click.option( + '--long-late-chunking-overlap-size', + default=DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE, + type=int, + help='Number of tokens per chunk for fixed strategy.', +) def main( model_name, strategy, @@ -62,12 +77,18 @@ def main( truncate_max_length, chunk_size, n_sentences, + long_late_chunking_embed_size, + long_late_chunking_overlap_size ): try: task_cls = globals()[task_name] except: raise ValueError(f'Unknown task name: {task_name}') - + + if truncate_max_length is not None and (long_late_chunking_embed_size > 0): + truncate_max_length = None + print(f'Truncation is disabled because Long Late Chunking algorithm is enabled.') + model, has_instructions = load_model(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) @@ -92,6 +113,8 @@ def main( tokenizer=tokenizer, prune_size=None, truncate_max_length=truncate_max_length, + long_late_chunking_embed_size=long_late_chunking_embed_size, + long_late_chunking_overlap_size=long_late_chunking_overlap_size, **chunking_args, ) ]