-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
💡 GRPO vram-efficiency improvement; only compute relevant logprobs #2773
💡 GRPO vram-efficiency improvement; only compute relevant logprobs #2773
Conversation
Benchmarked before and after with: from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
# Make everything deterministic
import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Dummy reward function: the closer the completion is to 20 characters, the higher the reward
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
training_args = GRPOConfig(
output_dir="Qwen2.5-0.5B-GRPO-main",
logging_steps=2,
gradient_accumulation_steps=1,
per_device_train_batch_size=2,
max_steps=20,
report_to="wandb",
bf16=True,
max_completion_length=128,
max_prompt_length=128,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train() |
…nsformer versions
|
Super cool! thanks! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Trying to reproduce your method I think I've found something even better: import torch
def original_method(logits, input_ids, logits_to_keep):
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
def new_method_1(logits, input_ids, logits_to_keep):
per_token_logps = []
for logits_row, input_ids_row in zip(logits[:, -logits_to_keep:], input_ids[:, -logits_to_keep:]):
token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(1)).squeeze(1)
token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
def new_method_2(logits, input_ids, logits_to_keep):
input_ids = input_ids[:, -logits_to_keep:]
token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
token_log_probs = token_logits - logsumexp_values
return token_log_probs
def measure_memory_and_time(func, logits, input_ids, logits_to_keep):
import time
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
result = func(logits, input_ids, logits_to_keep)
end_time = time.time()
mem_peak = torch.cuda.max_memory_allocated()
return result, end_time - start_time, mem_peak
# Simulated data
torch.manual_seed(42)
vocab_size = 150000
seq_len = 512
batch_size = 8
logits_to_keep = 128
device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, logits_to_keep, vocab_size, device=device)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
# Run both methods
orig_result, orig_time, orig_mem = measure_memory_and_time(original_method, logits, input_ids, logits_to_keep)
new_result_1, new_time_1, new_mem_1 = measure_memory_and_time(new_method_1, logits, input_ids, logits_to_keep)
new_result_2, new_time_2, new_mem_2 = measure_memory_and_time(new_method_1, logits, input_ids, logits_to_keep)
# Check equivalence
print("Max absolute difference:", (orig_result - new_result_1).abs().max().item())
print("Max absolute difference:", (orig_result - new_result_2).abs().max().item())
print("Original time: {:.6f} sec, Memory peak: {:.2f} MB".format(orig_time, orig_mem / 1e6))
print("New method 1 time: {:.6f} sec, Memory peak: {:.2f} MB".format(new_time_1, new_mem_1 / 1e6))
print("New method 2 time: {:.6f} sec, Memory peak: {:.2f} MB".format(new_time_2, new_mem_2 / 1e6))
What do you think? |
Nice! I didnt think to pull the gather out of the loop, let me incorporate that |
Ok, updated with |
Co-authored-by: Quentin Gallouédec <[email protected]>
Thanks! Ok to merge now |
Wow this is super nice! Thanks! |
In |
actually with |
What does this PR do?
GRPOTrainer
uses a method_get_per_token_logps
to compute the per-token logprobs for every token in the input sequence. However, it usesseq_len * vocab_size
additional memory in order to compute log_softmax, generating full log-probabilities for every possible token in the vocabulary at every index in the sequence. The next step then selects only the the actual input tokens to get the per-token logprobs for the input sequence.This can be made more efficient by performing the selection first to get per-token logits, computing the softmax denominator (a reduction over the full set of logits), and then directly computing the logits only for the relevant tokens. This requires only
seq_len
additional memory.Fixes # (issue)
NA but I'm happy to file an issue if needed
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.