Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: CPU Offload fails when enable_lora=True #11748

Closed
1 task done
Neko-nos opened this issue Jan 5, 2025 · 1 comment · Fixed by #11810
Closed
1 task done

[Bug]: CPU Offload fails when enable_lora=True #11748

Neko-nos opened this issue Jan 5, 2025 · 1 comment · Fixed by #11810
Assignees
Labels
bug Something isn't working

Comments

@Neko-nos
Copy link

Neko-nos commented Jan 5, 2025

Your current environment

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39

Python version: 3.11.9 (main, Jun 19 2024, 22:14:19) [GCC 13.2.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4080 SUPER
Nvidia driver version: 560.94
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.2.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        39 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               28
On-line CPU(s) list:                  0-27
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Core(TM) i7-14700F
CPU family:                           6
Model:                                183
Thread(s) per core:                   2
Core(s) per socket:                   14
Socket(s):                            1
Stepping:                             1
BogoMIPS:                             4223.99
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    Microsoft
Virtualization type:                  full
L1d cache:                            672 KiB (14 instances)
L1i cache:                            448 KiB (14 instances)
L2 cache:                             28 MiB (14 instances)
L3 cache:                             33 MiB (1 instance)
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed:               Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.0
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.47.1
[pip3] triton==3.1.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.6.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
�[4mGPU0	CPU Affinity	NUMA Affinity	GPU NUMA ID�[0m
GPU0	 X 				N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

LD_LIBRARY_PATH=/home/neko/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/cv2/../../lib64:/usr/local/lib:/usr/local/cuda/lib64:
CUDA_MODULE_LOADING=LAZY

Model Input Dumps

err_execute_model_input_20250105-235132.zip

🐛 Describe the bug

When I ran the following code, the cpu_offload_gb=8 worked correctly.

import vllm

llm = vllm.LLM(
    "princeton-nlp/gemma-2-9b-it-SimPO",
    tensor_parallel_size=1,
    # ref: https://docs.vllm.ai/en/latest/quantization/bnb.html#inflight-quantization-load-as-4bit-quantization
    # quantization="bitsandbytes",
    # load_format="bitsandbytes",
    # enable_lora=True,
    # ref: https://github.com/vllm-project/vllm/issues/2847#issuecomment-2009845554
    # max_lora_rank=64,
    dtype="bfloat16",
    enforce_eager=True,
    max_model_len=3201,
    enable_prefix_caching=True,
    gpu_memory_utilization=0.9,
    cpu_offload_gb=8,
    swap_space=0,
)

vllm_without_cpu_offload

However, when I set enable_lora to True, the following error happened.
I tried downgrading the version of vllm, but it didn't work.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/worker/model_runner_base.py:116, in dump_input_when_exception.<locals>._inner.<locals>._wrapper(*args, **kwargs)
    115 try:
--> 116     return func(*args, **kwargs)
    117 except Exception as err:

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/worker/model_runner.py:1691, in ModelRunner.execute_model(self, model_input, kv_caches, intermediate_tensors, num_steps)
   1689     with set_forward_context(model_input.attn_metadata,
   1690                              self.vllm_config):
-> 1691         hidden_or_intermediate_states = model_executable(
   1692             input_ids=model_input.input_tokens,
   1693             positions=model_input.input_positions,
   1694             kv_caches=kv_caches,
   1695             attn_metadata=model_input.attn_metadata,
   1696             intermediate_tensors=intermediate_tensors,
   1697             **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
   1698                                          device=self.device),
   1699             **seqlen_agnostic_kwargs)
   1701 if (self.observability_config is not None
   1702         and self.observability_config.collect_model_forward_time):

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/model_executor/models/gemma2.py:442, in Gemma2ForCausalLM.forward(self, input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, inputs_embeds)
    433 def forward(
    434     self,
    435     input_ids: torch.Tensor,
   (...)
    440     inputs_embeds: Optional[torch.Tensor] = None,
    441 ) -> Union[torch.Tensor, IntermediateTensors]:
--> 442     hidden_states = self.model(input_ids, positions, kv_caches,
    443                                attn_metadata, intermediate_tensors,
    444                                inputs_embeds)
    445     return hidden_states

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/compilation/decorators.py:168, in _support_torch_compile.<locals>.__call__(self, *args, **kwargs)
    167 if self.do_not_compile or torch.compiler.is_compiling():
--> 168     return self.forward(*args, **kwargs)
    170 # the first compilation needs to have dynamic shapes marked

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/model_executor/models/gemma2.py:304, in Gemma2Model.forward(self, input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, inputs_embeds)
    303     layer = self.layers[i]
--> 304     hidden_states, residual = layer(
    305         positions,
    306         hidden_states,
    307         kv_caches[i - self.start_layer],
    308         attn_metadata,
    309         residual,
    310     )
    311 if not get_pp_group().is_last_rank:

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/model_executor/models/utils.py:524, in maybe_offload_to_cpu.<locals>.forward(*args, **kwargs)
    518 device_state = {
    519     # here we blindly call `to(device)`
    520     # if the parameter is already on the device, it will be a no-op
    521     k: v.to(device, non_blocking=True)
    522     for k, v in module.state_dict().items()
    523 }
--> 524 output = functional_call(module,
    525                          device_state,
    526                          args=args,
    527                          kwargs=kwargs)
    528 module.forward = forward

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/_functorch/functional_call.py:148, in functional_call(module, parameter_and_buffer_dicts, args, kwargs, tie_weights, strict)
    143     raise ValueError(
    144         f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, "
    145         f"but got {type(parameter_and_buffer_dicts)}"
    146     )
--> 148 return nn.utils.stateless._functional_call(
    149     module,
    150     parameters_and_buffers,
    151     args,
    152     kwargs,
    153     tie_weights=tie_weights,
    154     strict=strict,
    155 )

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/nn/utils/stateless.py:298, in _functional_call(module, parameters_and_buffers, args, kwargs, tie_weights, strict)
    295 with _reparametrize_module(
    296     module, parameters_and_buffers, tie_weights=tie_weights, strict=strict
    297 ):
--> 298     return module(*args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/model_executor/models/gemma2.py:233, in Gemma2DecoderLayer.forward(self, positions, hidden_states, kv_cache, attn_metadata, residual)
    231     hidden_states, residual = self.input_layernorm(
    232         hidden_states, residual)
--> 233 hidden_states = self.self_attn(
    234     positions=positions,
    235     hidden_states=hidden_states,
    236     kv_cache=kv_cache,
    237     attn_metadata=attn_metadata,
    238 )
    239 hidden_states = self.post_attention_layernorm(hidden_states)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/model_executor/models/gemma2.py:170, in Gemma2Attention.forward(self, positions, hidden_states, kv_cache, attn_metadata)
    163 def forward(
    164     self,
    165     positions: torch.Tensor,
   (...)
    168     attn_metadata: AttentionMetadata,
    169 ) -> torch.Tensor:
--> 170     qkv, _ = self.qkv_proj(hidden_states)
    171     q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/lora/layers.py:513, in ColumnParallelLinearWithLoRA.forward(self, input_)
    512 # Matrix multiply.
--> 513 output_parallel = self.apply(input_, bias)
    514 if self.base_layer.gather_output:
    515     # All-gather across the partitions.

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/lora/layers.py:392, in BaseLinearLayerWithLoRA.apply(self, x, bias)
    391 output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
--> 392 self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
    393                                     self.lora_b_stacked,
    394                                     self.lora_bias_stacked, 1.0,
    395                                     self.output_slices)
    396 return output

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_gpu.py:308, in PunicaWrapperGPU.add_lora_linear(self, y, x, lora_a_stacked, lora_b_stacked, lora_bias_stacked, scale, output_slices, buffer, **kwargs)
    304     buffer = tuple(
    305         torch.zeros(
    306             (x.size(0), r), dtype=torch.float32, device=x.device)
    307         for _ in range(len(output_slices)))
--> 308 self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
    309 self.add_expand(y,
    310                 buffer,
    311                 lora_b_stacked,
   (...)
    314                 add_inputs=True,
    315                 **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_gpu.py:187, in PunicaWrapperGPU.add_shrink(self, y, x, lora_a_stacked, scale, **kwargs)
    186 for slice_idx in range(len(lora_a_stacked)):
--> 187     self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
    188                        scale)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_gpu.py:160, in PunicaWrapperGPU._apply_shrink(self, y, x, w_t_all, scale)
    158 shrink_fun: Callable = (self._shrink_prefill
    159                         if self.is_prefill else self._shrink_decode)
--> 160 shrink_fun(y, x, w_t_all, scale)
    161 y = y.view_as(y_org)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/lora/punica_wrapper/punica_gpu.py:48, in PunicaWrapperGPU._shrink_prefill(self, y, x, w_t_all, scale)
     47     return
---> 48 sgmv_shrink(
     49     x,
     50     w_t_all,
     51     y,
     52     *self.prefill_metadata,
     53     scale,
     54 )

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/_ops.py:1116, in OpOverloadPacket.__call__(self, *args, **kwargs)
   1115     return _call_overload_packet_from_python(self, args, kwargs)
-> 1116 return self._op(*args, **(kwargs or {}))

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    115 with ctx_factory():
--> 116     return func(*args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/lora/ops/sgmv_shrink.py:169, in _sgmv_shrink(inputs, lora_a_weights, output_tensor, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, batches, max_seq_length, token_nums, scaling)
    163 grid = (
    164     triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
    165     SPLIT_K,
    166     batches,
    167 )
--> 169 _sgmv_shrink_kernel[grid](
    170     inputs,
    171     lora_a_weights,
    172     output_tensor,
    173     N,
    174     K,
    175     b_seq_start_loc,
    176     seq_len_tensor,
    177     lora_indices_tensor,
    178     scaling,
    179     inputs.stride(0),
    180     inputs.stride(1),
    181     lora_a_weights.stride(0),
    182     lora_a_weights.stride(1),
    183     lora_a_weights.stride(2),
    184     output_tensor.stride(0),
    185     output_tensor.stride(1),
    186     BLOCK_M,
    187     BLOCK_N,
    188     BLOCK_K,
    189     EVEN_K,
    190     SPLIT_K,
    191 )
    192 return

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/triton/runtime/jit.py:345, in KernelInterface.__getitem__.<locals>.<lambda>(*args, **kwargs)
    340 """
    341 A JIT function is launched with: fn[grid](*args, **kwargs).
    342 Hence JITFunction.__getitem__ returns a callable proxy that
    343 memorizes the grid.
    344 """
--> 345 return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/triton/runtime/jit.py:691, in JITFunction.run(self, grid, warmup, *args, **kwargs)
    690     launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals)
--> 691     kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
    692                self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals)
    693 return kernel

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/triton/backends/nvidia/driver.py:365, in CudaLauncher.__call__(self, *args, **kwargs)
    364 def __call__(self, *args, **kwargs):
--> 365     self.launch(*args, **kwargs)

ValueError: Pointer argument (at 1) cannot be accessed from Triton (cpu tensor?)

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[1], line 3
      1 import vllm
----> 3 llm = vllm.LLM(
      4     "princeton-nlp/gemma-2-9b-it-SimPO",
      5     tensor_parallel_size=1,
      6     # ref: https://docs.vllm.ai/en/latest/quantization/bnb.html#inflight-quantization-load-as-4bit-quantization
      7     # quantization="bitsandbytes",
      8     # load_format="bitsandbytes",
      9     enable_lora=True,
     10     # ref: https://github.com/vllm-project/vllm/issues/2847#issuecomment-2009845554
     11     # max_lora_rank=64,
     12     dtype="bfloat16",
     13     enforce_eager=True,
     14     max_model_len=3201,
     15     enable_prefix_caching=True,
     16     gpu_memory_utilization=0.9,
     17     cpu_offload_gb=8,
     18     swap_space=0,
     19 )

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/utils.py:986, in deprecate_args.<locals>.wrapper.<locals>.inner(*args, **kwargs)
    979             msg += f" {additional_message}"
    981         warnings.warn(
    982             DeprecationWarning(msg),
    983             stacklevel=3,  # The inner function takes up one level
    984         )
--> 986 return fn(*args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py:230, in LLM.__init__(self, model, tokenizer, tokenizer_mode, skip_tokenizer_init, trust_remote_code, allowed_local_media_path, tensor_parallel_size, dtype, quantization, revision, tokenizer_revision, seed, gpu_memory_utilization, swap_space, cpu_offload_gb, enforce_eager, max_seq_len_to_capture, disable_custom_all_reduce, disable_async_output_proc, hf_overrides, mm_processor_kwargs, task, override_pooler_config, compilation_config, **kwargs)
    227 self.engine_class = self.get_engine_class()
    229 # TODO(rob): enable mp by default (issue with fork vs spawn)
--> 230 self.llm_engine = self.engine_class.from_engine_args(
    231     engine_args, usage_context=UsageContext.LLM_CLASS)
    233 self.request_counter = Counter()

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/engine/llm_engine.py:517, in LLMEngine.from_engine_args(cls, engine_args, usage_context, stat_loggers)
    515 executor_class = cls._get_executor_cls(engine_config)
    516 # Create the LLM engine.
--> 517 engine = cls(
    518     vllm_config=engine_config,
    519     executor_class=executor_class,
    520     log_stats=not engine_args.disable_log_stats,
    521     usage_context=usage_context,
    522     stat_loggers=stat_loggers,
    523 )
    525 return engine

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/engine/llm_engine.py:276, in LLMEngine.__init__(self, vllm_config, executor_class, log_stats, usage_context, stat_loggers, input_registry, mm_registry, use_cached_outputs)
    273 self.model_executor = executor_class(vllm_config=vllm_config, )
    275 if self.model_config.runner_type != "pooling":
--> 276     self._initialize_kv_caches()
    278 # If usage stat is enabled, collect relevant info.
    279 if is_usage_stats_enabled():

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/engine/llm_engine.py:416, in LLMEngine._initialize_kv_caches(self)
    409 """Initialize the KV cache in the worker(s).
    410 
    411 The workers will determine the number of blocks in both the GPU cache
    412 and the swap CPU cache.
    413 """
    414 start = time.time()
    415 num_gpu_blocks, num_cpu_blocks = (
--> 416     self.model_executor.determine_num_available_blocks())
    418 if self.cache_config.num_gpu_blocks_override is not None:
    419     num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/executor/gpu_executor.py:68, in GPUExecutor.determine_num_available_blocks(self)
     64 def determine_num_available_blocks(self) -> Tuple[int, int]:
     65     """Determine the number of available KV blocks by invoking the
     66     underlying worker.
     67     """
---> 68     return self.driver_worker.determine_num_available_blocks()

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/worker/worker.py:202, in Worker.determine_num_available_blocks(self)
    196 # Execute a forward pass with dummy inputs to profile the memory usage
    197 # of the model.
    198 with memory_profiling(baseline_memory_in_bytes=total_gpu_memory -
    199                       self.init_gpu_memory,
    200                       weights_memory_in_bytes=self.model_runner.
    201                       model_memory_usage) as result:
--> 202     self.model_runner.profile_run()
    203     torch.cuda.synchronize()
    205 self._assert_memory_footprint_increased_during_profiling()

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/worker/model_runner.py:1331, in GPUModelRunnerBase.profile_run(self)
   1325 if not get_pp_group().is_first_rank:
   1326     intermediate_tensors = self.model.make_empty_intermediate_tensors(
   1327         batch_size=batch_size,
   1328         dtype=self.model_config.dtype,
   1329         device=self.device)
-> 1331 self.execute_model(model_input, kv_caches, intermediate_tensors)
   1332 torch.cuda.synchronize()
   1333 return

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/programming/Kaggle/wsdm/.venv/lib/python3.11/site-packages/vllm/worker/model_runner_base.py:152, in dump_input_when_exception.<locals>._inner.<locals>._wrapper(*args, **kwargs)
    146         raise type(err)(f"Error in model execution: "
    147                         f"{str(err)}") from err
    149     logger.info(
    150         "Completed writing input of failed execution to %s.",
    151         filename)
--> 152 raise type(err)(
    153     f"Error in model execution (input dumped to {filename}): "
    154     f"{str(err)}") from err

ValueError: Error in model execution (input dumped to /tmp/err_execute_model_input_20250105-235132.pkl): Pointer argument (at 1) cannot be accessed from Triton (cpu tensor?)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@Neko-nos Neko-nos added the bug Something isn't working label Jan 5, 2025
@jeejeelee jeejeelee self-assigned this Jan 6, 2025
@jeejeelee
Copy link
Collaborator

I will look at this issue asap

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants