Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
harborn committed Jul 8, 2024
1 parent de0eb21 commit 6d7a06c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
24 changes: 24 additions & 0 deletions llm_on_ray/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import torch
import datasets
import evaluate
import transformers

from peft import get_peft_model, LoraConfig
Expand Down Expand Up @@ -340,6 +341,25 @@ def get_trainer(config: Dict, training_args, model, tokenizer, tokenized_dataset
gaudi_config.use_fused_adam = True
gaudi_config.use_fused_clip_norm = True

def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
result = logits.argmax(dim=-1)
return result

metric = evaluate.load("accuracy")

def compute_metrics(eval_preds):
preds, labels = eval_preds
# preds have the same shape as the labels, after the argmax(-1) has been calculated
# by preprocess_logits_for_metrics but we need to shift the labels
labels = labels[:, 1:].reshape(-1)
preds = preds[:, :-1].reshape(-1)
result = metric.compute(predictions=preds, references=labels)
return result

trainer = GaudiTrainer(
model=model,
args=training_args,
Expand All @@ -350,6 +370,10 @@ def get_trainer(config: Dict, training_args, model, tokenizer, tokenized_dataset
else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval
else None,
)
return trainer
return None
Expand Down
2 changes: 1 addition & 1 deletion llm_on_ray/finetune/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Dataset(BaseModel):
truncation_side: str = "right"
max_seq_length: int = 512
truncation: bool = True
padding: bool = True
padding: bool = False
mask_input: bool = True
mask_response: bool = True
data_preprocess_type: str = "default"
Expand Down

0 comments on commit 6d7a06c

Please sign in to comment.