Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

💡 GRPO vram-efficiency improvement; only compute relevant logprobs #2773

Merged
merged 12 commits into from
Feb 6, 2025

Conversation

tyler-romero
Copy link
Contributor

@tyler-romero tyler-romero commented Feb 5, 2025

What does this PR do?

GRPOTrainer uses a method _get_per_token_logps to compute the per-token logprobs for every token in the input sequence. However, it uses seq_len * vocab_size additional memory in order to compute log_softmax, generating full log-probabilities for every possible token in the vocabulary at every index in the sequence. The next step then selects only the the actual input tokens to get the per-token logprobs for the input sequence.

This can be made more efficient by performing the selection first to get per-token logits, computing the softmax denominator (a reduction over the full set of logits), and then directly computing the logits only for the relevant tokens. This requires only seq_len additional memory.

Fixes # (issue)
NA but I'm happy to file an issue if needed

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@tyler-romero
Copy link
Contributor Author

image

Benchmarked before and after with:
CUDA_VISIBLE_DEVICES="0" python bench.py

from datasets import load_dataset

from trl import GRPOConfig, GRPOTrainer


dataset = load_dataset("trl-lib/tldr", split="train")

# Make everything deterministic
import torch


torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# Dummy reward function: the closer the completion is to 20 characters, the higher the reward
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]


training_args = GRPOConfig(
    output_dir="Qwen2.5-0.5B-GRPO-main",
    logging_steps=2,
    gradient_accumulation_steps=1,
    per_device_train_batch_size=2,
    max_steps=20,
    report_to="wandb",
    bf16=True,
    max_completion_length=128,
    max_prompt_length=128,
)
trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

@tyler-romero
Copy link
Contributor Author

@qgallouedec

@tyler-romero tyler-romero changed the title GRPO memory-efficiency improvement; only compute relevant logprobs GRPO vram-efficiency improvement; only compute relevant logprobs Feb 5, 2025
@tyler-romero tyler-romero marked this pull request as ready for review February 5, 2025 19:41
@tyler-romero
Copy link
Contributor Author

~~> CUDA_VISIBLE_DEVICES="0" pytest tests/test_grpo_trainer.py                                            
=========================================================================== test session starts ===========================================================================
platform linux -- Python 3.12.5, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/tromero/workspace/trl
configfile: pyproject.toml
plugins: anyio-4.8.0, xdist-3.6.1, rerunfailures-15.0, cov-6.0.0
collected 14 items                                                                                                                                                        

tests/test_grpo_trainer.py ..............                                                                                                                           [100%]

============================================================================ warnings summary =============================================================================
.venv/lib/python3.12/site-packages/mergekit/architecture.py:345
  /home/tromero/workspace/trl/.venv/lib/python3.12/site-packages/mergekit/architecture.py:345: DeprecationWarning: contents is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
    for f in importlib.resources.contents(mergekit._data.architectures):

.venv/lib/python3.12/site-packages/mergekit/architecture.py:335: 40 warnings
  /home/tromero/workspace/trl/.venv/lib/python3.12/site-packages/mergekit/architecture.py:335: DeprecationWarning: read_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
    text = importlib.resources.read_text(mergekit._data.architectures, name)

../../.local/share/uv/python/cpython-3.12.5-linux-x86_64-gnu/lib/python3.12/importlib/resources/_legacy.py:79: 40 warnings
  /home/tromero/.local/share/uv/python/cpython-3.12.5-linux-x86_64-gnu/lib/python3.12/importlib/resources/_legacy.py:79: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
    with open_text(package, resource, encoding, errors) as fp:

tests/test_grpo_trainer.py::GRPOTrainerTester::test_training_vllm
  /home/tromero/workspace/trl/trl/trainer/grpo_trainer.py:304: UserWarning: The requested device cuda:0 is also used for training. This may lead to unexpected behavior. It is recommended to use a dedicated device for vLLM.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================================== 14 passed, 82 warnings in 329.12s (0:05:29) ===============================================================

