Skip to content

Commit

Permalink
↔️ GRPO: Set max_model_len when initializing vLLM instance (#2728)
Browse files Browse the repository at this point in the history
* Set max_model_len when initializing vLLM instance

* Introduce vllm_max_model_len arg

* Replace vllm args with vllm_init_kwargs

* Update docstring

* Add missing import

* Remove default values from newly deprecated args

* Docs update

* Reverted to adding single arg for max_model_len

* Remove spurious import

* Remove spurious line

* style

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
mirceapricop and qgallouedec authored Feb 5, 2025
1 parent af4ad47 commit 78c5ce2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
12 changes: 12 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class GRPOConfig(TrainingArguments):
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
based on the model configuration. Find the supported values in the vLLM documentation.
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
context size, which might be much larger than the KV cache, leading to inefficiencies.
> Parameters that control the training
Expand Down Expand Up @@ -181,6 +185,14 @@ class GRPOConfig(TrainingArguments):
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
vllm_max_model_len: Optional[int] = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)

# Parameters that control the training
learning_rate: float = field(
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def data_collator(features): # No data collation is needed in GRPO
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=True,
max_model_len=self.args.vllm_max_model_len,
)
self.sampling_params = SamplingParams(
n=self.num_generations,
Expand Down

0 comments on commit 78c5ce2

Please sign in to comment.