From 6d7a06cc2ff1b86180b40f4054e843b02d6722b9 Mon Sep 17 00:00:00 2001 From: "Wu, Gangsheng" Date: Mon, 8 Jul 2024 18:10:37 +0000 Subject: [PATCH] update --- llm_on_ray/finetune/finetune.py | 24 ++++++++++++++++++++++++ llm_on_ray/finetune/finetune_config.py | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py index 84b7700b..8f7fee20 100644 --- a/llm_on_ray/finetune/finetune.py +++ b/llm_on_ray/finetune/finetune.py @@ -28,6 +28,7 @@ import torch import datasets +import evaluate import transformers from peft import get_peft_model, LoraConfig @@ -340,6 +341,25 @@ def get_trainer(config: Dict, training_args, model, tokenizer, tokenized_dataset gaudi_config.use_fused_adam = True gaudi_config.use_fused_clip_norm = True + def preprocess_logits_for_metrics(logits, labels): + if isinstance(logits, tuple): + # Depending on the model and config, logits may contain extra tensors, + # like past_key_values, but logits always come first + logits = logits[0] + result = logits.argmax(dim=-1) + return result + + metric = evaluate.load("accuracy") + + def compute_metrics(eval_preds): + preds, labels = eval_preds + # preds have the same shape as the labels, after the argmax(-1) has been calculated + # by preprocess_logits_for_metrics but we need to shift the labels + labels = labels[:, 1:].reshape(-1) + preds = preds[:, :-1].reshape(-1) + result = metric.compute(predictions=preds, references=labels) + return result + trainer = GaudiTrainer( model=model, args=training_args, @@ -350,6 +370,10 @@ def get_trainer(config: Dict, training_args, model, tokenizer, tokenized_dataset else None, tokenizer=tokenizer, data_collator=data_collator, + compute_metrics=compute_metrics if training_args.do_eval else None, + preprocess_logits_for_metrics=preprocess_logits_for_metrics + if training_args.do_eval + else None, ) return trainer return None diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index 96fe0323..c19bda34 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -78,7 +78,7 @@ class Dataset(BaseModel): truncation_side: str = "right" max_seq_length: int = 512 truncation: bool = True - padding: bool = True + padding: bool = False mask_input: bool = True mask_response: bool = True data_preprocess_type: str = "default"