Skip to content

Commit

Permalink
initial PRIME
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jan 24, 2025
1 parent 4201299 commit bac0cd3
Show file tree
Hide file tree
Showing 2 changed files with 606 additions and 137 deletions.
166 changes: 124 additions & 42 deletions trl/trainer/prime_config.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,129 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional, Union
from .grpo_config import GRPOConfig
from typing import Optional

from transformers import TrainingArguments


@dataclass
class PrimeConfig(GRPOConfig):
"""
Configuration class for the PrimeTrainer.
Extends GRPOConfig with PRIME-specific parameters.
class PrimeConfig(TrainingArguments):
r"""
Configuration class for the [`PrimeConfig`].
Only the parameters specific to PRIME training are listed here. For details on other parameters, refer to the
[`~transformers.TrainingArguments`] documentation.
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
> Parameters that control the model and reference model
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`PrimeTrainer`] is provided as a string.
> Parameters that control the data preprocessing
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
num_generations (`int` or `None`, *optional*, defaults to `8`):
Number of generations per prompt to sample.
temperature (`float`, *optional*, defaults to `0.9`):
Temperature for sampling. The higher the temperature, the more random the completions.
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
Maximum length of the generated completion.
use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
> Parameters that control the training
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
"""
# Reward model parameters
reward_model_name_or_path: Optional[str] = field(

# Parameters that control the model and reference model
model_init_kwargs: Optional[dict] = field(
default=None,
metadata={"help": "Path to the reward model or its name on the Hub"},
)

reward_model_kwargs: Optional[dict] = field(
default_factory=dict,
metadata={"help": "Additional kwargs for reward model initialization"},
)

# PRIME specific parameters
prime_granularity: str = field(
default="token",
metadata={"help": "Granularity of process rewards: 'token' or 'whole'"},
)

prime_norm: str = field(
default="batch_norm",
metadata={"help": "Normalization method for process rewards"},
)

prime_ref_type: str = field(
default="freeze",
metadata={"help": "Reference model type: 'freeze' or 'policy'"},
)

prime_beta_train: float = field(
default=0.05,
metadata={"help": "Beta coefficient for training"},
)

reward_model_coef: float = field(
default=0.0,
metadata={"help": "Weight for the reward model score"},
)
metadata={
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
"argument of the `PrimeTrainer` is provided as a string."
},
)

# Parameters that control the data preprocessing
max_prompt_length: Optional[int] = field(
default=512,
metadata={
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
},
)
num_generations: Optional[int] = field(
default=8,
metadata={"help": "Number of generations to sample."},
)
temperature: Optional[float] = field(
default=0.9,
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
)
max_completion_length: Optional[int] = field(
default=256,
metadata={"help": "Maximum length of the generated completion."},
)
use_vllm: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use vLLM for generating completions. Requires vLLM to be installed "
"(`pip install vllm`)."
},
)

# Parameters that control the training
learning_rate: float = field(
default=1e-6,
metadata={
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
"`transformers.TrainingArguments`."
},
)
beta: float = field(
default=0.04,
metadata={"help": "KL coefficient."},
)
correct_ratio_min: float = field(
default=0.2,
metadata={"help": "Minimum ratio of correct responses required to keep a prompt batch."},
)
correct_ratio_max: float = field(
default=0.8,
metadata={"help": "Maximum ratio of correct responses allowed to keep a prompt batch."},
)
num_ppo_epochs: int = field(
default=4,
metadata={"help": "Number of PPO epochs to run per batch (M in the PRIME paper)."},
)
gamma: float = field(
default=1.0,
metadata={"help": "Discount factor for future rewards in advantage estimation."},
)
epsilon: float = field(
default=0.2,
metadata={"help": "PPO clip range for policy updates."},
)
Loading

0 comments on commit bac0cd3

Please sign in to comment.