Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process reward models #2241

Merged
merged 26 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2ca689c
adding model_cfg to set num_labels
SalmanMohammadi Jan 7, 2025
9f5a8e0
using a num_labels field instead
SalmanMohammadi Jan 7, 2025
9229435
Merge branch 'main' into fix_reward_model
SalmanMohammadi Jan 7, 2025
3baaa76
linting
SalmanMohammadi Jan 7, 2025
f81b174
WIP stepwise prompt tokenizer
SalmanMohammadi Jan 8, 2025
0630baa
this should work?
SalmanMohammadi Jan 8, 2025
796fd14
trainer working?
SalmanMohammadi Jan 8, 2025
a6ee075
pushing to runpod
SalmanMohammadi Jan 9, 2025
57050d4
fixing saving
SalmanMohammadi Jan 20, 2025
0f94239
updating conf
SalmanMohammadi Jan 21, 2025
1291d22
merging main
SalmanMohammadi Jan 22, 2025
3107e2a
updating config, adding docs
SalmanMohammadi Jan 22, 2025
034f303
adding stepwise supervision docpage
SalmanMohammadi Jan 22, 2025
0f0662b
updating tests
SalmanMohammadi Jan 22, 2025
a302348
adding test for dataset
SalmanMohammadi Jan 23, 2025
71b0f39
fixing tests
SalmanMohammadi Jan 24, 2025
5ee3876
linting
SalmanMohammadi Jan 24, 2025
bba11c7
addressing some comments
SalmanMohammadi Jan 25, 2025
f55fb7d
adding additional cfg fields support
SalmanMohammadi Jan 27, 2025
88fddfc
updating tests, fixing cfg
SalmanMohammadi Jan 27, 2025
f27ca55
fixing tests
SalmanMohammadi Jan 28, 2025
9d5bb17
updating loss
SalmanMohammadi Jan 28, 2025
b2e5ac7
Update test_process_reward_model_smollm2.py
SalmanMohammadi Jan 28, 2025
5b53586
updating loss values and seed
SalmanMohammadi Jan 28, 2025
b88d37a
dumb pre-commit
SalmanMohammadi Jan 28, 2025
2ccf31f
Merge branch 'main' into fix_reward_model
SalmanMohammadi Jan 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 6.1.0
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint
Expand Down
1 change: 1 addition & 0 deletions examples/gemma2/reward-model.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
base_model: google/gemma-2-2b
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
Expand Down
67 changes: 67 additions & 0 deletions examples/qwen2/prm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
base_model: Qwen/Qwen2.5-0.5B
# optionally might have model_type or tokenizer_type
model_type: AutoModelForTokenClassification
num_labels: 2
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

process_reward_model: true
chat_template:
datasets:
- path: trl-lib/prm800k
type: stepwise_supervised
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false

sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:


gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16:
fp16: true
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: false
flash_attention: false

warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
67 changes: 67 additions & 0 deletions examples/qwen2/reward-model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
base_model: Qwen/Qwen2.5-0.5B
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

reward_model: true
chat_template: qwen_25
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false

sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:


gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
2 changes: 1 addition & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):

# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
Expand Down
36 changes: 29 additions & 7 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
KTOTrainer,
ORPOConfig,
ORPOTrainer,
PRMConfig,
PRMTrainer,
RewardConfig,
RewardTrainer,
)
Expand Down Expand Up @@ -339,6 +341,13 @@ class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""


@dataclass
class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig):
"""
PRM config for PRM training
"""


class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
Expand Down Expand Up @@ -1300,6 +1309,14 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call


class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""

tag_names = ["axolotl", "prm"]


class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
Expand Down Expand Up @@ -1433,7 +1450,8 @@ def hook_post_create_trainer(self, trainer):

