generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# PRIME Trainer | ||
|
||
The Process Reinforcement through IMplicit rEwards (PRIME) algorithm by Cui et al. starts from an SFT policy, Process Reward Model (PRM), and a frozen reference model together with a ground truth outcome verifier. For each RL iteration, the policy model first generates rollouts. Then, the implicit PRM and outcome verifier score the rollouts, and the implicit PRM get updated on the rollouts with outcome reward. Finally, the outcome reward and process reward are combined and used to update the policy model via a PPO loss update. More information can be found in the [PRIME Blog](https://curvy-check-498.notion.site/Process-Reinforcement-through-Implicit-Rewards-15f4fcb9c42180f1b498cc9b2eaf896f). | ||
|
||
## Algorithm | ||
|
||
The algorithm for PRIME is descibed as follows: | ||
|
||
**Input**: Input supervised fine-tuned model $\pi_{\theta_{\mathrm{init}}}$; ground truth outcome reward verifier function $R_\mathrm{Gt}$; implicit PRM $\pi_\phi$; frozen reference model $\pi_{\mathrm{ref}}$; instruction dataset with ground truth outputs $\mathcal{D} = \{(x, y_\textrm{gt})\}$; and hyperparameters $\beta, \epsilon, M, N, K, \gamma$. | ||
|
||
**Notation**: $r_i$ represents the outcome reward of $y_i$ the $i$-th candidate output, and $r_i^t$ denotes its process reward at token step $t$. Define | ||
$r_\phi (y) = \beta \log \frac{\pi_\phi (y)}{\pi_{\theta_{\mathrm{ref}}} (y)}$, where the context $x$ is omitted for simplicity. | ||
|
||
**Steps**: | ||
|
||
1. Initialize policy model $\phi_\theta \leftarrow \pi_{\theta_{\mathrm{init}}}$, implicit PRM $\pi_{\phi} \leftarrow \pi_{\theta_{\mathrm{init}}}$, and reference model $\pi_{\mathrm{ref}} \leftarrow \pi_{\theta_{\mathrm{init}}}$ from the initial SFT model $\pi_{\theta_{\mathrm{init}}}$ | ||
|
||
2. For iterations $1, \ldots, N$ do: | ||
1. Sample batch of prompts $\mathcal{B} \sim \mathcal{D}$ | ||
2. Initialize the buffer $\mathcal{T} = \emptyset$ | ||
3. for each prompt instruction $(x, y_\textrm{gt}) \in \mathcal{B}$ do: | ||
1. Sample $K$ candidate outputs from the current policy $\pi_\theta$: $y_1, \ldots, y_K \sim \pi_{\theta}(\cdot | x)$ | ||
2. Compute ground truth rewards $r_i = R_\mathrm{Gt}(x, y_i, y_\textrm{gt})$ for $i = 1, \ldots, K$ | ||
3. Record the number of correct responses $|\mathcal{C}_x| = |\{y_i | r_i =1\}|$ | ||
4. if the number of correct response ratio $0.2 <|\mathcal{C}_x| / K < 0.8$ is between 0.2 and 0.8: | ||
1. add ALL the $K$ tuples of prompt $x$ and candidate outputs and ground truth rewards: $\{(x, y_i, r_i)\}_{i=1}^K$ to $\mathcal{T}$ | ||
5. else: | ||
1. drop this prompt instruction $x$ and continue to the next prompt | ||
4. For PPO epoch $1, \ldots, M$ do: | ||
1. Forward pass the implicit PRM $\pi_\phi$ on each $(x, y, r) \in \mathcal{T}$ to get the implicit process rewards as $r^t = \beta \log \frac{\pi_\phi(y_t | y_{<t})}{\pi_{\mathrm{ref}}(y_t | y_{<t})}$ for each token $t$ of $y$ | ||
2. Update the Implicit PRM $\pi_\phi$ using the Cross-Entropy loss given the tuples $(x, y, r)$: | ||
$\mathcal{L}_{\mathrm{CE}}(\phi) = \mathbb{E}_{(x, y, r) \sim \mathcal{T}} \left[ r \log \sigma (r_\phi (y)) + (1-r) \log (1-\sigma (r_\phi (y))) \right]$ | ||
3. Compute the RLOO Advantage $A$ for the prompt $x$ and its $K$ candidate outputs $\{y_1, ..., y_i, ..., y_K\}$ and ground outcome rewards $\{ r_1, ..., r_i, ..., r_K \}$: for each token $t$ of $y_i$ let $A_{i}^t$ be advantage of the sum of the RLOO wth outcome rewards and the RLOO with implicit process rewards: | ||
$ A_{i}^t = r_i - \frac{1}{K -1} \sum_{j \neq i} r_j + \sum_{s=t}^{|y_i|} \gamma^{s-t} \left[ | ||
r_i^s - \frac{1}{K -1} \sum_{j \neq i} \frac{r_\phi(y_j)}{|y_j|} | ||
\right] $ | ||
4. update the policy $\pi_\theta$ using the PPO loss with respect to the RLOO Advantage $A$ and $\pi_{\mathrm{ref}}$ with clip coefficient $\epsilon$. | ||
|
||
**Output**: return the final policy $\pi_\theta$ for saving. | ||
|
||
# Get started | ||
|
||
To just run a PRIME script to make sure the trainer can run, you can run the following command to train a PRIME model: | ||
|
||
```bash | ||
``` | ||
|
||
## PrimeTrainer | ||
|
||
[[autodoc]] PrimeTrainer | ||
|
||
## PrimeConfig | ||
|
||
[[autodoc]] PrimeConfig |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Optional, Union | ||
from .grpo_config import GRPOConfig | ||
|
||
@dataclass | ||
class PrimeConfig(GRPOConfig): | ||
""" | ||
Configuration class for the PrimeTrainer. | ||
Extends GRPOConfig with PRIME-specific parameters. | ||
""" | ||
|
||
# Reward model parameters | ||
reward_model_name_or_path: Optional[str] = field( | ||
default=None, | ||
metadata={"help": "Path to the reward model or its name on the Hub"}, | ||
) | ||
|
||
reward_model_kwargs: Optional[dict] = field( | ||
default_factory=dict, | ||
metadata={"help": "Additional kwargs for reward model initialization"}, | ||
) | ||
|
||
# PRIME specific parameters | ||
prime_granularity: str = field( | ||
default="token", | ||
metadata={"help": "Granularity of process rewards: 'token' or 'whole'"}, | ||
) | ||
|
||
prime_norm: str = field( | ||
default="batch_norm", | ||
metadata={"help": "Normalization method for process rewards"}, | ||
) | ||
|
||
prime_ref_type: str = field( | ||
default="freeze", | ||
metadata={"help": "Reference model type: 'freeze' or 'policy'"}, | ||
) | ||
|
||
prime_beta_train: float = field( | ||
default=0.05, | ||
metadata={"help": "Beta coefficient for training"}, | ||
) | ||
|
||
reward_model_coef: float = field( | ||
default=0.0, | ||
metadata={"help": "Weight for the reward model score"}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import torch | ||
from typing import Optional, Union, Callable | ||
from transformers import PreTrainedModel, PreTrainedTokenizerBase | ||
from .grpo_trainer import GRPOTrainer | ||
from .prime_config import PrimeConfig | ||
|
||
class PrimeTrainer(GRPOTrainer): | ||
def __init__( | ||
self, | ||
model: Union[str, PreTrainedModel], | ||
args: PrimeConfig = None, | ||
reward_function: Optional[Callable] = None, | ||
reward_model: Optional[PreTrainedModel] = None, | ||
**kwargs | ||
): | ||
if args is None: | ||
model_name = model if isinstance(model, str) else model.config._name_or_path | ||
model_name = model_name.split("/")[-1] | ||
args = PrimeConfig(f"{model_name}-PRIME") | ||
|
||
self.reward_function = reward_function | ||
super().__init__(model=model, args=args, reward_model=reward_model, **kwargs) | ||
|
||
# Initialize metrics specific to PRIME | ||
self._metrics.update({ | ||
"verifier_reward": [], | ||
"rm_reward": [], | ||
"process_reward": [], | ||
"correct_ratio": [] | ||
}) | ||
|
||
def filter_batch(self, rewards, batch): | ||
"""Filter batch based on correct response ratio thresholds""" | ||
batch_size = len(batch) // self.args.num_generations | ||
rewards = rewards.view(batch_size, self.args.num_generations) | ||
|
||
# Calculate correct response ratio per prompt | ||
correct_responses = (rewards > 0.5).float().sum(dim=1) / self.args.num_generations | ||
|
||
# Apply thresholds | ||
min_thresh, max_thresh = self.args.correct_ratio_threshold | ||
valid_mask = (correct_responses > min_thresh) & (correct_responses < max_thresh) | ||
|
||
# Expand mask for all generations | ||
final_mask = valid_mask.repeat_interleave(self.args.num_generations) | ||
|
||
filtered_batch = {k: v[final_mask] for k, v in batch.items() if isinstance(v, torch.Tensor)} | ||
return filtered_batch, valid_mask | ||
|
||
def compute_loss(self, model, inputs, return_outputs=False): | ||
# Get completions and compute base rewards similar to GRPO | ||
loss, rewards, completions = super().compute_base_rewards(model, inputs) | ||
|
||
# Compute verifier rewards using reward function | ||
if self.reward_function is not None: | ||
verifier_rewards = self.reward_function(completions) | ||
rewards += self.args.verifier_reward_coef * verifier_rewards | ||
self._metrics["verifier_reward"].append(verifier_rewards.mean().item()) | ||
|
||
# Filter batch based on correct ratio | ||
filtered_batch, valid_mask = self.filter_batch(rewards, inputs) | ||
if len(filtered_batch) == 0: | ||
return loss | ||
|
||
# Compute process rewards using implicit PRM | ||
process_rewards = self.compute_process_rewards( | ||
model, | ||
filtered_batch, | ||
granularity=self.args.prime_granularity, | ||
norm_type=self.args.prime_norm | ||
) | ||
|
||
# Update metrics | ||
self._metrics["correct_ratio"].append(valid_mask.float().mean().item()) | ||
self._metrics["process_reward"].append(process_rewards.mean().item()) | ||
|
||
# Combine rewards and compute PPO loss | ||
total_rewards = rewards + process_rewards | ||
ppo_loss = self.compute_ppo_loss(model, filtered_batch, total_rewards) | ||
|
||
return ppo_loss | ||
|
||
def compute_process_rewards(self, model, batch, granularity="token", norm_type="batch_norm"): | ||
"""Compute process rewards using the implicit PRM""" | ||
with torch.no_grad(): | ||
# Get logits from current policy and reference model | ||
policy_logits = model(**batch).logits | ||
ref_logits = self.ref_model(**batch).logits | ||
|
||
# Compute KL divergence | ||
kl_div = torch.nn.functional.kl_div( | ||
policy_logits.log_softmax(-1), | ||
ref_logits.softmax(-1), | ||
reduction='none' | ||
).sum(-1) | ||
|
||
# Apply normalization | ||
if norm_type == "batch_norm": | ||
kl_div = (kl_div - kl_div.mean()) / (kl_div.std() + 1e-8) | ||
|
||
# Convert to process rewards | ||
process_rewards = -self.args.beta_train * kl_div | ||
|
||
if granularity == "whole": | ||
process_rewards = process_rewards.mean(dim=1, keepdim=True).expand(-1, process_rewards.size(1)) | ||
|
||
return process_rewards |