diff --git a/FlagEmbedding/evaluation/air_bench/__main__.py b/FlagEmbedding/evaluation/air_bench/__main__.py index a875c9a7..8a2412c2 100644 --- a/FlagEmbedding/evaluation/air_bench/__main__.py +++ b/FlagEmbedding/evaluation/air_bench/__main__.py @@ -6,23 +6,27 @@ ) -parser = HfArgumentParser(( - AIRBenchEvalArgs, - AIRBenchEvalModelArgs -)) - -eval_args, model_args = parser.parse_args_into_dataclasses() -eval_args: AIRBenchEvalArgs -model_args: AIRBenchEvalModelArgs - -runner = AIRBenchEvalRunner( - eval_args=eval_args, - model_args=model_args -) +def main(): + parser = HfArgumentParser(( + AIRBenchEvalArgs, + AIRBenchEvalModelArgs + )) + + eval_args, model_args = parser.parse_args_into_dataclasses() + eval_args: AIRBenchEvalArgs + model_args: AIRBenchEvalModelArgs + + runner = AIRBenchEvalRunner( + eval_args=eval_args, + model_args=model_args + ) + + runner.run() -runner.run() -print("==============================================") -print("Search results have been generated.") -print("For computing metrics, please refer to the official AIR-Bench docs:") -print("- https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/submit_to_leaderboard.md") +if __name__ == "__main__": + main() + print("==============================================") + print("Search results have been generated.") + print("For computing metrics, please refer to the official AIR-Bench docs:") + print("- https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/submit_to_leaderboard.md") diff --git a/FlagEmbedding/evaluation/beir/__main__.py b/FlagEmbedding/evaluation/beir/__main__.py index 80228c15..37558ead 100644 --- a/FlagEmbedding/evaluation/beir/__main__.py +++ b/FlagEmbedding/evaluation/beir/__main__.py @@ -6,18 +6,23 @@ ) -parser = HfArgumentParser(( - BEIREvalArgs, - BEIREvalModelArgs -)) - -eval_args, model_args = parser.parse_args_into_dataclasses() -eval_args: BEIREvalArgs -model_args: BEIREvalModelArgs - -runner = BEIREvalRunner( - eval_args=eval_args, - model_args=model_args -) +def main(): + parser = HfArgumentParser(( + BEIREvalArgs, + BEIREvalModelArgs + )) + + eval_args, model_args = parser.parse_args_into_dataclasses() + eval_args: BEIREvalArgs + model_args: BEIREvalModelArgs + + runner = BEIREvalRunner( + eval_args=eval_args, + model_args=model_args + ) + + runner.run() + -runner.run() +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/evaluation/custom/__main__.py b/FlagEmbedding/evaluation/custom/__main__.py index 8279d19d..05aca236 100644 --- a/FlagEmbedding/evaluation/custom/__main__.py +++ b/FlagEmbedding/evaluation/custom/__main__.py @@ -6,18 +6,23 @@ ) -parser = HfArgumentParser(( - CustomEvalArgs, - CustomEvalModelArgs -)) - -eval_args, model_args = parser.parse_args_into_dataclasses() -eval_args: CustomEvalArgs -model_args: CustomEvalModelArgs - -runner = CustomEvalRunner( - eval_args=eval_args, - model_args=model_args -) +def main(): + parser = HfArgumentParser(( + CustomEvalArgs, + CustomEvalModelArgs + )) + + eval_args, model_args = parser.parse_args_into_dataclasses() + eval_args: CustomEvalArgs + model_args: CustomEvalModelArgs + + runner = CustomEvalRunner( + eval_args=eval_args, + model_args=model_args + ) + + runner.run() + -runner.run() +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/evaluation/miracl/__main__.py b/FlagEmbedding/evaluation/miracl/__main__.py index 0abdd0eb..811c3d79 100644 --- a/FlagEmbedding/evaluation/miracl/__main__.py +++ b/FlagEmbedding/evaluation/miracl/__main__.py @@ -6,18 +6,23 @@ ) -parser = HfArgumentParser(( - MIRACLEvalArgs, - MIRACLEvalModelArgs -)) - -eval_args, model_args = parser.parse_args_into_dataclasses() -eval_args: MIRACLEvalArgs -model_args: MIRACLEvalModelArgs - -runner = MIRACLEvalRunner( - eval_args=eval_args, - model_args=model_args -) +def main(): + parser = HfArgumentParser(( + MIRACLEvalArgs, + MIRACLEvalModelArgs + )) + + eval_args, model_args = parser.parse_args_into_dataclasses() + eval_args: MIRACLEvalArgs + model_args: MIRACLEvalModelArgs + + runner = MIRACLEvalRunner( + eval_args=eval_args, + model_args=model_args + ) + + runner.run() + -runner.run() +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/evaluation/mkqa/__main__.py b/FlagEmbedding/evaluation/mkqa/__main__.py index 5521671f..c335cbba 100644 --- a/FlagEmbedding/evaluation/mkqa/__main__.py +++ b/FlagEmbedding/evaluation/mkqa/__main__.py @@ -6,18 +6,23 @@ ) -parser = HfArgumentParser(( - MKQAEvalArgs, - MKQAEvalModelArgs -)) - -eval_args, model_args = parser.parse_args_into_dataclasses() -eval_args: MKQAEvalArgs -model_args: MKQAEvalModelArgs - -runner = MKQAEvalRunner( - eval_args=eval_args, - model_args=model_args -) +def main(): + parser = HfArgumentParser(( + MKQAEvalArgs, + MKQAEvalModelArgs + )) + + eval_args, model_args = parser.parse_args_into_dataclasses() + eval_args: MKQAEvalArgs + model_args: MKQAEvalModelArgs + + runner = MKQAEvalRunner( + eval_args=eval_args, + model_args=model_args + ) + + runner.run() + -runner.run() +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/evaluation/mldr/__main__.py b/FlagEmbedding/evaluation/mldr/__main__.py index 8cbd5c7b..2ce7c9ac 100644 --- a/FlagEmbedding/evaluation/mldr/__main__.py +++ b/FlagEmbedding/evaluation/mldr/__main__.py @@ -6,18 +6,23 @@ ) -parser = HfArgumentParser(( - MLDREvalArgs, - MLDREvalModelArgs -)) - -eval_args, model_args = parser.parse_args_into_dataclasses() -eval_args: MLDREvalArgs -model_args: MLDREvalModelArgs - -runner = MLDREvalRunner( - eval_args=eval_args, - model_args=model_args -) +def main(): + parser = HfArgumentParser(( + MLDREvalArgs, + MLDREvalModelArgs + )) + + eval_args, model_args = parser.parse_args_into_dataclasses() + eval_args: MLDREvalArgs + model_args: MLDREvalModelArgs + + runner = MLDREvalRunner( + eval_args=eval_args, + model_args=model_args + ) + + runner.run() + -runner.run() +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/evaluation/msmarco/__main__.py b/FlagEmbedding/evaluation/msmarco/__main__.py index c2f2f946..5d22ec48 100644 --- a/FlagEmbedding/evaluation/msmarco/__main__.py +++ b/FlagEmbedding/evaluation/msmarco/__main__.py @@ -6,18 +6,23 @@ ) -parser = HfArgumentParser(( - MSMARCOEvalArgs, - MSMARCOEvalModelArgs -)) - -eval_args, model_args = parser.parse_args_into_dataclasses() -eval_args: MSMARCOEvalArgs -model_args: MSMARCOEvalModelArgs - -runner = MSMARCOEvalRunner( - eval_args=eval_args, - model_args=model_args -) +def main(): + parser = HfArgumentParser(( + MSMARCOEvalArgs, + MSMARCOEvalModelArgs + )) + + eval_args, model_args = parser.parse_args_into_dataclasses() + eval_args: MSMARCOEvalArgs + model_args: MSMARCOEvalModelArgs + + runner = MSMARCOEvalRunner( + eval_args=eval_args, + model_args=model_args + ) + + runner.run() + -runner.run() +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/evaluation/mteb/__main__.py b/FlagEmbedding/evaluation/mteb/__main__.py index 2887e488..e906862f 100644 --- a/FlagEmbedding/evaluation/mteb/__main__.py +++ b/FlagEmbedding/evaluation/mteb/__main__.py @@ -6,18 +6,23 @@ ) -parser = HfArgumentParser(( - MTEBEvalArgs, - MTEBEvalModelArgs -)) - -eval_args, model_args = parser.parse_args_into_dataclasses() -eval_args: MTEBEvalArgs -model_args: MTEBEvalModelArgs - -runner = MTEBEvalRunner( - eval_args=eval_args, - model_args=model_args -) +def main(): + parser = HfArgumentParser(( + MTEBEvalArgs, + MTEBEvalModelArgs + )) + + eval_args, model_args = parser.parse_args_into_dataclasses() + eval_args: MTEBEvalArgs + model_args: MTEBEvalModelArgs + + runner = MTEBEvalRunner( + eval_args=eval_args, + model_args=model_args + ) + + runner.run() + -runner.run() +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/embedder/decoder_only/base/__main__.py b/FlagEmbedding/finetune/embedder/decoder_only/base/__main__.py index a87428a2..bfef14d8 100644 --- a/FlagEmbedding/finetune/embedder/decoder_only/base/__main__.py +++ b/FlagEmbedding/finetune/embedder/decoder_only/base/__main__.py @@ -8,19 +8,24 @@ ) -parser = HfArgumentParser(( - DecoderOnlyEmbedderModelArguments, - DecoderOnlyEmbedderDataArguments, - DecoderOnlyEmbedderTrainingArguments -)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: DecoderOnlyEmbedderModelArguments -data_args: DecoderOnlyEmbedderDataArguments -training_args: DecoderOnlyEmbedderTrainingArguments +def main(): + parser = HfArgumentParser(( + DecoderOnlyEmbedderModelArguments, + DecoderOnlyEmbedderDataArguments, + DecoderOnlyEmbedderTrainingArguments + )) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: DecoderOnlyEmbedderModelArguments + data_args: DecoderOnlyEmbedderDataArguments + training_args: DecoderOnlyEmbedderTrainingArguments -runner = DecoderOnlyEmbedderRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() + runner = DecoderOnlyEmbedderRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/embedder/decoder_only/icl/__main__.py b/FlagEmbedding/finetune/embedder/decoder_only/icl/__main__.py index 4354f440..58fe46ae 100644 --- a/FlagEmbedding/finetune/embedder/decoder_only/icl/__main__.py +++ b/FlagEmbedding/finetune/embedder/decoder_only/icl/__main__.py @@ -8,19 +8,24 @@ ) -parser = HfArgumentParser(( - DecoderOnlyEmbedderICLModelArguments, - DecoderOnlyEmbedderICLDataArguments, - DecoderOnlyEmbedderICLTrainingArguments -)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: DecoderOnlyEmbedderICLModelArguments -data_args: DecoderOnlyEmbedderICLDataArguments -training_args: DecoderOnlyEmbedderICLTrainingArguments +def main(): + parser = HfArgumentParser(( + DecoderOnlyEmbedderICLModelArguments, + DecoderOnlyEmbedderICLDataArguments, + DecoderOnlyEmbedderICLTrainingArguments + )) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: DecoderOnlyEmbedderICLModelArguments + data_args: DecoderOnlyEmbedderICLDataArguments + training_args: DecoderOnlyEmbedderICLTrainingArguments -runner = DecoderOnlyEmbedderICLRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() + runner = DecoderOnlyEmbedderICLRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py b/FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py index a31a8ef8..5c43abfe 100644 --- a/FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py +++ b/FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py @@ -8,19 +8,24 @@ ) -parser = HfArgumentParser(( - EncoderOnlyEmbedderModelArguments, - EncoderOnlyEmbedderDataArguments, - EncoderOnlyEmbedderTrainingArguments -)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: EncoderOnlyEmbedderModelArguments -data_args: EncoderOnlyEmbedderDataArguments -training_args: EncoderOnlyEmbedderTrainingArguments +def main(): + parser = HfArgumentParser(( + EncoderOnlyEmbedderModelArguments, + EncoderOnlyEmbedderDataArguments, + EncoderOnlyEmbedderTrainingArguments + )) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: EncoderOnlyEmbedderModelArguments + data_args: EncoderOnlyEmbedderDataArguments + training_args: EncoderOnlyEmbedderTrainingArguments -runner = EncoderOnlyEmbedderRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() + runner = EncoderOnlyEmbedderRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/embedder/encoder_only/m3/__main__.py b/FlagEmbedding/finetune/embedder/encoder_only/m3/__main__.py index cd556573..5a8cba7a 100644 --- a/FlagEmbedding/finetune/embedder/encoder_only/m3/__main__.py +++ b/FlagEmbedding/finetune/embedder/encoder_only/m3/__main__.py @@ -8,15 +8,20 @@ ) -parser = HfArgumentParser((EncoderOnlyEmbedderM3ModelArguments, EncoderOnlyEmbedderM3DataArguments, EncoderOnlyEmbedderM3TrainingArguments)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: EncoderOnlyEmbedderM3ModelArguments -data_args: EncoderOnlyEmbedderM3DataArguments -training_args: EncoderOnlyEmbedderM3TrainingArguments +def main(): + parser = HfArgumentParser((EncoderOnlyEmbedderM3ModelArguments, EncoderOnlyEmbedderM3DataArguments, EncoderOnlyEmbedderM3TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: EncoderOnlyEmbedderM3ModelArguments + data_args: EncoderOnlyEmbedderM3DataArguments + training_args: EncoderOnlyEmbedderM3TrainingArguments -runner = EncoderOnlyEmbedderM3Runner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() + runner = EncoderOnlyEmbedderM3Runner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/reranker/decoder_only/base/__init__.py b/FlagEmbedding/finetune/reranker/decoder_only/base/__init__.py index e3a4a063..3243ecb6 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/base/__init__.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/base/__init__.py @@ -6,5 +6,6 @@ __all__ = [ "CrossDecoderModel", "DecoderOnlyRerankerRunner", - "DecoderOnlyRerankerTrainer" + "DecoderOnlyRerankerTrainer", + "RerankerModelArguments", ] diff --git a/FlagEmbedding/finetune/reranker/decoder_only/base/__main__.py b/FlagEmbedding/finetune/reranker/decoder_only/base/__main__.py index f0cf400e..447e6dc7 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/base/__main__.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/base/__main__.py @@ -5,18 +5,26 @@ AbsRerankerTrainingArguments ) -from FlagEmbedding.finetune.reranker.decoder_only.base.runner import DecoderOnlyRerankerRunner -from FlagEmbedding.finetune.reranker.decoder_only.base.arguments import RerankerModelArguments +from FlagEmbedding.finetune.reranker.decoder_only.base import ( + DecoderOnlyRerankerRunner, + RerankerModelArguments +) -parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: RerankerModelArguments -data_args: AbsRerankerDataArguments -training_args: AbsRerankerTrainingArguments -runner = DecoderOnlyRerankerRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() +def main(): + parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: RerankerModelArguments + data_args: AbsRerankerDataArguments + training_args: AbsRerankerTrainingArguments + + runner = DecoderOnlyRerankerRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__init__.py b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__init__.py index e3a4a063..3243ecb6 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__init__.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__init__.py @@ -6,5 +6,6 @@ __all__ = [ "CrossDecoderModel", "DecoderOnlyRerankerRunner", - "DecoderOnlyRerankerTrainer" + "DecoderOnlyRerankerTrainer", + "RerankerModelArguments", ] diff --git a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__main__.py b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__main__.py index d1077fbb..64774fc0 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__main__.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__main__.py @@ -1,23 +1,30 @@ from transformers import HfArgumentParser from FlagEmbedding.abc.finetune.reranker import ( - AbsRerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments ) -from FlagEmbedding.finetune.reranker.decoder_only.layerwise.runner import DecoderOnlyRerankerRunner -from FlagEmbedding.finetune.reranker.decoder_only.layerwise.arguments import RerankerModelArguments +from FlagEmbedding.finetune.reranker.decoder_only.layerwise import ( + DecoderOnlyRerankerRunner, + RerankerModelArguments +) -parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: RerankerModelArguments -data_args: AbsRerankerDataArguments -training_args: AbsRerankerTrainingArguments -runner = DecoderOnlyRerankerRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() +def main(): + parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: RerankerModelArguments + data_args: AbsRerankerDataArguments + training_args: AbsRerankerTrainingArguments + + runner = DecoderOnlyRerankerRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/reranker/encoder_only/base/__main__.py b/FlagEmbedding/finetune/reranker/encoder_only/base/__main__.py index 0cccfe93..76778d65 100644 --- a/FlagEmbedding/finetune/reranker/encoder_only/base/__main__.py +++ b/FlagEmbedding/finetune/reranker/encoder_only/base/__main__.py @@ -5,18 +5,23 @@ AbsRerankerDataArguments, AbsRerankerTrainingArguments ) -from FlagEmbedding.finetune.reranker.encoder_only.base.runner import EncoderOnlyRerankerRunner +from FlagEmbedding.finetune.reranker.encoder_only.base import EncoderOnlyRerankerRunner -parser = HfArgumentParser((AbsRerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: AbsRerankerModelArguments -data_args: AbsRerankerDataArguments -training_args: AbsRerankerTrainingArguments +def main(): + parser = HfArgumentParser((AbsRerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: AbsRerankerModelArguments + data_args: AbsRerankerDataArguments + training_args: AbsRerankerTrainingArguments -runner = EncoderOnlyRerankerRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() + runner = EncoderOnlyRerankerRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/reranker/encoder_only/base/trainer.py b/FlagEmbedding/finetune/reranker/encoder_only/base/trainer.py index c7f0a09b..59120341 100644 --- a/FlagEmbedding/finetune/reranker/encoder_only/base/trainer.py +++ b/FlagEmbedding/finetune/reranker/encoder_only/base/trainer.py @@ -29,7 +29,6 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): if not hasattr(self.model, 'save_pretrained'): raise NotImplementedError(f'MODEL {self.model.__class__.__name__} ' f'does not support save_pretrained interface') else: - print(self.model) self.model.save_pretrained(output_dir) if self.tokenizer is not None and self.is_world_process_zero(): self.tokenizer.save_pretrained(output_dir)