Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Cutlass 2:4 Sparsity + FP8/Int8 Quant RuntimeError: Error Internal #11763

Open
1 task done
leoyuppieqnew opened this issue Jan 6, 2025 · 6 comments
Open
1 task done
Labels
bug Something isn't working

Comments

@leoyuppieqnew
Copy link

leoyuppieqnew commented Jan 6, 2025

Your current environment

PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Alibaba Group Enterprise Linux Server 7.2 (Paladin) (x86_64)
GCC version: (GCC) 10.2.1 20200825 (Alibaba 10.2.1-3 2.17)
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.32

Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.10.134-16.3.al8.x86_64-x86_64-with-glibc2.32
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H20
GPU 1: NVIDIA H20
GPU 2: NVIDIA H20
GPU 3: NVIDIA H20

Nvidia driver version: 535.183.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                192
On-line CPU(s) list:   0-191
Thread(s) per core:    2
Core(s) per socket:    48
Socket(s):             2
NUMA node(s):          2
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 143
Model name:            Intel(R) Xeon(R) Platinum 8469C
Stepping:              8
CPU MHz:               3100.000
CPU max MHz:           3800.0000
CPU min MHz:           800.0000
BogoMIPS:              5200.00
Virtualization:        VT-x
L1d cache:             48K
L1i cache:             32K
L2 cache:              2048K
L3 cache:              99840K
NUMA node0 CPU(s):     0-47,96-143
NUMA node1 CPU(s):     48-95,144-191
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm uintr md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] flake8==7.1.1
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.0
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.46.2
[pip3] triton==3.1.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-ml-py              12.560.30                pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pyzmq                     26.2.0                   pypi_0    pypi
[conda] torch                     2.5.1                    pypi_0    pypi
[conda] torchaudio                2.5.1                    pypi_0    pypi
[conda] torchvision               0.20.1                   pypi_0    pypi
[conda] transformers              4.46.2                   pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.1.dev3896+ga491d6f.d20250103 (git sha: a491d6f.d20250103
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    NIC0    NIC1    NIC2    NIC3    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NODE    NODE    SYS     SYS     0-47,96-143     0               N/A
GPU1    NV18     X      NV18    NV18    NODE    PIX     SYS     SYS     0-47,96-143     0               N/A
GPU2    NV18    NV18     X      NV18    SYS     SYS     PIX     NODE    48-95,144-191   1               N/A
GPU3    NV18    NV18    NV18     X      SYS     SYS     NODE    NODE    48-95,144-191   1               N/A
NIC0    NODE    NODE    SYS     SYS      X      NODE    SYS     SYS
NIC1    NODE    PIX     SYS     SYS     NODE     X      SYS     SYS
NIC2    SYS     SYS     PIX     NODE    SYS     SYS      X      NODE
NIC3    SYS     SYS     NODE    NODE    SYS     SYS     NODE     X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_bond_0
  NIC1: mlx5_bond_1
  NIC2: mlx5_bond_2
  NIC3: mlx5_bond_3

NVIDIA_VISIBLE_DEVICES=GPU-2e3c3bd8-5671-14b6-05bd-8b2e12a575f7,GPU-9cbca0f2-6fda-5286-a2d2-3a2627f071ad,GPU-a9850d16-a201-98e7-790b-e1fa485aa444,GPU-9367a5d4-02ff-f20f-3781-69b8f3cacb5a
LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/cv2/../../lib64::/opt/conda/lib/python3.10/site-packages/aistudio_common/reader/libs/:/opt/taobao/java/jre/lib/amd64/server/:/usr/local/cuda/lib64:/opt/conda/lib/python3.10/site-packages/aistudio_common/reader/libs/:/opt/taobao/java/jre/lib/amd64/server/:/usr/local/cuda/lib64:/opt/conda/lib/python3.10/site-packages/aistudio_common/reader/libs/:/opt/taobao/java/jre/lib/amd64/server/:/usr/local/cuda/lib64:/opt/conda/lib/python3.10/site-packages/aistudio_common/reader/libs/:/opt/taobao/java/jre/lib/amd64/server/:/usr/local/cuda/lib64:/opt/conda/lib/python3.10/site-packages/aistudio_common/reader/libs/:/opt/taobao/java/jre/lib/amd64/server/:/usr/local/cuda/lib64:/opt/conda/lib/python3.10/site-packages/aistudio_common/reader/libs/:/opt/taobao/java/jre/lib/amd64/server/:/usr/local/cuda/lib64
NVIDIA_DRIVER_CAPABILITIES=all
NCCL_NVLS_ENABLE=0
CUDA_MODULE_LOADING=LAZY

Model Input Dumps

No response

🐛 Describe the bug

[root workflow_47400355 /ossfs/workspace] 一 1月 06 15:46:05
$CUDA_LAUNCH_BLOCKING=1 vllm serve /mntfn/yanyi/sparse/mntfn-FP8-Dynamic/
INFO 01-06 15:46:27 api_server.py:647] vLLM API server version 0.1.dev3896+ga491d6f.d20250103
INFO 01-06 15:46:27 api_server.py:648] args: Namespace(subparser='serve', model_tag='/mntfn/yanyi/sparse/mntfn-FP8-Dynamic/', config='', host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=[''], allowed_methods=[''], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='/mntfn/yanyi/sparse/mntfn-FP8-Dynamic/', task='auto', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=None, guided_decoding_backend='xgrammar', logits_processor_pattern=None, distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=None, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', generation_config=None, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, dispatch_function=<function serve at 0x7fd3d1634c10>)
INFO 01-06 15:46:27 api_server.py:195] Started engine process with PID 65910
INFO 01-06 15:46:40 config.py:518] This model supports multiple tasks: {'reward', 'score', 'embed', 'classify', 'generate'}. Defaulting to 'generate'.
INFO 01-06 15:46:50 config.py:518] This model supports multiple tasks: {'reward', 'classify', 'generate', 'embed', 'score'}. Defaulting to 'generate'.
INFO 01-06 15:46:53 llm_engine.py:234] Initializing an LLM engine (v0.1.dev3896+ga491d6f.d20250103) with config: model='/mntfn/yanyi/sparse/mntfn-FP8-Dynamic/', speculative_config=None, tokenizer='/mntfn/yanyi/sparse/mntfn-FP8-Dynamic/', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=compressed-tensors, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/mntfn/yanyi/sparse/mntfn-FP8-Dynamic/, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"candidate_compile_sizes":[],"compile_sizes":[],"capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=True,
INFO 01-06 15:46:57 selector.py:120] Using Flash Attention backend.
INFO 01-06 15:46:59 model_runner.py:1094] Starting to load model /mntfn/yanyi/sparse/mntfn-FP8-Dynamic/...
Loading safetensors checkpoint shards: 0% Completed | 0/16 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 6% Completed | 1/16 [00:01<00:16, 1.08s/it]
Loading safetensors checkpoint shards: 12% Completed | 2/16 [00:02<00:17, 1.27s/it]
Loading safetensors checkpoint shards: 19% Completed | 3/16 [00:03<00:17, 1.38s/it]
Loading safetensors checkpoint shards: 25% Completed | 4/16 [00:05<00:16, 1.37s/it]
Loading safetensors checkpoint shards: 31% Completed | 5/16 [00:06<00:14, 1.32s/it]
Loading safetensors checkpoint shards: 38% Completed | 6/16 [00:07<00:12, 1.29s/it]
Loading safetensors checkpoint shards: 44% Completed | 7/16 [00:09<00:11, 1.26s/it]
Loading safetensors checkpoint shards: 50% Completed | 8/16 [00:10<00:09, 1.25s/it]
Loading safetensors checkpoint shards: 56% Completed | 9/16 [00:11<00:08, 1.24s/it]
Loading safetensors checkpoint shards: 62% Completed | 10/16 [00:12<00:07, 1.24s/it]
Loading safetensors checkpoint shards: 69% Completed | 11/16 [00:13<00:06, 1.25s/it]
Loading safetensors checkpoint shards: 75% Completed | 12/16 [00:15<00:05, 1.26s/it]
Loading safetensors checkpoint shards: 81% Completed | 13/16 [00:16<00:03, 1.24s/it]
Loading safetensors checkpoint shards: 88% Completed | 14/16 [00:17<00:02, 1.20s/it]
Loading safetensors checkpoint shards: 94% Completed | 15/16 [00:18<00:01, 1.16s/it]
Loading safetensors checkpoint shards: 100% Completed | 16/16 [00:19<00:00, 1.03s/it]
Loading safetensors checkpoint shards: 100% Completed | 16/16 [00:19<00:00, 1.21s/it]

INFO 01-06 15:47:21 model_runner.py:1099] Loading model weights took 45.6023 GB
INFO 01-06 15:47:24 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20250106-154724.pkl...
WARNING 01-06 15:47:24 model_runner_base.py:143] Failed to pickle inputs of failed execution: CUDA error: an illegal memory access was encountered
WARNING 01-06 15:47:24 model_runner_base.py:143] Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
WARNING 01-06 15:47:24 model_runner_base.py:143]
ERROR 01-06 15:47:24 engine.py:366] Error in model execution: Error Internal
ERROR 01-06 15:47:24 engine.py:366] Traceback (most recent call last):
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/worker/model_runner_base.py", line 116, in _wrapper
ERROR 01-06 15:47:24 engine.py:366] return func(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/worker/model_runner.py", line 1691, in execute_model
ERROR 01-06 15:47:24 engine.py:366] hidden_or_intermediate_states = model_executable(
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
ERROR 01-06 15:47:24 engine.py:366] return self._call_impl(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
ERROR 01-06 15:47:24 engine.py:366] return forward_call(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/model_executor/models/qwen2.py", line 477, in forward
ERROR 01-06 15:47:24 engine.py:366] hidden_states = self.model(input_ids, positions, kv_caches,
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/compilation/decorators.py", line 168, in call
ERROR 01-06 15:47:24 engine.py:366] return self.forward(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/model_executor/models/qwen2.py", line 340, in forward
ERROR 01-06 15:47:24 engine.py:366] hidden_states, residual = layer(
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
ERROR 01-06 15:47:24 engine.py:366] return self._call_impl(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
ERROR 01-06 15:47:24 engine.py:366] return forward_call(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/model_executor/models/qwen2.py", line 247, in forward
ERROR 01-06 15:47:24 engine.py:366] hidden_states = self.self_attn(
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
ERROR 01-06 15:47:24 engine.py:366] return self._call_impl(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
ERROR 01-06 15:47:24 engine.py:366] return forward_call(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/model_executor/models/qwen2.py", line 173, in forward
ERROR 01-06 15:47:24 engine.py:366] qkv, _ = self.qkv_proj(hidden_states)
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
ERROR 01-06 15:47:24 engine.py:366] return self._call_impl(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in call_impl
ERROR 01-06 15:47:24 engine.py:366] return forward_call(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/model_executor/layers/linear.py", line 373, in forward
ERROR 01-06 15:47:24 engine.py:366] output_parallel = self.quant_method.apply(self, input, bias)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py", line 511, in apply
ERROR 01-06 15:47:24 engine.py:366] return scheme.apply_weights(layer, x, bias=bias)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py", line 172, in apply_weights
ERROR 01-06 15:47:24 engine.py:366] out = ops.cutlass_scaled_sparse_mm(a=q_input,
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/_custom_ops.py", line 628, in cutlass_scaled_sparse_mm
ERROR 01-06 15:47:24 engine.py:366] torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a,
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/_ops.py", line 1116, in call
ERROR 01-06 15:47:24 engine.py:366] return self._op(*args, **(kwargs or {}))
ERROR 01-06 15:47:24 engine.py:366] RuntimeError: Error Internal
ERROR 01-06 15:47:24 engine.py:366]
ERROR 01-06 15:47:24 engine.py:366] The above exception was the direct cause of the following exception:
ERROR 01-06 15:47:24 engine.py:366]
ERROR 01-06 15:47:24 engine.py:366] Traceback (most recent call last):
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
ERROR 01-06 15:47:24 engine.py:366] engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
ERROR 01-06 15:47:24 engine.py:366] return cls(ipc_path=ipc_path,
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/engine.py", line 71, in init
ERROR 01-06 15:47:24 engine.py:366] self.engine = LLMEngine(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/engine/llm_engine.py", line 276, in init
ERROR 01-06 15:47:24 engine.py:366] self._initialize_kv_caches()
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/engine/llm_engine.py", line 416, in _initialize_kv_caches
ERROR 01-06 15:47:24 engine.py:366] self.model_executor.determine_num_available_blocks())
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/executor/gpu_executor.py", line 68, in determine_num_available_blocks
ERROR 01-06 15:47:24 engine.py:366] return self.driver_worker.determine_num_available_blocks()
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 01-06 15:47:24 engine.py:366] return func(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/worker/worker.py", line 202, in determine_num_available_blocks
ERROR 01-06 15:47:24 engine.py:366] self.model_runner.profile_run()
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 01-06 15:47:24 engine.py:366] return func(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/worker/model_runner.py", line 1331, in profile_run
ERROR 01-06 15:47:24 engine.py:366] self.execute_model(model_input, kv_caches, intermediate_tensors)
ERROR 01-06 15:47:24 engine.py:366] File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 01-06 15:47:24 engine.py:366] return func(*args, **kwargs)
ERROR 01-06 15:47:24 engine.py:366] File "/ossfs/workspace/ant_vllm/vllm/worker/model_runner_base.py", line 146, in _wrapper
ERROR 01-06 15:47:24 engine.py:366] raise type(err)(f"Error in model execution: "
ERROR 01-06 15:47:24 engine.py:366] RuntimeError: Error in model execution: Error Internal
Process SpawnProcess-1:
Traceback (most recent call last):
File "/ossfs/workspace/ant_vllm/vllm/worker/model_runner_base.py", line 116, in _wrapper
return func(*args, **kwargs)
File "/ossfs/workspace/ant_vllm/vllm/worker/model_runner.py", line 1691, in execute_model
hidden_or_intermediate_states = model_executable(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/ossfs/workspace/ant_vllm/vllm/model_executor/models/qwen2.py", line 477, in forward
hidden_states = self.model(input_ids, positions, kv_caches,
File "/ossfs/workspace/ant_vllm/vllm/compilation/decorators.py", line 168, in call
return self.forward(*args, **kwargs)
File "/ossfs/workspace/ant_vllm/vllm/model_executor/models/qwen2.py", line 340, in forward
hidden_states, residual = layer(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/ossfs/workspace/ant_vllm/vllm/model_executor/models/qwen2.py", line 247, in forward
hidden_states = self.self_attn(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/ossfs/workspace/ant_vllm/vllm/model_executor/models/qwen2.py", line 173, in forward
qkv, _ = self.qkv_proj(hidden_states)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in call_impl
return forward_call(*args, **kwargs)
File "/ossfs/workspace/ant_vllm/vllm/model_executor/layers/linear.py", line 373, in forward
output_parallel = self.quant_method.apply(self, input, bias)
File "/ossfs/workspace/ant_vllm/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py", line 511, in apply
return scheme.apply_weights(layer, x, bias=bias)
File "/ossfs/workspace/ant_vllm/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py", line 172, in apply_weights
out = ops.cutlass_scaled_sparse_mm(a=q_input,
File "/ossfs/workspace/ant_vllm/vllm/_custom_ops.py", line 628, in cutlass_scaled_sparse_mm
torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a,
File "/opt/conda/lib/python3.10/site-packages/torch/_ops.py", line 1116, in call
return self._op(*args, **(kwargs or {}))
RuntimeError: Error Internal

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/engine.py", line 368, in run_mp_engine
raise e
File "/ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
File "/ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
return cls(ipc_path=ipc_path,
File "/ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/engine.py", line 71, in init
self.engine = LLMEngine(*args, **kwargs)
File "/ossfs/workspace/ant_vllm/vllm/engine/llm_engine.py", line 276, in init
self._initialize_kv_caches()
File "/ossfs/workspace/ant_vllm/vllm/engine/llm_engine.py", line 416, in _initialize_kv_caches
self.model_executor.determine_num_available_blocks())
File "/ossfs/workspace/ant_vllm/vllm/executor/gpu_executor.py", line 68, in determine_num_available_blocks
return self.driver_worker.determine_num_available_blocks()
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/ossfs/workspace/ant_vllm/vllm/worker/worker.py", line 202, in determine_num_available_blocks
self.model_runner.profile_run()
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/ossfs/workspace/ant_vllm/vllm/worker/model_runner.py", line 1331, in profile_run
self.execute_model(model_input, kv_caches, intermediate_tensors)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/ossfs/workspace/ant_vllm/vllm/worker/model_runner_base.py", line 146, in _wrapper
raise type(err)(f"Error in model execution: "
RuntimeError: Error in model execution: Error Internal
[rank0]:[W106 15:47:25.908723004 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
Task exception was never retrieved
future: <Task finished name='Task-2' coro=<MQLLMEngineClient.run_output_handler_loop() done, defined at /ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/client.py:178> exception=ZMQError('Operation not supported')>
Traceback (most recent call last):
File "/ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/client.py", line 184, in run_output_handler_loop
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
File "/opt/conda/lib/python3.10/site-packages/zmq/_future.py", line 400, in poll
raise _zmq.ZMQError(_zmq.ENOTSUP)
zmq.error.ZMQError: Operation not supported
Task exception was never retrieved
future: <Task finished name='Task-3' coro=<MQLLMEngineClient.run_output_handler_loop() done, defined at /ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/client.py:178> exception=ZMQError('Operation not supported')>
Traceback (most recent call last):
File "/ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/client.py", line 184, in run_output_handler_loop
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
File "/opt/conda/lib/python3.10/site-packages/zmq/_future.py", line 400, in poll
raise _zmq.ZMQError(_zmq.ENOTSUP)
zmq.error.ZMQError: Operation not supported
Task exception was never retrieved
future: <Task finished name='Task-4' coro=<MQLLMEngineClient.run_output_handler_loop() done, defined at /ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/client.py:178> exception=ZMQError('Operation not supported')>
Traceback (most recent call last):
File "/ossfs/workspace/ant_vllm/vllm/engine/multiprocessing/client.py", line 184, in run_output_handler_loop
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
File "/opt/conda/lib/python3.10/site-packages/zmq/_future.py", line 400, in poll
raise _zmq.ZMQError(_zmq.ENOTSUP)
zmq.error.ZMQError: Operation not supported
Traceback (most recent call last):
File "/opt/conda/bin/vllm", line 8, in
sys.exit(main())
File "/ossfs/workspace/ant_vllm/vllm/scripts.py", line 201, in main
args.dispatch_function(args)
File "/ossfs/workspace/ant_vllm/vllm/scripts.py", line 42, in serve
uvloop.run(run_server(args))
File "/opt/conda/lib/python3.10/site-packages/uvloop/init.py", line 82, in run
return loop.run_until_complete(wrapper())
File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
File "/opt/conda/lib/python3.10/site-packages/uvloop/init.py", line 61, in wrapper
return await main
File "/ossfs/workspace/ant_vllm/vllm/entrypoints/openai/api_server.py", line 671, in run_server
async with build_async_engine_client(args) as engine_client:
File "/opt/conda/lib/python3.10/contextlib.py", line 199, in aenter
return await anext(self.gen)
File "/ossfs/workspace/ant_vllm/vllm/entrypoints/openai/api_server.py", line 114, in build_async_engine_client
async with build_async_engine_client_from_engine_args(
File "/opt/conda/lib/python3.10/contextlib.py", line 199, in aenter
return await anext(self.gen)
File "/ossfs/workspace/ant_vllm/vllm/entrypoints/openai/api_server.py", line 219, in build_async_engine_client_from_engine_args
raise RuntimeError(
RuntimeError: Engine process failed to start. See stack trace for the root cause.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@leoyuppieqnew leoyuppieqnew added the bug Something isn't working label Jan 6, 2025
@mgoin
Copy link
Member

mgoin commented Jan 7, 2025

Hi @leoyuppieqnew is there any information you could share on how you prepared the compressed-tensors checkpoint? Sharing details like the model itself or even just its config.json would be useful. Unfortunately the error message is unclear and we are unsure how to reproduce at the moment

@leoyuppieqnew
Copy link
Author

leoyuppieqnew commented Jan 7, 2025

Hi @leoyuppieqnew is there any information you could share on how you prepared the compressed-tensors checkpoint? Sharing details like the model itself or even just its config.json would be useful. Unfortunately the error message is unclear and we are unsure how to reproduce at the moment

Sure, I used the llm-compressor one-shot method for sparse and fp8 quantization. The model base is qwen2-72B-Instruct, and its config is as follows:

{
  "_name_or_path": "/home/qwen2_72b_tuwen_mix_24_10000/stage_sparsity",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 8192,
  "initializer_range": 0.02,
  "intermediate_size": 29568,
  "max_position_embeddings": 32768,
  "max_window_layers": 70,
  "model_type": "qwen2",
  "num_attention_heads": 64,
  "num_hidden_layers": 80,
  "num_key_value_heads": 8,
  "output_router_logits": false,
  "quantization_config": {
    "config_groups": {
      "group_0": {
        "input_activations": {
          "actorder": null,
          "block_structure": null,
          "dynamic": true,
          "group_size": null,
          "num_bits": 8,
          "observer": null,
          "observer_kwargs": {},
          "strategy": "token",
          "symmetric": true,
          "type": "float"
        },
        "output_activations": null,
        "targets": [
          "Linear"
        ],
        "weights": {
          "actorder": null,
          "block_structure": null,
          "dynamic": false,
          "group_size": null,
          "num_bits": 8,
          "observer": "minmax",
          "observer_kwargs": {},
          "strategy": "channel",
          "symmetric": true,
          "type": "float"
        }
      }
    },
    "format": "float-quantized",
    "global_compression_ratio": 1.4643128654975015,
    "ignore": [
      "lm_head"
    ],
    "kv_cache_scheme": null,
    "quant_method": "compressed-tensors",
    "quantization_status": "compressed",
    "sparsity_config": {
      "format": "dense",
      "global_sparsity": 0.48285508867414173,
      "ignore": [
        "lm_head"
      ],
      "registry_requires_subclass": false,
      "sparsity_structure": "2:4",
      "targets": [
        "Linear"
      ]
    }
  },
  "rms_norm_eps": 1e-06,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.44.1",
  "use_cache": false,
  "use_sliding_window": false,
  "vocab_size": 152064
}

@robertgshaw2-neuralmagic
Copy link
Collaborator

Is this using the vllm whl?

@leoyuppieqnew
Copy link
Author

Is this using the vllm whl?

Nope, it is compiled from source code, commitid is a491d6f

@robertgshaw2-neuralmagic
Copy link
Collaborator

Thanks, we will take a look

@LucasWilkinson
Copy link
Contributor

LucasWilkinson commented Jan 7, 2025

@leoyuppieqnew thanks for reporting the bug! Since you are building from source, can apply the following patch and rebuild and then provide the output? RuntimeError: Error Internal is something reported by CUTLASS when it runs into issues. This could help us out greatly since it should tell us fairly quickly if its an H20 related issue (a GPU we don't immediately have access to)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 83c803343..ae0c7c45e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -440,6 +440,7 @@ define_gpu_extension_target(
 # driver API. This causes problems when linking with earlier versions of CUDA.
 # Setting this variable sidesteps the issue by calling the driver directly.
 target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
+target_compile_definitions(_C PRIVATE CUTLASS_DEBUG_TRACE_LEVEL=1)
 
 #
 # _moe_C extension

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants