diff --git a/cllm/cllm_trainer_global.py b/cllm/cllm_trainer_global.py index 9580f86..4644d63 100644 --- a/cllm/cllm_trainer_global.py +++ b/cllm/cllm_trainer_global.py @@ -14,6 +14,7 @@ def __init__(self, *args, **kwargs): args = kwargs["args"] self.train_step_cnt = 0 self.max_new_tokens = args.max_new_tokens + self.use_gt_labels = args.use_gt_labels def training_step(self, model, inputs): self.train_step_cnt += 1 @@ -45,7 +46,10 @@ def consistency_training_step(self, model, inputs): ### compute AutoRegression loss ### # use labels to avoid pattern collapse - labels = inputs['complete_teacher_output_ids'] + if self.use_gt_labels: + labels = inputs['labels_ids'] + else: + labels = inputs['complete_teacher_output_ids'] # TODO: check if it's right when batch size > 1 labels = torch.tensor(labels).to(model.device) attention_mask = torch.full_like(labels, 1).to(model.device) diff --git a/cllm/train_cllm_global.py b/cllm/train_cllm_global.py index 7663d4f..53563e8 100644 --- a/cllm/train_cllm_global.py +++ b/cllm/train_cllm_global.py @@ -68,6 +68,7 @@ class TrainingArguments(transformers.TrainingArguments): "help": "Size of n_token_sequence in Jacobi trajectory." }, ) + use_gt_labels: bool = True report_to: str = field( default='wandb', metadata={