Skip to content

Commit

Permalink
feat: add additional cmd args
Browse files Browse the repository at this point in the history
  • Loading branch information
guenthermi committed Sep 25, 2024
1 parent 11ad8d9 commit b66a13c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
19 changes: 19 additions & 0 deletions chunked_pooling/mteb_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
n_sentences: Optional[int] = None,
model_has_instructions: bool = False,
embedding_model_name: Optional[str] = None, # for semantic chunking
truncate_max_length: Optional[int] = 8192,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -48,6 +49,7 @@ def __init__(
'n_sentences': n_sentences,
'embedding_model_name': embedding_model_name,
}
self.truncate_max_length = truncate_max_length

def load_data(self, **kwargs):
self.retrieval_task.load_data(**kwargs)
Expand Down Expand Up @@ -97,6 +99,21 @@ def evaluate(

return scores

def _truncate_documents(self, corpus):
for k, v in corpus.items():
if 'title' in v:
raise NotImplementedError(
'Currently truncation is only implemented for documents without titles'
)
tokens = self.tokenizer(
v['text'],
return_offsets_mapping=True,
max_length=self.truncate_max_length,
)
last_token_span = tokens.offset_mapping[-2]
v['text'] = v['text'][: last_token_span[1]]
return corpus

def _evaluate_monolingual(
self,
model,
Expand All @@ -108,6 +125,8 @@ def _evaluate_monolingual(
encode_kwargs=None,
**kwargs,
):
if self.truncate_max_length:
corpus = self._truncate_documents(corpus)
# split corpus into chunks
if not self.chunked_pooling_enabled:
corpus = self._apply_chunking(corpus, self.tokenizer)
Expand Down
26 changes: 24 additions & 2 deletions run_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,27 @@
required=False,
help='The name of the model used for semantic chunking.',
)
def main(model_name, strategy, task_name, eval_split, chunking_model):
@click.option(
'--truncate-max-length',
default=None,
type=int,
help='Maximum number of tokens; By default, no truncation is done.',
)
@click.option(
'--chunk-size',
default=DEFAULT_CHUNK_SIZE,
type=int,
help='Number of tokens per chunk for fixed strategy.',
)
def main(
model_name,
strategy,
task_name,
eval_split,
chunking_model,
truncate_max_length,
chunk_size,
):
try:
task_cls = globals()[task_name]
except:
Expand All @@ -46,7 +66,7 @@ def main(model_name, strategy, task_name, eval_split, chunking_model):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

chunking_args = {
'chunk_size': DEFAULT_CHUNK_SIZE,
'chunk_size': chunk_size,
'n_sentences': DEFAULT_N_SENTENCES,
'chunking_strategy': strategy,
'model_has_instructions': has_instructions,
Expand All @@ -64,6 +84,7 @@ def main(model_name, strategy, task_name, eval_split, chunking_model):
chunked_pooling_enabled=True,
tokenizer=tokenizer,
prune_size=None,
truncate_max_length=truncate_max_length,
**chunking_args,
)
]
Expand All @@ -90,6 +111,7 @@ def main(model_name, strategy, task_name, eval_split, chunking_model):
chunked_pooling_enabled=False,
tokenizer=tokenizer,
prune_size=None,
truncate_max_length=truncate_max_length,
**chunking_args,
)
]
Expand Down

0 comments on commit b66a13c

Please sign in to comment.