Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jan 24, 2025
1 parent fe4b5ef commit 4201299
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
title: PPO
- local: prm_trainer
title: PRM
- local: prime_trainer
title: PRIME
- local: reward_trainer
title: Reward
- local: rloo_trainer
Expand Down
54 changes: 54 additions & 0 deletions docs/source/prime_trainer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# PRIME Trainer

The Process Reinforcement through IMplicit rEwards (PRIME) algorithm by Cui et al. starts from an SFT policy, Process Reward Model (PRM), and a frozen reference model together with a ground truth outcome verifier. For each RL iteration, the policy model first generates rollouts. Then, the implicit PRM and outcome verifier score the rollouts, and the implicit PRM get updated on the rollouts with outcome reward. Finally, the outcome reward and process reward are combined and used to update the policy model via a PPO loss update. More information can be found in the [PRIME Blog](https://curvy-check-498.notion.site/Process-Reinforcement-through-Implicit-Rewards-15f4fcb9c42180f1b498cc9b2eaf896f).

## Algorithm

The algorithm for PRIME is descibed as follows:

**Input**: Input supervised fine-tuned model $\pi_{\theta_{\mathrm{init}}}$; ground truth outcome reward verifier function $R_\mathrm{Gt}$; implicit PRM $\pi_\phi$; frozen reference model $\pi_{\mathrm{ref}}$; instruction dataset with ground truth outputs $\mathcal{D} = \{(x, y_\textrm{gt})\}$; and hyperparameters $\beta, \epsilon, M, N, K, \gamma$.

**Notation**: $r_i$ represents the outcome reward of $y_i$ the $i$-th candidate output, and $r_i^t$ denotes its process reward at token step $t$. Define
$r_\phi (y) = \beta \log \frac{\pi_\phi (y)}{\pi_{\theta_{\mathrm{ref}}} (y)}$, where the context $x$ is omitted for simplicity.

**Steps**:

1. Initialize policy model $\phi_\theta \leftarrow \pi_{\theta_{\mathrm{init}}}$, implicit PRM $\pi_{\phi} \leftarrow \pi_{\theta_{\mathrm{init}}}$, and reference model $\pi_{\mathrm{ref}} \leftarrow \pi_{\theta_{\mathrm{init}}}$ from the initial SFT model $\pi_{\theta_{\mathrm{init}}}$

2. For iterations $1, \ldots, N$ do:
1. Sample batch of prompts $\mathcal{B} \sim \mathcal{D}$
2. Initialize the buffer $\mathcal{T} = \emptyset$
3. for each prompt instruction $(x, y_\textrm{gt}) \in \mathcal{B}$ do:
1. Sample $K$ candidate outputs from the current policy $\pi_\theta$: $y_1, \ldots, y_K \sim \pi_{\theta}(\cdot | x)$
2. Compute ground truth rewards $r_i = R_\mathrm{Gt}(x, y_i, y_\textrm{gt})$ for $i = 1, \ldots, K$
3. Record the number of correct responses $|\mathcal{C}_x| = |\{y_i | r_i =1\}|$
4. if the number of correct response ratio $0.2 <|\mathcal{C}_x| / K < 0.8$ is between 0.2 and 0.8:
1. add ALL the $K$ tuples of prompt $x$ and candidate outputs and ground truth rewards: $\{(x, y_i, r_i)\}_{i=1}^K$ to $\mathcal{T}$
5. else:
1. drop this prompt instruction $x$ and continue to the next prompt
4. For PPO epoch $1, \ldots, M$ do:
1. Forward pass the implicit PRM $\pi_\phi$ on each $(x, y, r) \in \mathcal{T}$ to get the implicit process rewards as $r^t = \beta \log \frac{\pi_\phi(y_t | y_{<t})}{\pi_{\mathrm{ref}}(y_t | y_{<t})}$ for each token $t$ of $y$
2. Update the Implicit PRM $\pi_\phi$ using the Cross-Entropy loss given the tuples $(x, y, r)$:
$\mathcal{L}_{\mathrm{CE}}(\phi) = \mathbb{E}_{(x, y, r) \sim \mathcal{T}} \left[ r \log \sigma (r_\phi (y)) + (1-r) \log (1-\sigma (r_\phi (y))) \right]$
3. Compute the RLOO Advantage $A$ for the prompt $x$ and its $K$ candidate outputs $\{y_1, ..., y_i, ..., y_K\}$ and ground outcome rewards $\{ r_1, ..., r_i, ..., r_K \}$: for each token $t$ of $y_i$ let $A_{i}^t$ be advantage of the sum of the RLOO wth outcome rewards and the RLOO with implicit process rewards:
$ A_{i}^t = r_i - \frac{1}{K -1} \sum_{j \neq i} r_j + \sum_{s=t}^{|y_i|} \gamma^{s-t} \left[
r_i^s - \frac{1}{K -1} \sum_{j \neq i} \frac{r_\phi(y_j)}{|y_j|}
\right] $
4. update the policy $\pi_\theta$ using the PPO loss with respect to the RLOO Advantage $A$ and $\pi_{\mathrm{ref}}$ with clip coefficient $\epsilon$.

**Output**: return the final policy $\pi_\theta$ for saving.

# Get started

To just run a PRIME script to make sure the trainer can run, you can run the following command to train a PRIME model:

```bash
```

## PrimeTrainer

[[autodoc]] PrimeTrainer

## PrimeConfig

[[autodoc]] PrimeConfig
Empty file added tests/test_prime_trainer.py
Empty file.
Empty file added trl/scripts/prime.py
Empty file.
4 changes: 4 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
"ppo_trainer": ["PPOTrainer"],
"prm_config": ["PRMConfig"],
"prm_trainer": ["PRMTrainer"],
"prime_config": ["PrimeConfig"],
"prime_trainer": ["PrimeTrainer"],
"reward_config": ["RewardConfig"],
"reward_trainer": ["RewardTrainer"],
"rloo_config": ["RLOOConfig"],
Expand Down Expand Up @@ -131,6 +133,8 @@
from .orpo_trainer import ORPOTrainer
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer
from .prime_config import PrimeConfig
from .prime_trainer import PrimeTrainer
from .prm_config import PRMConfig
from .prm_trainer import PRMTrainer
from .reward_config import RewardConfig
Expand Down
47 changes: 47 additions & 0 deletions trl/trainer/prime_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from dataclasses import dataclass, field
from typing import Optional, Union
from .grpo_config import GRPOConfig

@dataclass
class PrimeConfig(GRPOConfig):
"""
Configuration class for the PrimeTrainer.
Extends GRPOConfig with PRIME-specific parameters.
"""

# Reward model parameters
reward_model_name_or_path: Optional[str] = field(
default=None,
metadata={"help": "Path to the reward model or its name on the Hub"},
)

reward_model_kwargs: Optional[dict] = field(
default_factory=dict,
metadata={"help": "Additional kwargs for reward model initialization"},
)

# PRIME specific parameters
prime_granularity: str = field(
default="token",
metadata={"help": "Granularity of process rewards: 'token' or 'whole'"},
)

prime_norm: str = field(
default="batch_norm",
metadata={"help": "Normalization method for process rewards"},
)

prime_ref_type: str = field(
default="freeze",
metadata={"help": "Reference model type: 'freeze' or 'policy'"},
)

prime_beta_train: float = field(
default=0.05,
metadata={"help": "Beta coefficient for training"},
)

reward_model_coef: float = field(
default=0.0,
metadata={"help": "Weight for the reward model score"},
)
107 changes: 107 additions & 0 deletions trl/trainer/prime_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import torch
from typing import Optional, Union, Callable
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from .grpo_trainer import GRPOTrainer
from .prime_config import PrimeConfig

class PrimeTrainer(GRPOTrainer):
def __init__(
self,
model: Union[str, PreTrainedModel],
args: PrimeConfig = None,
reward_function: Optional[Callable] = None,
reward_model: Optional[PreTrainedModel] = None,
**kwargs
):
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model_name.split("/")[-1]
args = PrimeConfig(f"{model_name}-PRIME")

self.reward_function = reward_function
super().__init__(model=model, args=args, reward_model=reward_model, **kwargs)

# Initialize metrics specific to PRIME
self._metrics.update({
"verifier_reward": [],
"rm_reward": [],
"process_reward": [],
"correct_ratio": []
})

def filter_batch(self, rewards, batch):
"""Filter batch based on correct response ratio thresholds"""
batch_size = len(batch) // self.args.num_generations
rewards = rewards.view(batch_size, self.args.num_generations)

# Calculate correct response ratio per prompt
correct_responses = (rewards > 0.5).float().sum(dim=1) / self.args.num_generations

# Apply thresholds
min_thresh, max_thresh = self.args.correct_ratio_threshold
valid_mask = (correct_responses > min_thresh) & (correct_responses < max_thresh)

# Expand mask for all generations
final_mask = valid_mask.repeat_interleave(self.args.num_generations)

filtered_batch = {k: v[final_mask] for k, v in batch.items() if isinstance(v, torch.Tensor)}
return filtered_batch, valid_mask

def compute_loss(self, model, inputs, return_outputs=False):
# Get completions and compute base rewards similar to GRPO
loss, rewards, completions = super().compute_base_rewards(model, inputs)

# Compute verifier rewards using reward function
if self.reward_function is not None:
verifier_rewards = self.reward_function(completions)
rewards += self.args.verifier_reward_coef * verifier_rewards
self._metrics["verifier_reward"].append(verifier_rewards.mean().item())

# Filter batch based on correct ratio
filtered_batch, valid_mask = self.filter_batch(rewards, inputs)
if len(filtered_batch) == 0:
return loss

# Compute process rewards using implicit PRM
process_rewards = self.compute_process_rewards(
model,
filtered_batch,
granularity=self.args.prime_granularity,
norm_type=self.args.prime_norm
)

# Update metrics
self._metrics["correct_ratio"].append(valid_mask.float().mean().item())
self._metrics["process_reward"].append(process_rewards.mean().item())

# Combine rewards and compute PPO loss
total_rewards = rewards + process_rewards
ppo_loss = self.compute_ppo_loss(model, filtered_batch, total_rewards)

return ppo_loss

def compute_process_rewards(self, model, batch, granularity="token", norm_type="batch_norm"):
"""Compute process rewards using the implicit PRM"""
with torch.no_grad():
# Get logits from current policy and reference model
policy_logits = model(**batch).logits
ref_logits = self.ref_model(**batch).logits

# Compute KL divergence
kl_div = torch.nn.functional.kl_div(
policy_logits.log_softmax(-1),
ref_logits.softmax(-1),
reduction='none'
).sum(-1)

# Apply normalization
if norm_type == "batch_norm":
kl_div = (kl_div - kl_div.mean()) / (kl_div.std() + 1e-8)

# Convert to process rewards
process_rewards = -self.args.beta_train * kl_div

if granularity == "whole":
process_rewards = process_rewards.mean(dim=1, keepdim=True).expand(-1, process_rewards.size(1))

return process_rewards

0 comments on commit 4201299

Please sign in to comment.