Skip to content

Commit

Permalink
Merge branch 'main' of github.com:hao-ai-lab/Consistency_LLM into main
Browse files Browse the repository at this point in the history
  • Loading branch information
snyhlxde1 committed May 13, 2024
2 parents 24dfe45 + c89fc98 commit 9fa4d29
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
6 changes: 5 additions & 1 deletion cllm/cllm_trainer_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, *args, **kwargs):
args = kwargs["args"]
self.train_step_cnt = 0
self.max_new_tokens = args.max_new_tokens
self.use_gt_labels = args.use_gt_labels

def training_step(self, model, inputs):
self.train_step_cnt += 1
Expand Down Expand Up @@ -45,7 +46,10 @@ def consistency_training_step(self, model, inputs):

### compute AutoRegression loss ###
# use labels to avoid pattern collapse
labels = inputs['complete_teacher_output_ids']
if self.use_gt_labels:
labels = inputs['labels_ids']
else:
labels = inputs['complete_teacher_output_ids']
# TODO: check if it's right when batch size > 1
labels = torch.tensor(labels).to(model.device)
attention_mask = torch.full_like(labels, 1).to(model.device)
Expand Down
1 change: 1 addition & 0 deletions cllm/train_cllm_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class TrainingArguments(transformers.TrainingArguments):
"help": "Size of n_token_sequence in Jacobi trajectory."
},
)
use_gt_labels: bool = True
report_to: str = field(
default='wandb',
metadata={
Expand Down

0 comments on commit 9fa4d29

Please sign in to comment.