diff --git a/Finetune/plugin/agentenv/rlhf_env.py b/Finetune/plugin/agentenv/rlhf_env.py index 2ce978c80..82048ab11 100644 --- a/Finetune/plugin/agentenv/rlhf_env.py +++ b/Finetune/plugin/agentenv/rlhf_env.py @@ -162,11 +162,11 @@ def step(self, action): reward = r_align - self.kl_coeff * r_kl - info = [{ + info = { "r_align": r_align, "r_kl": r_kl, "n_response_tokens": n_response_tokens - }] + } # Produce a random reward when we reach the goal. return self.observation_space.sample(), reward, True, False, info \ No newline at end of file diff --git a/Finetune/plugin/model/reward_model.py b/Finetune/plugin/model/reward_model.py index 73a6c4f9c..294c9f7df 100644 --- a/Finetune/plugin/model/reward_model.py +++ b/Finetune/plugin/model/reward_model.py @@ -34,7 +34,14 @@ def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) # The additional value head. - self.value_head = nn.Linear(self.config.n_embd, 1) + if hasattr(self.config, 'hidden_size'): + self.value_head = nn.Linear(self.config.hidden_size, 1) + elif hasattr(self.config, 'n_embd'): + self.value_head = nn.Linear(self.config.n_embd, 1) + else: + raise ValueError("current model does not support") + + self.post_init() def forward( self, @@ -65,7 +72,7 @@ def value(self, input_ids, attention_mask) -> torch.Tensor: values = self.value_head(last_hidden_state) # Remove the last dimension, since there is only a single value per token. - value = values.mean(dim=1).squeeze(-1) + value = values.squeeze(-1) return value diff --git a/Finetune/plugin/trainer/rm_trainer.py b/Finetune/plugin/trainer/rm_trainer.py index c69d0c1f2..8ae0cd381 100644 --- a/Finetune/plugin/trainer/rm_trainer.py +++ b/Finetune/plugin/trainer/rm_trainer.py @@ -1,6 +1,8 @@ from .trainer import Trainer from itertools import chain +import os import torch +from torch.utils.tensorboard import SummaryWriter import transformers import math import time @@ -57,14 +59,17 @@ def compute_loss(self, batch, return_outputs=False): def train(self): num_train_epochs = self.config.get("num_train_epochs", 1) log_step = self.config.get("log_step", 1) + if not os.path.exists(self.config.get("log_path", ".")): + os.makedirs(self.config.get("log_path", "."), exist_ok=True) + writer = SummaryWriter(self.config.get("log_path", ".")) for idx in range(num_train_epochs): logger.info(f"start train epoch {idx}") self.model.train() start = time.time() for step, batch in enumerate(self.train_dataloader): with self.accelerator.accumulate(self.model): - batch = dict(zip(batch.keys(), map(lambda x: x.unsqueeze(1), batch.values()))) loss = self.compute_loss(batch) + writer.add_scalar('training loss', loss, step) self.accelerator.backward(loss) self.optimizer.step() if self.lr_scheduler is not None: @@ -92,3 +97,5 @@ def train(self): eval_loss = float("inf") perplexity = float("inf") logger.info(f"eval epoch:[{idx}/{num_train_epochs}]\tloss:[{eval_loss}]\tppl:[{perplexity}]\ttime:[{time.time()-start}]") + writer.add_scalar('eval loss', eval_loss, idx) + writer.add_scalar('perplexity', perplexity, idx) \ No newline at end of file diff --git a/Finetune/rl_algo/ppo/ppo_rlhf.py b/Finetune/rl_algo/ppo/ppo_rlhf.py index 4c4ba32c8..44e8921ad 100644 --- a/Finetune/rl_algo/ppo/ppo_rlhf.py +++ b/Finetune/rl_algo/ppo/ppo_rlhf.py @@ -117,7 +117,7 @@ def training_step(self): policies_to_update = {DEFAULT_POLICY_ID} kl_dict = { - pid: train_results[pid].get("mean_kl_loss") + pid: train_results[pid][LEARNER_STATS_KEY].get("kl") for pid in policies_to_update } self.learner_group.additional_update( @@ -132,7 +132,7 @@ def evaluate(self): # breakpoint() train_batch = self.sampler.sample(batch_size=1) - rewards = train_batch[SampleBatch.INFOS][0]['r_align'] + rewards = train_batch[SampleBatch.INFOS]['r_align'] self.evaluation_metrics = {"evaluation": { diff --git a/Finetune/rl_algo/ppo/rlhf_ppo_torch_learner.py b/Finetune/rl_algo/ppo/rlhf_ppo_torch_learner.py index 4e77004a6..c4eb5b587 100644 --- a/Finetune/rl_algo/ppo/rlhf_ppo_torch_learner.py +++ b/Finetune/rl_algo/ppo/rlhf_ppo_torch_learner.py @@ -32,12 +32,10 @@ class RLHFPPOTorchLearner(PPOTorchLearner): @override(PPOTorchLearner) - def compute_loss_for_module( + def compute_loss_per_module( self, - *, module_id: str, - hps: PPOLearnerHyperparameters, - batch: NestedDict, + batch: SampleBatch, fwd_out: Mapping[str, TensorType] ) -> TensorType: """Extention of PPO loss function to support RLHF. @@ -62,10 +60,9 @@ def compute_loss_for_module( logp_ratio = masked_mean(logp_ratio_unmasked, attention_mask, dim=-1) # Only calculate kl loss if necessary (kl-coeff > 0.0). - # if self.hps.kl_coeff > 0.0: - if hps.use_kl_loss: + if self.hps.kl_coeff > 0.0: action_kl = prev_action_dist.kl(curr_action_dist) - mean_kl_loss = masked_mean(action_kl, attention_mask, dim=-1).mean() + mean_kl_loss = torch.mean(action_kl) if mean_kl_loss.isinf(): logger.warning( "KL divergence is non-finite, this will likely destabilize " @@ -74,8 +71,8 @@ def compute_loss_for_module( "This can happen naturally in deterministic " "environments where the optimal policy has zero mass " "for a specific action. To fix this issue, consider " - "setting the coefficient for the KL loss term to " - "zero or increasing policy entropy." + "setting `kl_coeff` to 0.0 or increasing `entropy_coeff` in your " + "config." ) else: mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device) @@ -91,7 +88,7 @@ def compute_loss_for_module( ) # Compute a value function loss. - if hps.use_critic: + if self.hps.use_critic: value_fn_out = fwd_out[SampleBatch.VF_PREDS] vf_loss = torch.pow(value_fn_out - batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss_clipped = torch.clamp(vf_loss, 0, self.hps.vf_clip_param) @@ -104,34 +101,30 @@ def compute_loss_for_module( vf_loss_clipped = mean_vf_loss = torch.tensor(0.0).to(surrogate_loss.device) total_loss = torch.mean( - -surrogate_loss - + hps.vf_loss_coeff * vf_loss_clipped - - ( - self.entropy_coeff_schedulers_per_module[module_id].get_current_value() - * curr_entropy - ) + surrogate_loss + + self.hps.vf_loss_coeff * vf_loss_clipped + - self.entropy_coeff_scheduler.get_current_value(module_id) * curr_entropy ) + # Add mean_kl_loss (already processed through `reduce_mean_valid`), # if necessary. - if hps.use_kl_loss: - # if self.hps.kl_coeff > 0.0: - # total_loss += self.kl_coeff * mean_kl_loss + if self.hps.kl_coeff > 0.0: total_loss += self.curr_kl_coeffs_per_module[module_id] * mean_kl_loss - # Register important loss stats. - self.register_metrics( - module_id, - { - POLICY_LOSS_KEY: -torch.mean(surrogate_loss), - VF_LOSS_KEY: mean_vf_loss, - LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY: mean_vf_unclipped_loss, - LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY: explained_variance( - batch[Postprocessing.VALUE_TARGETS], value_fn_out - ), - ENTROPY_KEY: mean_entropy, - LEARNER_RESULTS_KL_KEY: mean_kl_loss, - }, - ) - - return total_loss \ No newline at end of file + return { + self.TOTAL_LOSS_KEY: total_loss, + "policy_loss": torch.mean(surrogate_loss), + "vf_loss": mean_vf_loss, + "unclipped_vf_loss": mean_vf_unclipped_loss, + "vf_explained_var": explained_variance( + batch[Postprocessing.VALUE_TARGETS], value_fn_out + ), + "entropy": mean_entropy, + "kl": mean_kl_loss, + "entropy_coeff": self.entropy_coeff_scheduler.get_current_value(module_id), + "cur_kl_coeff": self.curr_kl_coeffs_per_module[module_id], + "mean_reward_total": batch[SampleBatch.REWARDS].mean(), + "mean_reward_rm": batch[SampleBatch.INFOS]["r_align"].mean(), + "mean_reward_kl": batch[SampleBatch.INFOS]["r_kl"].mean(), + } \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 943883010..d0256a8d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,6 @@ gradio gymnasium dm-tree scikit-image -einops \ No newline at end of file +tensorboard +einops +