Skip to content

Commit

Permalink
Process reward models (#2241)
Browse files Browse the repository at this point in the history
* adding model_cfg to set num_labels

* using a num_labels field instead

* linting

* WIP stepwise prompt tokenizer

* this should work?

* trainer working?

* pushing to runpod

* fixing saving

* updating conf

* updating config, adding docs

* adding stepwise supervision docpage

* updating tests

* adding test for dataset

* fixing tests

* linting

* addressing some comments

* adding additional cfg fields support

* updating tests, fixing cfg

* fixing tests

* updating loss

* Update test_process_reward_model_smollm2.py

* updating loss values and seed

* dumb pre-commit
  • Loading branch information
SalmanMohammadi authored Jan 29, 2025
1 parent c071a53 commit 54dd7ab
Show file tree
Hide file tree
Showing 17 changed files with 542 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 6.1.0
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint
Expand Down
6 changes: 6 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:

# reward modelling: `True` or `False`
reward_model:

# process reward modelling: `True` or `False`
process_reward_model:

# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
Expand Down
18 changes: 18 additions & 0 deletions docs/dataset-formats/stepwise_supervised.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
title: Stepwise Supervised Format
description: Format for datasets with stepwise completions and labels
order: 3
---

## Stepwise Supervised

The stepwise supervised format is designed for chain-of-thought (COT) reasoning datasets where each example contains multiple completion steps and a preference label for each step.
### ExampleHere's a simple example of a stepwise supervised dataset entry:```json
{
"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": [
"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.",
"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."
],
"labels": [true, false]
}
47 changes: 47 additions & 0 deletions docs/reward_modelling.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
---
title: "Reward Modelling"
description: "Reward models are used to guide models towards behaviors which is preferred by humans, by training over large datasets annotated with human preferences. "
---

### Overview

Reward modelling is a technique used to train models to predict the reward or value of a given input. This is particularly useful in reinforcement learning scenarios where the model needs to evaluate the quality of its actions or predictions.
We support the reward modelling techniques supported by `trl`.

### (Outcome) Reward Models

Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).

```yaml
base_model: google/gemma-2-2b
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer

reward_model: true
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template

val_set_size: 0.1
eval_steps: 100
```
### Process Reward Models (PRM)
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
```yaml
base_model: Qwen/Qwen2.5-3B
model_type: AutoModelForTokenClassification
num_labels: 2

process_reward_model: true
datasets:
- path: trl-lib/math_shepherd
type: stepwise_supervised
split: train

val_set_size: 0.1
eval_steps: 100
```
1 change: 1 addition & 0 deletions examples/gemma2/reward-model.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
base_model: google/gemma-2-2b
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
Expand Down
72 changes: 72 additions & 0 deletions examples/qwen2/prm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
base_model: Qwen/Qwen2.5-3B
# optionally might have model_type or tokenizer_type
model_type: AutoModelForTokenClassification
num_labels: 2
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

process_reward_model: true
chat_template:
datasets:
- path: trl-lib/math_shepherd
type: stepwise_supervised
step_separator: "\n"
max_completion_length:
train_on_last_step_only: false

val_set_size: 0.2
output_dir: ./outputs/out
remove_unused_columns: false

sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:


gradient_accumulation_steps: 1
micro_batch_size: 8
eval_batch_size: 8
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32:
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
eval_steps: 100
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
67 changes: 67 additions & 0 deletions examples/qwen2/reward-model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
base_model: Qwen/Qwen2.5-0.5B
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

reward_model: true
chat_template: qwen_25
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false

sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:


gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
46 changes: 35 additions & 11 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
KTOTrainer,
ORPOConfig,
ORPOTrainer,
PRMConfig,
PRMTrainer,
RewardConfig,
RewardTrainer,
)
Expand Down Expand Up @@ -342,6 +344,13 @@ class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""


@dataclass
class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig):
"""
PRM config for PRM training
"""


class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
Expand Down Expand Up @@ -1244,6 +1253,14 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]


class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""

tag_names = ["axolotl", "prm"]


class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
Expand Down Expand Up @@ -1377,7 +1394,8 @@ def hook_post_create_trainer(self, trainer):

class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
Build the HuggingFace training args/trainer for causal models
and reward modelling using TRL.
"""

def get_callbacks(self):
Expand Down Expand Up @@ -1452,6 +1470,8 @@ def _get_trainer_cls(self):
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
return AxolotlTrainer

def build(self, total_num_steps):
Expand Down Expand Up @@ -1842,11 +1862,13 @@ def build(self, total_num_steps):
"accelerator_config"
] = self.cfg.accelerator_config

training_args_cls = (
AxolotlTrainingArguments
if not self.cfg.reward_model
else AxolotlRewardConfig
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments

training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
Expand Down Expand Up @@ -1880,9 +1902,9 @@ def build(self, total_num_steps):
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not self.cfg.reward_model:
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not self.cfg.reward_model:
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
Expand All @@ -1893,8 +1915,10 @@ def build(self, total_num_steps):
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer

if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
Expand Down Expand Up @@ -1984,7 +2008,7 @@ def build_collator(

class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for DPO Trainer
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""

def get_callbacks(self):
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def process(self, dataset):
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100

return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
Expand Down
Loading

0 comments on commit 54dd7ab

Please sign in to comment.