diff --git a/chunked_pooling/mteb_chunked_eval.py b/chunked_pooling/mteb_chunked_eval.py index 2433e7f..b119deb 100644 --- a/chunked_pooling/mteb_chunked_eval.py +++ b/chunked_pooling/mteb_chunked_eval.py @@ -26,6 +26,7 @@ def __init__( n_sentences: Optional[int] = None, model_has_instructions: bool = False, embedding_model_name: Optional[str] = None, # for semantic chunking + truncate_max_length: Optional[int] = 8192, **kwargs, ): super().__init__(**kwargs) @@ -48,6 +49,7 @@ def __init__( 'n_sentences': n_sentences, 'embedding_model_name': embedding_model_name, } + self.truncate_max_length = truncate_max_length def load_data(self, **kwargs): self.retrieval_task.load_data(**kwargs) @@ -97,6 +99,21 @@ def evaluate( return scores + def _truncate_documents(self, corpus): + for k, v in corpus.items(): + if 'title' in v: + raise NotImplementedError( + 'Currently truncation is only implemented for documents without titles' + ) + tokens = self.tokenizer( + v['text'], + return_offsets_mapping=True, + max_length=self.truncate_max_length, + ) + last_token_span = tokens.offset_mapping[-2] + v['text'] = v['text'][: last_token_span[1]] + return corpus + def _evaluate_monolingual( self, model, @@ -108,6 +125,8 @@ def _evaluate_monolingual( encode_kwargs=None, **kwargs, ): + if self.truncate_max_length: + corpus = self._truncate_documents(corpus) # split corpus into chunks if not self.chunked_pooling_enabled: corpus = self._apply_chunking(corpus, self.tokenizer) diff --git a/run_chunked_eval.py b/run_chunked_eval.py index ff49da0..88494bd 100644 --- a/run_chunked_eval.py +++ b/run_chunked_eval.py @@ -35,7 +35,34 @@ required=False, help='The name of the model used for semantic chunking.', ) -def main(model_name, strategy, task_name, eval_split, chunking_model): +@click.option( + '--truncate-max-length', + default=None, + type=int, + help='Maximum number of tokens; By default, no truncation is done.', +) +@click.option( + '--chunk-size', + default=DEFAULT_CHUNK_SIZE, + type=int, + help='Number of tokens per chunk for fixed strategy.', +) +@click.option( + '--n-sentences', + default=DEFAULT_N_SENTENCES, + type=int, + help='Number of sentences per chunk for sentence strategy.', +) +def main( + model_name, + strategy, + task_name, + eval_split, + chunking_model, + truncate_max_length, + chunk_size, + n_sentences, +): try: task_cls = globals()[task_name] except: @@ -46,8 +73,8 @@ def main(model_name, strategy, task_name, eval_split, chunking_model): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) chunking_args = { - 'chunk_size': DEFAULT_CHUNK_SIZE, - 'n_sentences': DEFAULT_N_SENTENCES, + 'chunk_size': chunk_size, + 'n_sentences': n_sentences, 'chunking_strategy': strategy, 'model_has_instructions': has_instructions, 'embedding_model_name': chunking_model if chunking_model else model_name, @@ -64,6 +91,7 @@ def main(model_name, strategy, task_name, eval_split, chunking_model): chunked_pooling_enabled=True, tokenizer=tokenizer, prune_size=None, + truncate_max_length=truncate_max_length, **chunking_args, ) ] @@ -90,6 +118,7 @@ def main(model_name, strategy, task_name, eval_split, chunking_model): chunked_pooling_enabled=False, tokenizer=tokenizer, prune_size=None, + truncate_max_length=truncate_max_length, **chunking_args, ) ]