Skip to content

Commit

Permalink
Fix labels & eos_token for SFT
Browse files Browse the repository at this point in the history
  • Loading branch information
lijiahao committed Nov 28, 2023
1 parent b116838 commit e7736e5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
3 changes: 2 additions & 1 deletion applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def __getitem__(self, idx):
return {
"input_ids": self.chosen_dataset[idx]["input_ids"],
"attention_mask": self.chosen_dataset[idx]["attention_mask"],
"labels": self.chosen_dataset[idx]["input_ids"]
"labels": torch.where(self.chosen_dataset[idx]["attention_mask"].bool(),
self.chosen_dataset[idx]["input_ids"], -100)
}
elif self.train_phase == 2:
return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def main():
args.seed,
tokenizer,
args.max_seq_len,
end_of_conversation_token=tokenizer.eos_token,
sft_only_data_path=args.sft_only_data_path)
# DataLoaders creation:
if args.local_rank == -1:
Expand Down

0 comments on commit e7736e5

Please sign in to comment.