diff --git a/chunked_pooling/mteb_chunked_eval.py b/chunked_pooling/mteb_chunked_eval.py index 827fc1c..e522bac 100644 --- a/chunked_pooling/mteb_chunked_eval.py +++ b/chunked_pooling/mteb_chunked_eval.py @@ -106,26 +106,35 @@ def evaluate( def _truncate_documents(self, corpus): for k, v in corpus.items(): + title_tokens = 0 if 'title' in v: - raise NotImplementedError( - 'Currently truncation is only implemented for documents without titles' + tokens = self.tokenizer( + v['title'] + ' ', + return_offsets_mapping=True, + max_length=self.truncate_max_length, ) + title_tokens = len(tokens.input_ids) tokens = self.tokenizer( v['text'], return_offsets_mapping=True, - max_length=self.truncate_max_length, + max_length=self.truncate_max_length - title_tokens, ) last_token_span = tokens.offset_mapping[-2] v['text'] = v['text'][: last_token_span[1]] return corpus def _embed_with_overlap(self, model, model_inputs): - + len_tokens = len(model_inputs["input_ids"][0]) - + if len_tokens > self.long_late_chunking_embed_size: indices = [] - for i in range(0, len_tokens, self.long_late_chunking_embed_size - self.long_late_chunking_overlap_size): + for i in range( + 0, + len_tokens, + self.long_late_chunking_embed_size + - self.long_late_chunking_overlap_size, + ): start = i end = min(i + self.long_late_chunking_embed_size, len_tokens) indices.append((start, end)) @@ -138,10 +147,12 @@ def _embed_with_overlap(self, model, model_inputs): batch_inputs = {k: v[:, start:end] for k, v in model_inputs.items()} with torch.no_grad(): - model_output = model(**batch_inputs) + model_output = model(**batch_inputs) if start > 0: - outputs.append(model_output[0][:, self.long_late_chunking_overlap_size:]) + outputs.append( + model_output[0][:, self.long_late_chunking_overlap_size :] + ) else: outputs.append(model_output[0]) @@ -227,10 +238,12 @@ def _evaluate_monolingual( output_embs = chunked_pooling( [model_outputs], annotations, max_length=None ) - else: # truncation + else: # truncation model_outputs = model(**model_inputs) output_embs = chunked_pooling( - model_outputs, annotations, max_length=self.truncate_max_length + model_outputs, + annotations, + max_length=self.truncate_max_length, ) corpus_embs.extend(output_embs) diff --git a/run_chunked_eval.py b/run_chunked_eval.py index d2d913d..95de94a 100644 --- a/run_chunked_eval.py +++ b/run_chunked_eval.py @@ -11,8 +11,8 @@ DEFAULT_N_SENTENCES = 5 BATCH_SIZE = 1 DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE = 256 -DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE = 0 # set to 0 to disable long late chunking -DEFAULT_TRUNCATE_MAX_LENGTH = 8192 +DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE = 0 # set to 0 to disable long late chunking +DEFAULT_TRUNCATE_MAX_LENGTH = None @click.command() @@ -65,13 +65,13 @@ '--long-late-chunking-embed-size', default=DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE, type=int, - help='Token length of the embeddings that come before/after soft boundaries (i.e. overlapping embeddings). Above zero, overlap is used between neighbouring embeddings.', + help='Number of tokens per chunk for fixed strategy.', ) @click.option( '--long-late-chunking-overlap-size', default=DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE, type=int, - help='Number of tokens per chunk for fixed strategy.', + help='Token length of the embeddings that come before/after soft boundaries (i.e. overlapping embeddings). Above zero, overlap is used between neighbouring embeddings.', ) def main( model_name, @@ -84,17 +84,19 @@ def main( chunk_size, n_sentences, long_late_chunking_embed_size, - long_late_chunking_overlap_size + long_late_chunking_overlap_size, ): try: task_cls = globals()[task_name] except: raise ValueError(f'Unknown task name: {task_name}') - + if truncate_max_length is not None and (long_late_chunking_embed_size > 0): truncate_max_length = None - print(f'Truncation is disabled because Long Late Chunking algorithm is enabled.') - + print( + f'Truncation is disabled because Long Late Chunking algorithm is enabled.' + ) + model, has_instructions = load_model(model_name, model_weights) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)