Skip to content

Commit

Permalink
Adapt to ray 2.5.0 for RLHF pipeline (intel#25)
Browse files Browse the repository at this point in the history
* [rlhf] adapt to ray 2.5.0

* [rlhf] add post_init for RewardModel

* [rlhf] support neox reward model

* [rlhf] adapt values dims

* add tensorboard for rm trainer

* remove duplicated code of reward model
  • Loading branch information
zhangjian94cn authored Jul 12, 2023
1 parent 1f84466 commit e6e9f93
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 43 deletions.
4 changes: 2 additions & 2 deletions Finetune/plugin/agentenv/rlhf_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def step(self, action):

reward = r_align - self.kl_coeff * r_kl

info = [{
info = {
"r_align": r_align,
"r_kl": r_kl,
"n_response_tokens": n_response_tokens
}]
}

# Produce a random reward when we reach the goal.
return self.observation_space.sample(), reward, True, False, info
11 changes: 9 additions & 2 deletions Finetune/plugin/model/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)

# The additional value head.
self.value_head = nn.Linear(self.config.n_embd, 1)
if hasattr(self.config, 'hidden_size'):
self.value_head = nn.Linear(self.config.hidden_size, 1)
elif hasattr(self.config, 'n_embd'):
self.value_head = nn.Linear(self.config.n_embd, 1)
else:
raise ValueError("current model does not support")

self.post_init()

def forward(
self,
Expand Down Expand Up @@ -65,7 +72,7 @@ def value(self, input_ids, attention_mask) -> torch.Tensor:

values = self.value_head(last_hidden_state)
# Remove the last dimension, since there is only a single value per token.
value = values.mean(dim=1).squeeze(-1)
value = values.squeeze(-1)

return value

Expand Down
9 changes: 8 additions & 1 deletion Finetune/plugin/trainer/rm_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .trainer import Trainer
from itertools import chain
import os
import torch
from torch.utils.tensorboard import SummaryWriter
import transformers
import math
import time
Expand Down Expand Up @@ -57,14 +59,17 @@ def compute_loss(self, batch, return_outputs=False):
def train(self):
num_train_epochs = self.config.get("num_train_epochs", 1)
log_step = self.config.get("log_step", 1)
if not os.path.exists(self.config.get("log_path", ".")):
os.makedirs(self.config.get("log_path", "."), exist_ok=True)
writer = SummaryWriter(self.config.get("log_path", "."))
for idx in range(num_train_epochs):
logger.info(f"start train epoch {idx}")
self.model.train()
start = time.time()
for step, batch in enumerate(self.train_dataloader):
with self.accelerator.accumulate(self.model):
batch = dict(zip(batch.keys(), map(lambda x: x.unsqueeze(1), batch.values())))
loss = self.compute_loss(batch)
writer.add_scalar('training loss', loss, step)
self.accelerator.backward(loss)
self.optimizer.step()
if self.lr_scheduler is not None:
Expand Down Expand Up @@ -92,3 +97,5 @@ def train(self):
eval_loss = float("inf")
perplexity = float("inf")
logger.info(f"eval epoch:[{idx}/{num_train_epochs}]\tloss:[{eval_loss}]\tppl:[{perplexity}]\ttime:[{time.time()-start}]")
writer.add_scalar('eval loss', eval_loss, idx)
writer.add_scalar('perplexity', perplexity, idx)
4 changes: 2 additions & 2 deletions Finetune/rl_algo/ppo/ppo_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def training_step(self):

policies_to_update = {DEFAULT_POLICY_ID}
kl_dict = {
pid: train_results[pid].get("mean_kl_loss")
pid: train_results[pid][LEARNER_STATS_KEY].get("kl")
for pid in policies_to_update
}
self.learner_group.additional_update(
Expand All @@ -132,7 +132,7 @@ def evaluate(self):
# breakpoint()

train_batch = self.sampler.sample(batch_size=1)
rewards = train_batch[SampleBatch.INFOS][0]['r_align']
rewards = train_batch[SampleBatch.INFOS]['r_align']

self.evaluation_metrics = {"evaluation":
{
Expand Down
63 changes: 28 additions & 35 deletions Finetune/rl_algo/ppo/rlhf_ppo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@
class RLHFPPOTorchLearner(PPOTorchLearner):

@override(PPOTorchLearner)
def compute_loss_for_module(
def compute_loss_per_module(
self,
*,
module_id: str,
hps: PPOLearnerHyperparameters,
batch: NestedDict,
batch: SampleBatch,
fwd_out: Mapping[str, TensorType]
) -> TensorType:
"""Extention of PPO loss function to support RLHF.
Expand All @@ -62,10 +60,9 @@ def compute_loss_for_module(
logp_ratio = masked_mean(logp_ratio_unmasked, attention_mask, dim=-1)

# Only calculate kl loss if necessary (kl-coeff > 0.0).
# if self.hps.kl_coeff > 0.0:
if hps.use_kl_loss:
if self.hps.kl_coeff > 0.0:
action_kl = prev_action_dist.kl(curr_action_dist)
mean_kl_loss = masked_mean(action_kl, attention_mask, dim=-1).mean()
mean_kl_loss = torch.mean(action_kl)
if mean_kl_loss.isinf():
logger.warning(
"KL divergence is non-finite, this will likely destabilize "
Expand All @@ -74,8 +71,8 @@ def compute_loss_for_module(
"This can happen naturally in deterministic "
"environments where the optimal policy has zero mass "
"for a specific action. To fix this issue, consider "
"setting the coefficient for the KL loss term to "
"zero or increasing policy entropy."
"setting `kl_coeff` to 0.0 or increasing `entropy_coeff` in your "
"config."
)
else:
mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device)
Expand All @@ -91,7 +88,7 @@ def compute_loss_for_module(
)

# Compute a value function loss.
if hps.use_critic:
if self.hps.use_critic:
value_fn_out = fwd_out[SampleBatch.VF_PREDS]
vf_loss = torch.pow(value_fn_out - batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_loss_clipped = torch.clamp(vf_loss, 0, self.hps.vf_clip_param)
Expand All @@ -104,34 +101,30 @@ def compute_loss_for_module(
vf_loss_clipped = mean_vf_loss = torch.tensor(0.0).to(surrogate_loss.device)

total_loss = torch.mean(
-surrogate_loss
+ hps.vf_loss_coeff * vf_loss_clipped
- (
self.entropy_coeff_schedulers_per_module[module_id].get_current_value()
* curr_entropy
)
surrogate_loss
+ self.hps.vf_loss_coeff * vf_loss_clipped
- self.entropy_coeff_scheduler.get_current_value(module_id) * curr_entropy
)


# Add mean_kl_loss (already processed through `reduce_mean_valid`),
# if necessary.
if hps.use_kl_loss:
# if self.hps.kl_coeff > 0.0:
# total_loss += self.kl_coeff * mean_kl_loss
if self.hps.kl_coeff > 0.0:
total_loss += self.curr_kl_coeffs_per_module[module_id] * mean_kl_loss

# Register important loss stats.
self.register_metrics(
module_id,
{
POLICY_LOSS_KEY: -torch.mean(surrogate_loss),
VF_LOSS_KEY: mean_vf_loss,
LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY: mean_vf_unclipped_loss,
LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY: explained_variance(
batch[Postprocessing.VALUE_TARGETS], value_fn_out
),
ENTROPY_KEY: mean_entropy,
LEARNER_RESULTS_KL_KEY: mean_kl_loss,
},
)

return total_loss
return {
self.TOTAL_LOSS_KEY: total_loss,
"policy_loss": torch.mean(surrogate_loss),
"vf_loss": mean_vf_loss,
"unclipped_vf_loss": mean_vf_unclipped_loss,
"vf_explained_var": explained_variance(
batch[Postprocessing.VALUE_TARGETS], value_fn_out
),
"entropy": mean_entropy,
"kl": mean_kl_loss,
"entropy_coeff": self.entropy_coeff_scheduler.get_current_value(module_id),
"cur_kl_coeff": self.curr_kl_coeffs_per_module[module_id],
"mean_reward_total": batch[SampleBatch.REWARDS].mean(),
"mean_reward_rm": batch[SampleBatch.INFOS]["r_align"].mean(),
"mean_reward_kl": batch[SampleBatch.INFOS]["r_kl"].mean(),
}
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ gradio
gymnasium
dm-tree
scikit-image
einops
tensorboard
einops

0 comments on commit e6e9f93

Please sign in to comment.