You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently LLM-VM does not support multiple GPU setups. Using runpod, I rented a setup with 2 RTX 3090 GPUs. Well running the local Bloom model example from the docs. I ran into this error:
`EleutherAI/pythia-70m-deduped` loaded on 2 GPUs.
Using model: bloom
Running with an empty context
Exception in thread Thread-2 (new_thread):
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/usr/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.10/dist-packages/llm_vm/completion/optimize.py", line 45, in new_thread
t[0] = foo()
File "/usr/local/lib/python3.10/dist-packages/llm_vm/completion/optimize.py", line 259, in promiseCompletion
best_completion = self.call_big(prompt, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/llm_vm/client.py", line 102, in CALL_BIG
return self.teacher.generate(prompt, max_len,**kwargs)
File "/usr/local/lib/python3.10/dist-packages/llm_vm/onsite_llm.py", line 153, in generate
generate_ids=self.model.generate(inputs.input_ids, max_length=max_length, **generation_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1695, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'DataParallel' object has no attribute 'generate'
{'status': 0, 'resp': 'cannot unpack non-iterable NoneType object'}
On fix I had for this error was by adding the following code changes to src/llm_vm/onsite_llm.py (added the if-else statement under the "account ofr cases where model is wrapped..." comment):
diff --git a/src/llm_vm/onsite_llm.py b/src/llm_vm/onsite_llm.py
index 9fcfe3c..613acbf 100644
--- a/src/llm_vm/onsite_llm.py
+++ b/src/llm_vm/onsite_llm.py
@@ -151,7 +141,13 @@ class BaseOnsiteLLM(ABC):
inputs = self.tokenizer(prompt, return_tensors="pt", **tokenizer_kwargs).to(device[0])
else:
inputs = self.tokenizer(prompt, return_tensors="pt").to(device)
- generate_ids=self.model.generate(inputs.input_ids, max_length=max_length, **generation_kwargs)
+
+ # account for cases where the model is wrapped in DataParallel
+ if isinstance(self.model, torch.nn.DataParallel):
+ generate_ids = self.model.module.generate(inputs.input_ids, max_length=max_length, **generation_kwargs)
+ else:
+ generate_ids = self.model.generate(inputs.input_ids, max_length=max_length, **generation_kwargs)
+
resp= self.tokenizer.batch_decode(generate_ids,skip_special_tokens=True,clean_up_tokenization_spaces=False)[0]
# need to drop the len(prompt) prefix with these sequences generally
# because they include the prompt.
☝️ This change resolves the error but it does not fix the core issue. In this change we are using DataParallel.module to access the model's generate() function. In doing so, we would be bypassing the DataParallel wrapper and skipping the parallelism provided by "DataParallel". I believe this is the case and that we should implement a new solution.
The text was updated successfully, but these errors were encountered:
MehmetMHY
changed the title
LLM-VM is not support multiple GPUs currently
LLM-VM does not support multiple GPUs currently
Nov 15, 2023
Currently LLM-VM does not support multiple GPU setups. Using runpod, I rented a setup with 2 RTX 3090 GPUs. Well running the local Bloom model example from the docs. I ran into this error:
On fix I had for this error was by adding the following code changes to
src/llm_vm/onsite_llm.py
(added the if-else statement under the "account ofr cases where model is wrapped..." comment):☝️ This change resolves the error but it does not fix the core issue. In this change we are using DataParallel.module to access the model's generate() function. In doing so, we would be bypassing the DataParallel wrapper and skipping the parallelism provided by "DataParallel". I believe this is the case and that we should implement a new solution.
The text was updated successfully, but these errors were encountered: