Skip to content

Commit

Permalink
Fix overriding optimize_device_cache with optimize_cuda_cache in PPOC…
Browse files Browse the repository at this point in the history
…onfig (#1690)

* Don't override optimize_device_cache when optimize_cuda_cache is not provided
Raise an exception when both optimize_cuda_cache and optimize_device_cache are set

* Minor fix
  • Loading branch information
alexisrozhkov authored Jun 3, 2024
1 parent f18253b commit 6c203f9
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ class PPOConfig:
warnings.warn(
"The `optimize_cuda_cache` argument will be deprecated soon, please use `optimize_device_cache` instead."
)

if optimize_device_cache is True:
raise ValueError("Both `optimize_device_cache` and `optimize_cuda_cache` were provided")

optimize_device_cache = optimize_cuda_cache
else:
optimize_device_cache = False

def __post_init__(self):
if self.forward_batch_size is not None:
Expand Down

0 comments on commit 6c203f9

Please sign in to comment.