-
-
Notifications
You must be signed in to change notification settings - Fork 944
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding model_cfg to set num_labels * using a num_labels field instead * linting * WIP stepwise prompt tokenizer * this should work? * trainer working? * pushing to runpod * fixing saving * updating conf * updating config, adding docs * adding stepwise supervision docpage * updating tests * adding test for dataset * fixing tests * linting * addressing some comments * adding additional cfg fields support * updating tests, fixing cfg * fixing tests * updating loss * Update test_process_reward_model_smollm2.py * updating loss values and seed * dumb pre-commit
- Loading branch information
1 parent
c071a53
commit 54dd7ab
Showing
17 changed files
with
542 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
--- | ||
title: Stepwise Supervised Format | ||
description: Format for datasets with stepwise completions and labels | ||
order: 3 | ||
--- | ||
|
||
## Stepwise Supervised | ||
|
||
The stepwise supervised format is designed for chain-of-thought (COT) reasoning datasets where each example contains multiple completion steps and a preference label for each step. | ||
### ExampleHere's a simple example of a stepwise supervised dataset entry:```json | ||
{ | ||
"prompt": "Which number is larger, 9.8 or 9.11?", | ||
"completions": [ | ||
"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", | ||
"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8." | ||
], | ||
"labels": [true, false] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
--- | ||
title: "Reward Modelling" | ||
description: "Reward models are used to guide models towards behaviors which is preferred by humans, by training over large datasets annotated with human preferences. " | ||
--- | ||
|
||
### Overview | ||
|
||
Reward modelling is a technique used to train models to predict the reward or value of a given input. This is particularly useful in reinforcement learning scenarios where the model needs to evaluate the quality of its actions or predictions. | ||
We support the reward modelling techniques supported by `trl`. | ||
|
||
### (Outcome) Reward Models | ||
|
||
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step). | ||
|
||
```yaml | ||
base_model: google/gemma-2-2b | ||
model_type: AutoModelForSequenceClassification | ||
num_labels: 1 | ||
tokenizer_type: AutoTokenizer | ||
|
||
reward_model: true | ||
chat_template: gemma | ||
datasets: | ||
- path: argilla/distilabel-intel-orca-dpo-pairs | ||
type: bradley_terry.chat_template | ||
|
||
val_set_size: 0.1 | ||
eval_steps: 100 | ||
``` | ||
### Process Reward Models (PRM) | ||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning. | ||
```yaml | ||
base_model: Qwen/Qwen2.5-3B | ||
model_type: AutoModelForTokenClassification | ||
num_labels: 2 | ||
|
||
process_reward_model: true | ||
datasets: | ||
- path: trl-lib/math_shepherd | ||
type: stepwise_supervised | ||
split: train | ||
|
||
val_set_size: 0.1 | ||
eval_steps: 100 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
base_model: Qwen/Qwen2.5-3B | ||
# optionally might have model_type or tokenizer_type | ||
model_type: AutoModelForTokenClassification | ||
num_labels: 2 | ||
tokenizer_type: AutoTokenizer | ||
# Automatically upload checkpoint and final model to HF | ||
# hub_model_id: username/custom_model_name | ||
|
||
load_in_8bit: false | ||
load_in_4bit: false | ||
strict: false | ||
|
||
process_reward_model: true | ||
chat_template: | ||
datasets: | ||
- path: trl-lib/math_shepherd | ||
type: stepwise_supervised | ||
step_separator: "\n" | ||
max_completion_length: | ||
train_on_last_step_only: false | ||
|
||
val_set_size: 0.2 | ||
output_dir: ./outputs/out | ||
remove_unused_columns: false | ||
|
||
sequence_len: 2048 | ||
sample_packing: false | ||
eval_sample_packing: false | ||
pad_to_sequence_len: true | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_name: | ||
wandb_log_model: | ||
|
||
|
||
gradient_accumulation_steps: 1 | ||
micro_batch_size: 8 | ||
eval_batch_size: 8 | ||
num_epochs: 1 | ||
optimizer: adamw_torch | ||
lr_scheduler: cosine | ||
learning_rate: 0.0002 | ||
|
||
train_on_inputs: false | ||
group_by_length: false | ||
bf16: true | ||
fp16: | ||
tf32: | ||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: false | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
xformers_attention: | ||
flash_attention: true | ||
|
||
warmup_ratio: 0.1 | ||
evals_per_epoch: | ||
eval_table_size: | ||
eval_max_new_tokens: 128 | ||
eval_steps: 100 | ||
saves_per_epoch: 1 | ||
debug: | ||
deepspeed: | ||
weight_decay: 0.0 | ||
fsdp: | ||
fsdp_config: | ||
special_tokens: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
base_model: Qwen/Qwen2.5-0.5B | ||
# optionally might have model_type or tokenizer_type | ||
model_type: AutoModelForSequenceClassification | ||
num_labels: 1 | ||
tokenizer_type: AutoTokenizer | ||
# Automatically upload checkpoint and final model to HF | ||
# hub_model_id: username/custom_model_name | ||
|
||
load_in_8bit: false | ||
load_in_4bit: false | ||
strict: false | ||
|
||
reward_model: true | ||
chat_template: qwen_25 | ||
datasets: | ||
- path: argilla/distilabel-intel-orca-dpo-pairs | ||
type: bradley_terry.chat_template | ||
val_set_size: 0.0 | ||
output_dir: ./outputs/out | ||
remove_unused_columns: false | ||
|
||
sequence_len: 2048 | ||
sample_packing: false | ||
eval_sample_packing: false | ||
pad_to_sequence_len: true | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_name: | ||
wandb_log_model: | ||
|
||
|
||
gradient_accumulation_steps: 4 | ||
micro_batch_size: 2 | ||
num_epochs: 4 | ||
optimizer: adamw_bnb_8bit | ||
lr_scheduler: cosine | ||
learning_rate: 0.0002 | ||
|
||
train_on_inputs: false | ||
group_by_length: false | ||
bf16: true | ||
fp16: | ||
tf32: true | ||
|
||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: false | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
xformers_attention: | ||
flash_attention: true | ||
|
||
warmup_ratio: 0.1 | ||
evals_per_epoch: | ||
eval_table_size: | ||
eval_max_new_tokens: 128 | ||
saves_per_epoch: 1 | ||
debug: | ||
deepspeed: | ||
weight_decay: 0.0 | ||
fsdp: | ||
fsdp_config: | ||
special_tokens: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.