Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
DrownFish19 committed Feb 27, 2025
1 parent 8dbc396 commit 87ba5d7
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 54 deletions.
6 changes: 3 additions & 3 deletions llm/alignment/ppo/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from comm_utils import offload_tensor_to_cpu, reload_tensor_to_gpu
from paddle.utils import try_import
from predict.predictor import (
InferencePredictorMixin,
DygraphInferencePredictor,
PdArgumentParser,
PredictorArgument,
)
Expand Down Expand Up @@ -53,8 +53,8 @@ def __init__(self, config, model: PretrainedModel = None, tokenizer: PretrainedT
# multi time prediction. define caches and extra inputs creation method
# instead of using predictor.__init__

self._buffer_maker = types.MethodType(InferencePredictorMixin.__init__, self)
self._inputs_processer = types.MethodType(InferencePredictorMixin._preprocess, self)
self._buffer_maker = types.MethodType(DygraphInferencePredictor.__init__, self)
self._inputs_processer = types.MethodType(DygraphInferencePredictor._preprocess, self)

@staticmethod
def create_predictor(trainer):
Expand Down
98 changes: 48 additions & 50 deletions llm/alignment/ppo/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
)
# criterion is only used for non-PipelineParallel models. criterion is
# included in model for PipelineParallel.
self.info_buffer = {}
if getattr(self, "loss_cls", None) and self.criterion is None:
self.criterion = self.create_criterion()

Expand All @@ -150,7 +151,7 @@ def create_criterion(self):
whose label arguments are merged into one argument, this is useful to
PipelineParallel and trainer.criterion which limit loss format.
"""
criterion = create_loss(self.loss_cls, self.model.config, self.args, merge_labels=True)
criterion = create_loss(self.loss_cls, self.model.config, self.args, self.info_buffer, merge_labels=True)
return criterion

def loss_identifier(self, inputs: Dict) -> str:
Expand Down Expand Up @@ -1004,7 +1005,6 @@ def __init__(
), # workaround for pipeline parallel model check
},
):

self.reference_trainer = StepTrainer(
reference_model,
criterion,
Expand Down Expand Up @@ -1240,9 +1240,7 @@ def prediction_step(
padding="max_length",
max_length=self._model_config.max_sequence_length,
return_attention_mask=False,
)[
"input_ids"
] # pad to max_sequence_length
)["input_ids"] # pad to max_sequence_length
else:
seq = generated_seq

Expand Down Expand Up @@ -1272,9 +1270,7 @@ def prediction_step(
attention_mask=reward_attention_mask,
position_ids=reward_position_ids,
# return_dict=True,
)[
1
] # .end_scores
)[1] # .end_scores
else:
prompt_len = inputs["input_ids"].shape[-1]
if "label_ids" not in inputs:
Expand Down Expand Up @@ -1552,12 +1548,12 @@ def gen_epoch_data():
self.prompt_only_dataloader,
itertools.cycle(self.ptx_dataloader),
):

# generate batches
self.set_eval()

with ema(self.policy_trainer), (
ema(self.value_trainer) if self.args.rl_algorithm == "ppo" else contextlib.nullcontext()
with (
ema(self.policy_trainer),
ema(self.value_trainer) if self.args.rl_algorithm == "ppo" else contextlib.nullcontext(),
):
with guard_set_args(self._model_config, {"use_fused_head_and_loss_fn": False}):
rl_batches = self.split_rl_micro_batches(prompt_only_batch)
Expand Down Expand Up @@ -1708,34 +1704,40 @@ def train(

# ##### trainging data and related num setting #####
# TODO(guosheng): remove the binding method get_collator of dataset
with guard_set_args(
args,
{"per_device_train_batch_size": self.args.per_device_prompt_batch_size},
), guard_set_args(
self,
{
"train_dataset": self.train_dataset,
"data_collator": self.train_dataset.get_collator(),
},
):
train_dataloader = self.prompt_only_dataloader = self.get_train_dataloader()

if self.use_ptx:
with guard_set_args(
with (
guard_set_args(
args,
{
"per_device_train_batch_size": (
1
if getattr(self.ptx_dataset, "is_intokens", False)
else self.args.per_device_prompt_batch_size * self.args.num_return_sequences
)
},
), guard_set_args(
{"per_device_train_batch_size": self.args.per_device_prompt_batch_size},
),
guard_set_args(
self,
{
"train_dataset": self.ptx_dataset,
"data_collator": self.ptx_dataset.get_collator(),
"train_dataset": self.train_dataset,
"data_collator": self.train_dataset.get_collator(),
},
),
):
train_dataloader = self.prompt_only_dataloader = self.get_train_dataloader()

if self.use_ptx:
with (
guard_set_args(
args,
{
"per_device_train_batch_size": (
1
if getattr(self.ptx_dataset, "is_intokens", False)
else self.args.per_device_prompt_batch_size * self.args.num_return_sequences
)
},
),
guard_set_args(
self,
{
"train_dataset": self.ptx_dataset,
"data_collator": self.ptx_dataset.get_collator(),
},
),
):
self.ptx_dataloader = self.get_train_dataloader()
else:
Expand Down Expand Up @@ -1768,7 +1770,11 @@ def train(
# ##### set training state and resume #####
# consumed_samples used to set train_dataloader.batch_sampler may not be
# correct. Thus, data cannot be resumed perfectly when not breaking at epoch end.
(epochs_trained, steps_trained_in_current_epoch, steps_trained_progress_bar,) = self.init_train_state(
(
epochs_trained,
steps_trained_in_current_epoch,
steps_trained_progress_bar,
) = self.init_train_state(
resume_from_checkpoint,
train_dataloader,
max_steps,
Expand Down Expand Up @@ -1883,14 +1889,10 @@ def train(

best_model_checkpoint = json.loads(self.state.best_model_checkpoint)

logger.info(
f"Loading best model from {best_model_checkpoint['value']}" f"(score: {self.state.best_metric})."
)
logger.info(f"Loading best model from {best_model_checkpoint['value']}(score: {self.state.best_metric}).")
self.load_best_ckpt(best_model_checkpoint["value"], self.value_trainer)

logger.info(
f"Loading best model from {best_model_checkpoint['policy']}" f"(score: {self.state.best_metric})."
)
logger.info(f"Loading best model from {best_model_checkpoint['policy']}(score: {self.state.best_metric}).")
self.load_best_ckpt(best_model_checkpoint["policy"], self.policy_trainer)

metrics = speed_metrics(
Expand Down Expand Up @@ -1967,7 +1969,6 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
None.
"""
if self.control.should_log and tr_loss is not None:

logs: Dict[str, float] = {}
# use_ptx would double the gradient_accumulation_steps which causes
# policy_loss and ptx_loss reduced by half. Moreover, ptx_loss should
Expand Down Expand Up @@ -2484,7 +2485,8 @@ def generate(self, prompt_only_batch: Dict, do_eval=False) -> List[Dict[str, Any
if do_eval:
self.args.num_return_sequences = train_num_return_sequences
sequences = sequences.transpose([1, 0, 2])

if not isinstance(sequences, list):
sequences = [sequences]
# prompt, sequence, attention_mask
return [
{
Expand Down Expand Up @@ -2629,9 +2631,7 @@ def rollout_reward_value(
attention_mask=reward_attention_mask,
position_ids=reward_position_ids,
# return_dict=True,
)[
1
] # .end_scores
)[1] # .end_scores
else:
prompt_len = kwargs["prompt"].shape[-1]
if "label_ids" not in kwargs:
Expand All @@ -2651,9 +2651,7 @@ def rollout_reward_value(
attention_mask=attention_mask,
position_ids=position_ids,
# return_dict=True,
)[
0
] # .scores
)[0] # .scores
reward_value = reward_value.squeeze(axis=-1)
reward_value = reward_value[:, :-1]

Expand Down
2 changes: 1 addition & 1 deletion llm/alignment/ppo/run_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def main():
reward_tokenizer,
reward_critic_tokenizer if training_args.rl_algorithm == "ppo" else None,
]:
if isinstance(tokenizer, AutoTokenizer) and tokenizer.pad_token_id is None:
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id

if training_args.should_load_dataset:
Expand Down
10 changes: 10 additions & 0 deletions llm/alignment/ppo/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,16 @@ class ModelArgument:
"can be selected as `full` or `full_attn` or `core_attn`. "
},
)
chat_template: str = field(
default="none",
metadata={
"help": "the path of `chat_template.json` file to handle multi-rounds conversation. "
"If is None(do not set --chat_template argument), it will use the default `chat_template.json`;"
"If is equal with `model_name_or_path`, it will use the default loading; "
"If is directory, it will find the `chat_template.json` under the directory; If is file, it will load it."
"If is none string, it will not use chat_template.json."
},
)


@dataclass
Expand Down

0 comments on commit 87ba5d7

Please sign in to comment.