Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPOTrainer: fix progress bar for num_mini_batches > 1 #2531

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/ppo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
* `lr`: lr: The current learning rate used by the optimizer.
* `episode`: episode: The current global step or episode count in the training process.
* `episode`: episode: The current episode count in the training process.


## Cookbook
Expand Down
6 changes: 2 additions & 4 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,7 @@ def __init__(
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
self.accelerator = accelerator
args.world_size = accelerator.num_processes
args.local_batch_size = (
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
)
args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
args.batch_size = int(args.local_batch_size * args.world_size)
args.mini_batch_size = exact_div(
Expand Down Expand Up @@ -376,7 +374,7 @@ def repeat_generator():
# trainer state initialization
self.state.global_step = 0
self.state.episode = 0
self.state.max_steps = args.num_total_batches * args.num_mini_batches
self.state.max_steps = args.num_total_batches
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
Expand Down
Loading