diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 78b2edbcb5..9227ff52cf 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -24,8 +24,6 @@ This example demonstrates how to train a model using the GRPO method. We train a > Below is the script to train the model. -Note that the input tensor for the forward pass has a size of `num_generations * per_device_train_batch_size` because GRPO generates `num_generations` completions for each prompt in the batch. Adjusting these values appropriately can help prevent OOM errors. -Consequently, the effective train batch size is `num_generations * per_device_train_batch_size * gradient_accumulation_steps`. ```python # train_grpo.py diff --git a/tests/test_cli.py b/tests/test_cli.py index 234b4e7ba1..6b00ed1ed0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -47,7 +47,7 @@ def test_grpo(self): from trl.cli import main with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory - command = f"trl grpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --reward_model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_prompt_only --num_generations 3 --max_completion_length 32 --report_to none" + command = f"trl grpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --reward_model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_prompt_only --num_generations 4 --max_completion_length 32 --report_to none" with patch("sys.argv", command.split(" ")): main() diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 6c85ddfc4f..54d2b3965d 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -49,7 +49,7 @@ def test_training(self, config_name): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -78,8 +78,8 @@ def test_training_with_eval(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = GRPOConfig( output_dir=tmp_dir, - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage - per_device_eval_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + per_device_eval_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage eval_strategy="steps", @@ -106,7 +106,7 @@ def test_training_peft(self): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -149,7 +149,7 @@ def test_training_different_reward_model(self): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -185,7 +185,7 @@ def reward_func(completions, **kwargs): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -221,7 +221,7 @@ def reward_func(completions, **kwargs): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -260,7 +260,7 @@ def reward_func2(completions, **kwargs): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -295,7 +295,7 @@ def reward_func(completions, **kwargs): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -334,7 +334,7 @@ def reward_func(completions, some_values, **kwargs): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -367,7 +367,7 @@ def test_training_vllm(self): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", @@ -400,7 +400,7 @@ def test_training_torch_compile(self): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage torch_compile=True, @@ -431,7 +431,7 @@ def test_training_with_sync_ref_model(self): training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=2, # reduce the batch size to reduce memory usage + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage sync_ref_model=True, diff --git a/trl/models/utils.py b/trl/models/utils.py index 22a30c0afb..dce9d60228 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -137,6 +137,8 @@ def setup_chat_format( def remove_hooks(model: "DeepSpeedEngine") -> None: """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model.""" + if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer + return if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): optimizer_offload = model.optimizer.parameter_offload elif model.optimizer is not None: @@ -164,6 +166,8 @@ def iter_params(module, recurse=False): def add_hooks(model: "DeepSpeedEngine") -> None: """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.""" + if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer + return if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): optimizer_offload = model.optimizer.parameter_offload elif model.optimizer is not None: diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index e641065968..b30cf9899d 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -45,7 +45,8 @@ class GRPOConfig(TrainingArguments): max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. num_generations (`int` or `None`, *optional*, defaults to `8`): - Number of generations per prompt to sample. + Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size) + must be divisible by this value. temperature (`float`, *optional*, defaults to `0.9`): Temperature for sampling. The higher the temperature, the more random the completions. max_completion_length (`int` or `None`, *optional*, defaults to `256`): @@ -83,11 +84,6 @@ class GRPOConfig(TrainingArguments): learning_rate (`float`, *optional*, defaults to `1e-6`): Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. - per_device_train_batch_size (`int`, *optional*, defaults to `1`): - Number of prompts sampled per device for training. The actual batch passed into the model will be this - value multiplied by `num_generations`. - gradient_accumulation_steps (`int`, *optional*, defaults to `8`): - Number of updates steps to accumulate the gradients for, before performing a backward/update pass. beta (`float`, *optional*, defaults to `0.04`): KL coefficient. sync_ref_model (`bool`, *optional*, defaults to `False`): @@ -132,7 +128,10 @@ class GRPOConfig(TrainingArguments): ) num_generations: Optional[int] = field( default=8, - metadata={"help": "Number of generations to sample."}, + metadata={ + "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) " + "must be divisible by this value." + }, ) temperature: Optional[float] = field( default=0.9, @@ -202,23 +201,6 @@ class GRPOConfig(TrainingArguments): "`transformers.TrainingArguments`." }, ) - # GRPO generates multiple completions per prompt, increasing memory usage. - # To accommodate this, the per-device train batch size is decreased (overriden from the parent class), - # and the number gradient accumulation steps is increased to maintain the effective batch size. - per_device_train_batch_size: int = field( - default=1, - metadata={ - "help": "Number of prompts sampled per device for training. The actual batch passed into the model will " - "be this value multiplied by `num_generations`." - }, - ) - gradient_accumulation_steps: int = field( - default=8, - metadata={ - "help": "Number of updates steps to accumulate the gradients for, before performing a backward/update " - "pass." - }, - ) beta: float = field( default=0.04, metadata={"help": "KL coefficient."}, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c6c5f4fa76..fa55a4bb8e 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -16,17 +16,18 @@ import textwrap import warnings from collections import defaultdict -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Sized, Union from unittest.mock import patch import torch import torch.utils.data import transformers -from accelerate.utils import broadcast_object_list, gather_object +from accelerate.utils import broadcast_object_list, gather, gather_object from accelerate.utils.other import is_compiled_module from datasets import Dataset, IterableDataset from packaging import version from torch import nn +from torch.utils.data import Sampler from transformers import ( AutoModelForCausalLM, AutoModelForSequenceClassification, @@ -63,6 +64,37 @@ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] +class RepeatRandomSampler(Sampler): + """ + Sampler that repeats the indices of a dataset N times. + + Args: + data_source (`Sized`): + Dataset to sample from. + repeat_count (`int`): + Number of times to repeat each index. + + Example: + ```python + >>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2) + >>> list(sampler) + [2, 2, 0, 0, 3, 3, 1, 1] + ``` + """ + + def __init__(self, data_source: Sized, repeat_count: int): + self.data_source = data_source + self.repeat_count = repeat_count + self.num_samples = len(data_source) + + def __iter__(self): + indexes = [idx for idx in torch.randperm(self.num_samples).tolist() for _ in range(self.repeat_count)] + return iter(indexes) + + def __len__(self): + return self.num_samples * self.repeat_count + + class GRPOTrainer(Trainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -280,6 +312,26 @@ def data_collator(features): # No data collation is needed in GRPO optimizers=optimizers, ) + # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations + num_processes = self.accelerator.num_processes + global_batch_size = args.per_device_train_batch_size * num_processes + possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] + if self.num_generations not in possible_values: + raise ValueError( + f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " + f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train " + f"batch size, the valid values for the number of generations are: {possible_values}." + ) + if self.args.eval_strategy != "no": + global_batch_size = args.per_device_eval_batch_size * num_processes + possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] + if self.num_generations not in possible_values: + raise ValueError( + f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly " + f"divisible by the number of generations per prompt ({self.num_generations}). Given the current " + f"eval batch size, the valid values for the number of generations are: {possible_values}." + ) + if self.use_vllm: if not is_vllm_available(): raise ImportError( @@ -325,12 +377,11 @@ def data_collator(features): # No data collation is needed in GRPO max_model_len=self.args.vllm_max_model_len, ) self.sampling_params = SamplingParams( - n=self.num_generations, temperature=args.temperature, max_tokens=self.max_completion_length, ) - self._last_loaded_step = 0 # tag to avoid useless loading during grad checkpointing + self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation # When using vLLM, the main process is responsible for loading the model weights. This can cause process # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we @@ -341,7 +392,6 @@ def data_collator(features): # No data collation is needed in GRPO max_new_tokens=self.max_completion_length, do_sample=True, temperature=args.temperature, - num_return_sequences=self.num_generations, pad_token_id=processing_class.pad_token_id, ) @@ -374,12 +424,17 @@ def _set_signature_columns_if_needed(self): if self._signature_columns is None: self._signature_columns = ["prompt"] + # We need a custom sampler that samples the same prompt multiple times + def _get_train_sampler(self) -> Sampler: + return RepeatRandomSampler(self.train_dataset, self.num_generations) + + def _get_eval_sampler(self, eval_dataset) -> Sampler: + return RepeatRandomSampler(eval_dataset, self.num_generations) + # Get the per-token log probabilities for the completions for the model and the reference model def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - logits = model( - input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1 - ).logits # (B, L, V) + logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids = input_ids[:, -logits_to_keep:] @@ -389,8 +444,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # Compute the log probabilities for the input tokens. token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) - # use a loop to reduce memory peak - logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) # loop to reduce memory peak token_log_probs = token_logits - logsumexp_values # log_softmax = logits - log(sum(exp(logits))) return token_log_probs @@ -430,22 +484,19 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False) completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] else: - completion_ids = [None] * len(all_prompts_text) * self.num_generations - + completion_ids = [None] * len(all_prompts_text) # Broadcast the completions from the main process to all processes, ensuring each process receives its # corresponding slice. completion_ids = broadcast_object_list(completion_ids, from_process=0) process_slice = slice( - self.accelerator.process_index * len(prompts) * self.num_generations, - (self.accelerator.process_index + 1) * len(prompts) * self.num_generations, + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), ) completion_ids = completion_ids[process_slice] # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) - prompt_ids = torch.repeat_interleave(prompt_ids, self.num_generations, dim=0) - prompt_mask = torch.repeat_interleave(prompt_mask, self.num_generations, dim=0) prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) else: # Regular generation path @@ -458,7 +509,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s prompt_length = prompt_ids.size(1) prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] - prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0) # Mask everything after the first EOS token is_eos = completion_ids == self.processing_class.eos_token_id @@ -488,9 +538,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if is_conversational(inputs[0]): completions = [[{"role": "assistant", "content": completion}] for completion in completions] - # Compute the rewards - prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] # repeat prompts - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) for i, (reward_func, reward_processing_class) in enumerate( zip(self.reward_funcs, self.reward_processing_classes) @@ -509,14 +556,15 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) else: # Repeat all input columns (but "prompt" and "completion") to match the number of generations - reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} - for key in reward_kwargs: - for example in inputs: - # Repeat each value in the column for `num_generations` times - reward_kwargs[key].extend([example[key]] * self.num_generations) + keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # completions may be distributed across processes + rewards_per_func = gather(rewards_per_func) + # Sum the rewards from all reward functions rewards = rewards_per_func.sum(dim=1) @@ -529,8 +577,15 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + advantages = advantages[process_slice] + # Log the metrics - reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) + reward_per_func = rewards_per_func.mean(0) for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models reward_func_name = reward_func.config._name_or_path.split("/")[-1] @@ -538,8 +593,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s reward_func_name = reward_func.__name__ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) - self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) - self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) + self._metrics["reward"].append(rewards.mean().item()) + self._metrics["reward_std"].append(std_grouped_rewards.mean().item()) return { "prompt_ids": prompt_ids,