Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue when running on multiple GPUs #1438

Closed
zyzhang1130 opened this issue Mar 18, 2024 · 7 comments
Closed

Issue when running on multiple GPUs #1438

zyzhang1130 opened this issue Mar 18, 2024 · 7 comments

Comments

@zyzhang1130
Copy link

I used PPOTrainer in a setting with 2 GPUs. I got the following issue:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[24], line 233
    231 rewards.append(torch.tensor(reward).to(device))
    232 print('reward',reward)
--> 233 train_stats = ppo_trainer.step(query_tensors,negative_qna, rewards)
    234 del log_prob
    235 del sampled_ids

File ~/anaconda3/envs/rlgaf/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:711, in PPOTrainer.step(self, queries, responses, scores, response_masks)
    708 full_kl_penalty = self.config.kl_penalty == "full"
    710 with torch.no_grad():
--> 711     all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
    712         self.model,
    713         queries,
    714         responses,
    715         model_inputs,
    716         response_masks=response_masks,
    717         return_logits=full_kl_penalty,
    718     )
    719     with self.optional_peft_ctx():
    720         ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
    721             self.model if self.is_peft_model else self.ref_model,
    722             queries,
   (...)
    725             return_logits=full_kl_penalty,
    726         )

File ~/anaconda3/envs/rlgaf/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:984, in PPOTrainer.batched_forward_pass(self, model, queries, responses, model_inputs, return_logits, response_masks)
    982 if response_masks is not None:
    983     response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
--> 984 logits, _, values = model(**input_kwargs)
    986 if self.is_encoder_decoder:
    987     input_ids = input_kwargs["decoder_input_ids"]

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/trl/models/modeling_value_head.py:170, in AutoModelForCausalLMWithValueHead.forward(self, input_ids, past_key_values, attention_mask, **kwargs)
    167 if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
    168     kwargs.pop("past_key_values")
--> 170 base_model_output = self.pretrained_model(
    171     input_ids=input_ids,
    172     attention_mask=attention_mask,
    173     **kwargs,
    174 )
    176 last_hidden_state = base_model_output.hidden_states[-1]
    177 lm_logits = base_model_output.logits

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py:1067, in GemmaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1064 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1066 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1067 outputs = self.model(
   1068     input_ids=input_ids,
   1069     attention_mask=attention_mask,
   1070     position_ids=position_ids,
   1071     past_key_values=past_key_values,
   1072     inputs_embeds=inputs_embeds,
   1073     use_cache=use_cache,
   1074     output_attentions=output_attentions,
   1075     output_hidden_states=output_hidden_states,
   1076     return_dict=return_dict,
   1077     cache_position=cache_position,
   1078 )
   1080 hidden_states = outputs[0]
   1081 logits = self.lm_head(hidden_states)

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py:860, in GemmaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    857     use_cache = False
    859 if inputs_embeds is None:
--> 860     inputs_embeds = self.embed_tokens(input_ids)
    862 past_seen_tokens = 0
    863 if use_cache:  # kept for BC (cache positions)

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/torch/nn/modules/sparse.py:162, in Embedding.forward(self, input)
    161 def forward(self, input: Tensor) -> Tensor:
--> 162     return F.embedding(
    163         input, self.weight, self.padding_idx, self.max_norm,
    164         self.norm_type, self.scale_grad_by_freq, self.sparse)

File ~/anaconda3/envs/rlgaf/lib/python3.10/site-packages/torch/nn/functional.py:2233, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2227     # Note [embedding_renorm set_grad_enabled]
   2228     # XXX: equivalent to
   2229     # with torch.no_grad():
   2230     #   torch.embedding_renorm_
   2231     # remove once script supports set_grad_enabled
   2232     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2233 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

But upon checking, query_tensors,negative_qna, rewards are all on cuda:1. So I don't know why it says I have something on cuda:0,

@younesbelkada
Copy link
Contributor

Hi @zyzhang1130
Thanks for the issue ! can you share a reproducible snippet?

Copy link

github-actions bot commented May 2, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@yananchen1989
Copy link

face the same issue too.

@zyzhang1130
Copy link
Author

zyzhang1130 commented May 13, 2024

@younesbelkada can you please show me a simple example script, where either SFTtrainer or ppo_trainer is moved to device other than 'device:0' for training? If this is possible then it can solve my issue. I worry my code is too complex to look at.

Copy link

github-actions bot commented Jun 6, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jun 6, 2024

Hi all. Would you like to give the new PPOv2Trainer a try? They should work well with multiple GPUs. Feel free to re-open the issue if you run into this problem again.

@vwxyzjn vwxyzjn closed this as completed Jun 6, 2024
@ldh127
Copy link

ldh127 commented Aug 1, 2024

Hi all. Would you like to give the new PPOv2Trainer a try? They should work well with multiple GPUs. Feel free to re-open the issue if you run into this problem again.

hello , i want to know that if the PPOv2Trainer support multi machines and multi gpus , i use torchrun to run the python scripts , thanks ~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants