Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ReMax Algorithm #2955

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

liziniu
Copy link

@liziniu liziniu commented Feb 25, 2025

What does this PR do?

This PR adds the ReMax component to TRL, implementing the algorithm from the ICML 2024 paper ReMax: A Simple, Effective, and Efficient Reinforcement Learning Method for Aligning Large Language Models.

Key features include:

  • Implementation of ReMaxTrainer for training language models using the ReMax algorithm
  • Support for both standard and conversational format datasets
  • Flexible reward function system that supports:
    • Pre-trained reward models
    • Custom reward functions
    • Multiple reward functions with optional weighting
  • Integration with vLLM for accelerated generation
  • Comprehensive test coverage including PEFT compatibility and torch.compile support
  • Detailed documentation with examples and usage guidelines

Before submitting

  • Did you read the contributor guideline
  • Did you make sure to update the documentation with your changes? Documentation includes:
    • Full ReMax algorithm explanation
    • Quick start guide
    • Detailed examples for custom reward functions
    • API reference for ReMaxTrainer and ReMaxConfig
  • Did you write necessary tests? Tests cover:
    • Basic training scenarios
    • PEFT integration
    • vLLM acceleration
    • Multiple reward functions
    • Custom reward functions
    • Conversational format support

Who can review?

Anyone with experience in reinforcement learning for language models would be great to review this PR. The implementation involves both training and generation components.

Test Results

All tests have been successfully executed:

tests/test_remax_trainer.py ................ss.....s...
24 passed, 3 skipped, 36 warnings in 1129.19s (0:18:49)
Click to see the test details
+ python -m pytest tests/test_remax_trainer.py
======================================== test session starts =========================================
platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0
rootdir: /220049033/project/github_repo/trl
configfile: pyproject.toml
plugins: cov-6.0.0, anyio-4.8.0, rerunfailures-15.0, xdist-3.6.1
collected 27 items                                                                                   

tests/test_remax_trainer.py ................ss.....s...                                        [100%]

========================================== warnings summary ==========================================
tests/test_remax_trainer.py: 17 warnings
  /220049033/venvs/trl/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
    warnings.warn(

tests/test_remax_trainer.py: 17 warnings
  /220049033/venvs/trl/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.
    warnings.warn(

tests/test_remax_trainer.py::ReMaxTrainerTester::test_training_vllm
tests/test_remax_trainer.py::ReMaxTrainerTester::test_training_vllm_guided_decoding
  /220049033/project/github_repo/trl/trl/trainer/remax_trainer.py:456: UserWarning: The requested device cuda:0 is also being used for training. For higher throughput and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. If this is intentional, you may ignore this warning but should adjust `vllm_gpu_memory_utilization` accordingly.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================== 24 passed, 3 skipped, 36 warnings in 1129.19s (0:18:49) =======================
[rank0]:[W225 07:43:25.960804091 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())

Empirical Effectiveness

The effectiveness of ReMax is compared with GRPO on fine-tuning Qwen-2.5-3B-Instruct on GSM8K dataset with the accuracy reward.

截屏2025-02-25 20 39 58 截屏2025-02-25 20 39 47
Click to see the training code
import re
import torch
import time

from datasets import load_dataset
from trl import ReMaxConfig, ReMaxTrainer

from math_verify import parse, verify

dataset = load_dataset("gsm8k", "main")
train_dataset = dataset["train"].map(
    lambda x: {
        "prompt": [{"role": "user", "content": x["question"]}], 
        "ground_truth": x["answer"].split("####")[-1].strip(),
        }, 
    num_proc=64,
)
eval_dataset = dataset["test"].map(
    lambda x: {
        "prompt": [{"role": "user", "content": x["question"]}], 
        "ground_truth": x["answer"].split("####")[-1].strip()
        }, 
    num_proc=64,
)
eval_dataset = eval_dataset.select(range(256))

def reward_func(completions, ground_truth, **kwargs):
    completion_answers = [parse(completion[0]["content"]) for completion in completions]
    ground_truth_answers = [parse(truth) for truth in ground_truth]
    return [1.0 if verify(answer, ground_truth) else 0.0 for answer, ground_truth in zip(completion_answers, ground_truth_answers)]

training_args = ReMaxConfig(
    output_dir="Qwen2.5-ReMax", 
    logging_steps=5, 
    report_to="none", 
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=16,
    use_vllm=False, 
    vllm_gpu_memory_utilization=0.85, 
    num_generations=2, 
    max_completion_length=1024,
    max_prompt_length=512,
    log_completions=True,
    model_init_kwargs={"attn_implementation": "flash_attention_2", "torch_dtype": torch.bfloat16},
    bf16=True,
    gradient_checkpointing=True,
    run_name="Qwen2.5-3B-ReMax" + "-" + time.strftime("%Y%m%d-%H%M%S"), 
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    learning_rate=2e-6,
    adam_beta1=0.9,
    adam_beta2=0.95,
    num_train_epochs=1,
    eval_strategy="steps",
    eval_steps=50,
    beta=0.01,
    temperature=1.0,
)

trainer = ReMaxTrainer(
    model="Qwen/Qwen2.5-3B-Instruct",
    reward_funcs=reward_func,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

@qgallouedec
Copy link
Member

Thank for this great work!! It seems very close to GRPO, can you summarize the key differences to make the reviewing a bit easier for me?

@liziniu
Copy link
Author

liziniu commented Feb 25, 2025

Hi, ReMax differs from GRPO in two key aspects: baseline estimation and advantage calculation.

Key Conceptual Differences

Baseline Estimation:

  • GRPO uses the averaged empirical reward as the baseline value
  • ReMax simply uses the reward value of a greedy decoding completion as the baseline

Advantage Calculation:

  • GRPO calculates the grouped mean and standard deviation for normalization
  • ReMax does not require this normalization step

Implementation Details

The implementation of remax_config.py is basically same with grpo_config.py, with modifications primarily in the trainer code's (remax_trainer.py) _generate_and_score_completions function (lines 690-880):

Key Modifications

  1. Lines 690-760: Modified generation to incorporate greedy decoding for baseline estimation

    • Sampling parameters for vllm/HF generation slightly changed to accommodate this
  2. Lines 760-808: Minimal changes to existing code

  3. Lines 810-860: Added calculation of rewards for greedy completions

  4. Lines 860-870: Direct advantage calculation without additional operations like gathering

Additional Changes

  • __init__ method:

    • Changed vllm sampling parameters by setting n = 1
    • This preserves the function of generating multiple completions since prompts are repeated in lines 690-760
  • compute_loss method (line 951):

    • Added dim=1 when calculating averaged loss
    • Loss is first normalized across timesteps then across different batches
    • This implementation follows the description in the paper

I also provide an introduction to ReMax at the docs.

If you have additional questions, please feel free to let me know.

@qgallouedec
Copy link
Member

Thanks! Can you try to integrate the very latest changes in GRPO?

@liziniu
Copy link
Author

liziniu commented Mar 3, 2025

Hi, I’ve integrated the latest changes from GRPO. Below is a summary of the updates:

  • Lines 736-802: Modified the sampling strategy for ReMax to incorporate greedy decoding, which improves baseline estimation for ReMax.

  • Lines 857-906: Added reward calculations for greedy completions. These rewards will be used to compute the advantage function.

  • Lines 908-915: Implemented a customized advantage calculation specifically tailored for ReMax.

Let me know if you have any questions or need further details!

@kashif
Copy link
Collaborator

kashif commented Mar 3, 2025

currently, the remax trainer file is a copy of the remax config file... is that a mistake?

@liziniu
Copy link
Author

liziniu commented Mar 3, 2025

Hi @kashif, thank you for pointing that out! It was my mistake to copy the wrong content. I’ve now fixed it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants