Skip to content

Commit

Permalink
fix peft training
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jan 15, 2025
1 parent 8ae06b1 commit cc2b7b9
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,21 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
use_cache=False,
)
ref_hidden_states = ref_decoder_outputs.last_hidden_state
elif not self.reference_free:
with self.null_ref_context():
ref_encoder_outputs = model.get_encoder()(
concatenated_batch["prompt_input_ids"],
attention_mask=concatenated_batch["prompt_attention_mask"],
return_dict=True,
)
ref_decoder_outputs = model.get_decoder()(
input_ids=decoder_input_ids,
attention_mask=concatenated_batch["completion_attention_mask"],
encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
use_cache=False,
)
ref_hidden_states = ref_decoder_outputs.last_hidden_state

labels = concatenated_batch["completion_input_ids"]
else:
Expand Down Expand Up @@ -1210,6 +1225,19 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
**model_kwargs,
)
ref_hidden_states = ref_outputs.last_hidden_state[:, :-1]
elif not self.reference_free:
if hasattr(model, "get_decoder"):
ref_base_model = model.get_decoder()
else:
ref_base_model = getattr(model, self.args.base_model_attribute_name, model)
with self.null_ref_context():
ref_outputs = ref_base_model(
input_ids,
attention_mask=attention_mask,
use_cache=False,
**model_kwargs,
)
ref_hidden_states = ref_outputs.last_hidden_state[:, :-1]

labels = input_ids[:, 1:] # Shift right for casual LM

Expand All @@ -1219,8 +1247,12 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
# Get reference model weights if needed
ref_weight = None
ref_bias = None
if not self.reference_free and self.ref_model is not None:
ref_lm_head = self.ref_model.get_output_embeddings()
if not self.reference_free:
if self.ref_model is not None:
ref_lm_head = self.ref_model.get_output_embeddings()
else:
with self.null_ref_context():
ref_lm_head = model.get_output_embeddings()
ref_weight = ref_lm_head.weight
ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None

Expand Down

0 comments on commit cc2b7b9

Please sign in to comment.