@qgallouedec
Copy link
Member

Super cool! thanks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

Trying to reproduce your method I think I've found something even better:

import torch


def original_method(logits, input_ids, logits_to_keep):
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids[:, -logits_to_keep:]):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)


def new_method_1(logits, input_ids, logits_to_keep):
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits[:, -logits_to_keep:], input_ids[:, -logits_to_keep:]):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(1)).squeeze(1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)


def new_method_2(logits, input_ids, logits_to_keep):
    input_ids = input_ids[:, -logits_to_keep:]
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs


def measure_memory_and_time(func, logits, input_ids, logits_to_keep):
    import time

    torch.cuda.reset_peak_memory_stats()
    start_time = time.time()
    result = func(logits, input_ids, logits_to_keep)
    end_time = time.time()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak


# Simulated data
torch.manual_seed(42)
vocab_size = 150000
seq_len = 512
batch_size = 8
logits_to_keep = 128

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, logits_to_keep, vocab_size, device=device)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)

# Run both methods
orig_result, orig_time, orig_mem = measure_memory_and_time(original_method, logits, input_ids, logits_to_keep)
new_result_1, new_time_1, new_mem_1 = measure_memory_and_time(new_method_1, logits, input_ids, logits_to_keep)
new_result_2, new_time_2, new_mem_2 = measure_memory_and_time(new_method_1, logits, input_ids, logits_to_keep)

# Check equivalence
print("Max absolute difference:", (orig_result - new_result_1).abs().max().item())
print("Max absolute difference:", (orig_result - new_result_2).abs().max().item())
print("Original time:     {:.6f} sec, Memory peak: {:.2f} MB".format(orig_time, orig_mem / 1e6))
print("New method 1 time: {:.6f} sec, Memory peak: {:.2f} MB".format(new_time_1, new_mem_1 / 1e6))
print("New method 2 time: {:.6f} sec, Memory peak: {:.2f} MB".format(new_time_2, new_mem_2 / 1e6))
Max absolute difference: 1.9073486328125e-06
Max absolute difference: 1.9073486328125e-06
Original time:     0.017965 sec, Memory peak: 769.69 MB
New method 1 time: 0.074697 sec, Memory peak: 692.11 MB
New method 2 time: 0.000901 sec, Memory peak: 692.11 MB

What do you think?

@tyler-romero
Copy link
Contributor Author

Nice! I didnt think to pull the gather out of the loop, let me incorporate that

@tyler-romero
Copy link
Contributor Author

Ok, updated with new_method_2

trl/trainer/grpo_trainer.py Outdated Show resolved Hide resolved
@tyler-romero
Copy link
Contributor Author

Thanks! Ok to merge now

@Superskyyy
Copy link
Contributor

Wow this is super nice! Thanks!

@qgallouedec qgallouedec linked an issue Feb 6, 2025 that may be closed by this pull request
5 tasks
@qgallouedec qgallouedec changed the title GRPO vram-efficiency improvement; only compute relevant logprobs 💡 GRPO vram-efficiency improvement; only compute relevant logprobs Feb 6, 2025
@qgallouedec qgallouedec merged commit a85768f into huggingface:main Feb 6, 2025
13 checks passed
@mst272
Copy link
Contributor

mst272 commented Feb 6, 2025

def new_method_2(logits, input_ids, logits_to_keep):
    input_ids = input_ids[:, -logits_to_keep:]
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

In new_method_2(logits, input_ids, logits_to_keep), mabye we could use logsumexp_values = torch.logsumexp(logits, dim=-1) instead of logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])

@qgallouedec
Copy link
Member

actually with logsumexp_values = torch.logsumexp(logits, dim=-1) you get a high memory peak, hence the inner loop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

model.forward requires num_logits_to_keep, not logits_to_keep
5 participants