From 4201299d6fc68cca5085e0419f77eb653ff683e0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 24 Jan 2025 09:57:38 +0100 Subject: [PATCH] initial --- docs/source/_toctree.yml | 2 + docs/source/prime_trainer.md | 54 ++++++++++++++++++ tests/test_prime_trainer.py | 0 trl/scripts/prime.py | 0 trl/trainer/__init__.py | 4 ++ trl/trainer/prime_config.py | 47 +++++++++++++++ trl/trainer/prime_trainer.py | 107 +++++++++++++++++++++++++++++++++++ 7 files changed, 214 insertions(+) create mode 100644 docs/source/prime_trainer.md create mode 100644 tests/test_prime_trainer.py create mode 100644 trl/scripts/prime.py create mode 100644 trl/trainer/prime_config.py create mode 100644 trl/trainer/prime_trainer.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 4ccc57ae8f..a71cf22a01 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -80,6 +80,8 @@ title: PPO - local: prm_trainer title: PRM + - local: prime_trainer + title: PRIME - local: reward_trainer title: Reward - local: rloo_trainer diff --git a/docs/source/prime_trainer.md b/docs/source/prime_trainer.md new file mode 100644 index 0000000000..118aa712d2 --- /dev/null +++ b/docs/source/prime_trainer.md @@ -0,0 +1,54 @@ +# PRIME Trainer + +The Process Reinforcement through IMplicit rEwards (PRIME) algorithm by Cui et al. starts from an SFT policy, Process Reward Model (PRM), and a frozen reference model together with a ground truth outcome verifier. For each RL iteration, the policy model first generates rollouts. Then, the implicit PRM and outcome verifier score the rollouts, and the implicit PRM get updated on the rollouts with outcome reward. Finally, the outcome reward and process reward are combined and used to update the policy model via a PPO loss update. More information can be found in the [PRIME Blog](https://curvy-check-498.notion.site/Process-Reinforcement-through-Implicit-Rewards-15f4fcb9c42180f1b498cc9b2eaf896f). + +## Algorithm + +The algorithm for PRIME is descibed as follows: + +**Input**: Input supervised fine-tuned model $\pi_{\theta_{\mathrm{init}}}$; ground truth outcome reward verifier function $R_\mathrm{Gt}$; implicit PRM $\pi_\phi$; frozen reference model $\pi_{\mathrm{ref}}$; instruction dataset with ground truth outputs $\mathcal{D} = \{(x, y_\textrm{gt})\}$; and hyperparameters $\beta, \epsilon, M, N, K, \gamma$. + +**Notation**: $r_i$ represents the outcome reward of $y_i$ the $i$-th candidate output, and $r_i^t$ denotes its process reward at token step $t$. Define +$r_\phi (y) = \beta \log \frac{\pi_\phi (y)}{\pi_{\theta_{\mathrm{ref}}} (y)}$, where the context $x$ is omitted for simplicity. + +**Steps**: + +1. Initialize policy model $\phi_\theta \leftarrow \pi_{\theta_{\mathrm{init}}}$, implicit PRM $\pi_{\phi} \leftarrow \pi_{\theta_{\mathrm{init}}}$, and reference model $\pi_{\mathrm{ref}} \leftarrow \pi_{\theta_{\mathrm{init}}}$ from the initial SFT model $\pi_{\theta_{\mathrm{init}}}$ + +2. For iterations $1, \ldots, N$ do: + 1. Sample batch of prompts $\mathcal{B} \sim \mathcal{D}$ + 2. Initialize the buffer $\mathcal{T} = \emptyset$ + 3. for each prompt instruction $(x, y_\textrm{gt}) \in \mathcal{B}$ do: + 1. Sample $K$ candidate outputs from the current policy $\pi_\theta$: $y_1, \ldots, y_K \sim \pi_{\theta}(\cdot | x)$ + 2. Compute ground truth rewards $r_i = R_\mathrm{Gt}(x, y_i, y_\textrm{gt})$ for $i = 1, \ldots, K$ + 3. Record the number of correct responses $|\mathcal{C}_x| = |\{y_i | r_i =1\}|$ + 4. if the number of correct response ratio $0.2 <|\mathcal{C}_x| / K < 0.8$ is between 0.2 and 0.8: + 1. add ALL the $K$ tuples of prompt $x$ and candidate outputs and ground truth rewards: $\{(x, y_i, r_i)\}_{i=1}^K$ to $\mathcal{T}$ + 5. else: + 1. drop this prompt instruction $x$ and continue to the next prompt + 4. For PPO epoch $1, \ldots, M$ do: + 1. Forward pass the implicit PRM $\pi_\phi$ on each $(x, y, r) \in \mathcal{T}$ to get the implicit process rewards as $r^t = \beta \log \frac{\pi_\phi(y_t | y_{ 0.5).float().sum(dim=1) / self.args.num_generations + + # Apply thresholds + min_thresh, max_thresh = self.args.correct_ratio_threshold + valid_mask = (correct_responses > min_thresh) & (correct_responses < max_thresh) + + # Expand mask for all generations + final_mask = valid_mask.repeat_interleave(self.args.num_generations) + + filtered_batch = {k: v[final_mask] for k, v in batch.items() if isinstance(v, torch.Tensor)} + return filtered_batch, valid_mask + + def compute_loss(self, model, inputs, return_outputs=False): + # Get completions and compute base rewards similar to GRPO + loss, rewards, completions = super().compute_base_rewards(model, inputs) + + # Compute verifier rewards using reward function + if self.reward_function is not None: + verifier_rewards = self.reward_function(completions) + rewards += self.args.verifier_reward_coef * verifier_rewards + self._metrics["verifier_reward"].append(verifier_rewards.mean().item()) + + # Filter batch based on correct ratio + filtered_batch, valid_mask = self.filter_batch(rewards, inputs) + if len(filtered_batch) == 0: + return loss + + # Compute process rewards using implicit PRM + process_rewards = self.compute_process_rewards( + model, + filtered_batch, + granularity=self.args.prime_granularity, + norm_type=self.args.prime_norm + ) + + # Update metrics + self._metrics["correct_ratio"].append(valid_mask.float().mean().item()) + self._metrics["process_reward"].append(process_rewards.mean().item()) + + # Combine rewards and compute PPO loss + total_rewards = rewards + process_rewards + ppo_loss = self.compute_ppo_loss(model, filtered_batch, total_rewards) + + return ppo_loss + + def compute_process_rewards(self, model, batch, granularity="token", norm_type="batch_norm"): + """Compute process rewards using the implicit PRM""" + with torch.no_grad(): + # Get logits from current policy and reference model + policy_logits = model(**batch).logits + ref_logits = self.ref_model(**batch).logits + + # Compute KL divergence + kl_div = torch.nn.functional.kl_div( + policy_logits.log_softmax(-1), + ref_logits.softmax(-1), + reduction='none' + ).sum(-1) + + # Apply normalization + if norm_type == "batch_norm": + kl_div = (kl_div - kl_div.mean()) / (kl_div.std() + 1e-8) + + # Convert to process rewards + process_rewards = -self.args.beta_train * kl_div + + if granularity == "whole": + process_rewards = process_rewards.mean(dim=1, keepdim=True).expand(-1, process_rewards.size(1)) + + return process_rewards \ No newline at end of file