class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
Build the HuggingFace training args/trainer for causal models
and reward modelling using TRL.
"""

def get_callbacks(self):
Expand Down Expand Up @@ -1508,6 +1526,8 @@ def _get_trainer_cls(self):
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
return AxolotlTrainer

def build(self, total_num_steps):
Expand Down Expand Up @@ -1897,11 +1917,13 @@ def build(self, total_num_steps):
"accelerator_config"
] = self.cfg.accelerator_config

training_args_cls = (
AxolotlTrainingArguments
if not self.cfg.reward_model
else AxolotlRewardConfig
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved

training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
Expand Down Expand Up @@ -2035,7 +2057,7 @@ def build_collator(

class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for DPO Trainer
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""

def get_callbacks(self):
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def process(self, dataset):
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
import pdb
pdb.set_trace()
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
Expand Down
118 changes: 118 additions & 0 deletions src/axolotl/prompt_strategies/stepwise_supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Module for stepwise datasets, typically including a prompt and reasoning traces,
and (optionally) per-step, or per-prompt-trace labels for reward modelling.
"""

from itertools import chain

from typing import Dict, Generator, List, Optional, Union

from transformers import BatchEncoding, PreTrainedTokenizer

from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.dict import DictDefault


class StepwiseSupervisedPromptTokenizingStrategy:
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
"""
Tokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.
These datasets should include the following columns:
- prompt: the prompt text
- completions: a list of `n` completion steps
- labels: a list of `n` labels indicating the "correctness" of each step
"""

def __init__(
self,
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
step_separator: str = "\n",
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
max_completion_length: Optional[int] = None,
train_on_last_step_only: bool = False,
is_eval: bool = False,
):
self.tokenizer = tokenizer
self.train_on_inputs = train_on_inputs
self.sequence_len = sequence_len
self.step_separator = step_separator
self.max_completion_length = max_completion_length
self.train_on_last_step_only = train_on_last_step_only
self.is_eval = is_eval

def tokenize_prompt(
self, prompt: Dict[str, Union[str, List[str]]]
) -> BatchEncoding:
# Inspired by TRL's PRMTRainer
# https://github.com/huggingface/trl/blob/ed7de87dc766478c024b68f12530d1b0e7c3ff23/trl/trainer/prm_trainer.py#L206
prompt_ids = self.tokenizer(prompt["prompt"], add_special_tokens=False)[
"input_ids"
]

completions_ids = [
self.tokenizer(completion, add_special_tokens=False)["input_ids"]
for completion in prompt["completions"]
]

# Handle labels
if self.train_on_last_step_only and not self.is_eval:
labels = [-100] * (len(prompt["labels"]) - 1) + [int(prompt["labels"][-1])]
else:
labels = [int(label) for label in prompt["labels"]]

# Add step separators
separator_ids = self.tokenizer.encode(
self.step_separator, add_special_tokens=False
)
completions_ids = [completion + separator_ids for completion in completions_ids]

# Create step-wise labels
labels = [
[-100] * (len(completion) - 1) + [label]
for completion, label in zip(completions_ids, labels)
]

# Join all steps
completion_ids = list(chain(*completions_ids))
labels = list(chain(*labels))

# Handle max lengths
if self.max_completion_length:
completion_ids = completion_ids[: self.max_completion_length]
labels = labels[: self.max_completion_length]

# Add BOS token if model has one
if self.tokenizer.bos_token_id is not None:
prompt_ids = [self.tokenizer.bos_token_id] + prompt_ids

# Combine prompt and completion
input_ids = prompt_ids + completion_ids
full_labels = [-100] * len(prompt_ids) + labels

# Apply max sequence length
if self.sequence_len:
input_ids = input_ids[: self.sequence_len]
full_labels = full_labels[: self.sequence_len]

return BatchEncoding(
{
"input_ids": input_ids,
"labels": full_labels,
"attention_mask": [1] * len(input_ids),
}
)

@property
def supports_batched(self):
return False


def load(
tokenizer: PreTrainedTokenizer, cfg: DictDefault
) -> StepwiseSupervisedPromptTokenizingStrategy:
return StepwiseSupervisedPromptTokenizingStrategy(
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
4 changes: 3 additions & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ class ModelInputConfig(BaseModel):

base_model: str
base_model_config: Optional[str] = None
cls_model_config: Optional[str] = None
cls_model_config: Optional[str] = None #
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
tokenizer_config: Optional[str] = None
tokenizer_use_fast: Optional[bool] = None
tokenizer_legacy: Optional[bool] = None
Expand Down Expand Up @@ -608,6 +608,8 @@ class Config:

rl: Optional[RLType] = None
reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None
num_labels: Optional[int] = None
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
Expand Down
Loading
Loading