Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Unable to load deepseek r1 on 8 x AMD MI300X AssertionError: FP8 weight padding is not supported in block quantization #375

Open
1 task done
samos123 opened this issue Jan 21, 2025 · 21 comments
Labels
bug Something isn't working

Comments

@samos123
Copy link

Your current environment

Image I'm using: rocm/vllm-dev:nightly_main_20250120

vLLM version and flags used:

INFO 01-21 19:40:33 api_server.py:768] vLLM API server version 0.6.7.dev215+gfaa1815c
INFO 01-21 19:40:33 api_server.py:769] args: Namespace(host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='deepseek-ai/DeepSeek-R1', task='auto', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='fp8', max_model_len=128000, guided_decoding_backend='xgrammar', logits_processor_pattern=None, distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=8, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=128000, max_num_seqs=1024, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=16384, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=15, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=False, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=['deepseek-r1-mi300x'], qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', generation_config=None, calculate_kv_scales=False, disable_log_requests=True, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False)

Model Input Dumps

No response

🐛 Describe the bug

errror

(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240] Traceback (most recent call last):
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 234, in _run_worker_process
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "vllm/utils.py", line 2379, in vllm.utils.run_method
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 182, in load_model
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     self.model_runner.load_model()
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/multi_step_model_runner.py", line 650, in load_model
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     self._base_model_runner.load_model()
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1097, in load_model
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     self.model = get_model(vllm_config=self.vllm_config)
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/__init__.py", line 12, in get_model
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     return loader.load_model(vllm_config=vllm_config)
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/loader.py", line 376, in load_model
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     model = _initialize_model(vllm_config=vllm_config)
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/loader.py", line 118, in _initialize_model
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     return model_class(vllm_config=vllm_config, prefix=prefix)
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v3.py", line 510, in __init__
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     self.model = DeepseekV3Model(vllm_config=vllm_config,
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v3.py", line 445, in __init__
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     self.start_layer, self.end_layer, self.layers = make_layers(
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]                                                     ^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 556, in make_layers
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v3.py", line 447, in <lambda>
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     lambda prefix: DeepseekV3DecoderLayer(
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]                    ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v3.py", line 354, in __init__
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     self.self_attn = DeepseekV3Attention(
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]                      ^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v3.py", line 213, in __init__
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     self.q_a_proj = ReplicatedLinear(self.hidden_size,
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/linear.py", line 214, in __init__
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     self.quant_method.create_weights(self,
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/fp8.py", line 169, in create_weights
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]     assert not envs.VLLM_FP8_PADDING, (
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=306) ERROR 01-21 19:41:08 multiproc_worker_utils.py:240] AssertionError: FP8 weight padding is not supported in block quantization.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@samos123 samos123 added the bug Something isn't working label Jan 21, 2025
@gshtras
Copy link
Collaborator

gshtras commented Jan 21, 2025

With DeepSeek V3 please set the VLLM_FP8_PADDING=0 environment variable as its quantization method is currently incompatible with FP8 padding

@samos123
Copy link
Author

Getting much further now. Thanks for the quick suggestion:

Loading safetensors checkpoint shards: 100% Completed | 163/163 [01:30<00:00,  3.02it/s]
Loading safetensors checkpoint shards: 100% Completed | 163/163 [01:30<00:00,  1.79it/s]

WARNING 01-21 22:07:21 kv_cache.py:86] Using KV cache scaling factor 1.0 for fp8_e4m3. This may cause accuracy issues. Please make sure k/v_scale scaling factors are available in the fp8 checkpoint.
(VllmWorkerProcess pid=308) WARNING 01-21 22:07:21 kv_cache.py:86] Using KV cache scaling factor 1.0 for fp8_e4m3. This may cause accuracy issues. Please make sure k/v_scale scaling factors are available in the fp8 checkpoint.
(VllmWorkerProcess pid=305) WARNING 01-21 22:07:21 kv_cache.py:86] Using KV cache scaling factor 1.0 for fp8_e4m3. This may cause accuracy issues. Please make sure k/v_scale scaling factors are available in the fp8 checkpoint.
(VllmWorkerProcess pid=310) WARNING 01-21 22:07:21 kv_cache.py:86] Using KV cache scaling factor 1.0 for fp8_e4m3. This may cause accuracy issues. Please make sure k/v_scale scaling factors are available in the fp8 checkpoint.
(VllmWorkerProcess pid=309) WARNING 01-21 22:07:21 kv_cache.py:86] Using KV cache scaling factor 1.0 for fp8_e4m3. This may cause accuracy issues. Please make sure k/v_scale scaling factors are available in the fp8 checkpoint.
(VllmWorkerProcess pid=306) WARNING 01-21 22:07:21 kv_cache.py:86] Using KV cache scaling factor 1.0 for fp8_e4m3. This may cause accuracy issues. Please make sure k/v_scale scaling factors are available in the fp8 checkpoint.
(VllmWorkerProcess pid=307) WARNING 01-21 22:07:21 kv_cache.py:86] Using KV cache scaling factor 1.0 for fp8_e4m3. This may cause accuracy issues. Please make sure k/v_scale scaling factors are available in the fp8 checkpoint.
(VllmWorkerProcess pid=308) INFO 01-21 22:07:22 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=305) INFO 01-21 22:07:22 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=306) INFO 01-21 22:07:22 model_runner.py:1100] Loading model weights took 79.3596 GB
INFO 01-21 22:07:22 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=310) INFO 01-21 22:07:22 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=309) INFO 01-21 22:07:22 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=307) INFO 01-21 22:07:22 model_runner.py:1100] Loading model weights took 79.3596 GB
WARNING 01-21 22:07:27 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
(VllmWorkerProcess pid=309) WARNING 01-21 22:07:27 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
(VllmWorkerProcess pid=307) WARNING 01-21 22:07:27 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
(VllmWorkerProcess pid=308) WARNING 01-21 22:07:27 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
(VllmWorkerProcess pid=310) WARNING 01-21 22:07:27 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
(VllmWorkerProcess pid=306) WARNING 01-21 22:07:27 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
(VllmWorkerProcess pid=311) WARNING 01-21 22:07:27 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
(VllmWorkerProcess pid=305) WARNING 01-21 22:07:27 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
Memory access fault by GPU node-4 (Agent handle: 0x23962b40) on address 0x7f602c86c000. Reason: Unknown.
Memory access fault by GPU node-9 (Agent handle: 0x2c625db0) on address 0x7f5ac4545000. Reason: Unknown.
Memory access fault by GPU node-7 (Agent handle: 0x48c22fa0) on address 0x7f03cabe5000. Reason: Unknown.
Memory access fault by GPU node-3 (Agent handle: 0x2e1223b0) on address 0x7f81f26ea000. Reason: Unknown.
Memory access fault by GPU node-5 (Agent handle: 0x196a0050) on address 0x7ee0861d9000. Reason: Unknown.
Memory access fault by GPU node-6 (Agent handle: 0x12a60390) on address 0x7f56c2fe6000. Reason: Unknown.
Memory access fault by GPU node-8 (Agent handle: 0x27217680) on address 0x7ede3917d000. Reason: Unknown.
Memory access fault by GPU node-2 (Agent handle: 0x36d6d160) on address 0x7ed8fb1ef000. Reason: Write access to a read-only page.

I will post an update if this ends up fully working or not.

@samos123
Copy link
Author

samos123 commented Jan 21, 2025

hmm looks like there is a failure and it tries to write GPU dump, however GPU dump writing fails due to running out of disk. The core issue however is the Memory access fault.

Any idea on how to debug further?

Current flags used in my KubeAI custom object:

  deepseek-r1-mi300x:
    enabled: true
    features: [TextGeneration]
    url: hf://deepseek-ai/DeepSeek-R1
    engine: VLLM
    env:
      HIP_FORCE_DEV_KERNARG: "1"
      NCCL_MIN_NCHANNELS: "112"
      TORCH_BLAS_PREFER_HIPBLASLT: "1"
      VLLM_USE_TRITON_FLASH_ATTN: "0"
      VLLM_FP8_PADDING: "0"
    args:
      - --trust-remote-code
      - --max-model-len=128000
      - --max-num-batched-token=128000
      - --max-num-seqs=1024
      - --num-scheduler-steps=15
      - --tensor-parallel-size=8
      - --gpu-memory-utilization=0.90
      - --disable-log-requests
      - --enable-chunked-prefill=false
      - --max-seq-len-to-capture=16384
      - --kv-cache-dtype=fp8
    resourceProfile: amd-gpu-mi300x:8
    targetRequests: 1024
    minReplicas: 1

Current logs:

Memory access fault by GPU node-2 (Agent handle: 0x36d6d160) on address 0x7ed8fb1ef000. Reason: Write access to a read-only page.




Failed to allocate file: No space left on device
GPU core dump failed
Failed to allocate file: No space left on device
GPU core dump failed
Failed to allocate file: No space left on device
GPU core dump failed
Task exception was never retrieved
future: <Task finished name='Task-2' coro=<MQLLMEngineClient.run_output_handler_loop() done, defined at /usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/client.py:180> exception=ZMQError('Operation not supported')>

@billcsm
Copy link

billcsm commented Jan 21, 2025

I got the same issue during trying to load DeepSeek V3 on 8x AMD MI300X GPU. Set the VLLM_FP8_PADDING=0 environment variable and stuck at the following in running. Thanks for any helps.

(VllmWorkerProcess pid=949) INFO 01-21 23:11:24 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=946) INFO 01-21 23:11:24 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=943) INFO 01-21 23:11:24 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=945) INFO 01-21 23:11:24 model_runner.py:1100] Loading model weights took 79.3596 GB
INFO 01-21 23:11:24 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=944) INFO 01-21 23:11:24 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=947) INFO 01-21 23:11:24 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=948) INFO 01-21 23:11:24 model_runner.py:1100] Loading model weights took 79.3596 GB
(VllmWorkerProcess pid=948) WARNING 01-21 23:12:03 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
(VllmWorkerProcess pid=947) WARNING 01-21 23:12:03 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
(VllmWorkerProcess pid=945) WARNING 01-21 23:12:03 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
(VllmWorkerProcess pid=944) WARNING 01-21 23:12:04 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
Memory access fault by GPU node-14 (Agent handle: 0x1fb45b60) on address 0x7f2ba853c000. Reason: Unknown.
(VllmWorkerProcess pid=949) WARNING 01-21 23:12:05 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
WARNING 01-21 23:12:05 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
Memory access fault by GPU node-13 (Agent handle: 0x312a3840) on address 0x7ee6e5a61000. Reason: Unknown.
Memory access fault by GPU node-11 (Agent handle: 0x3448b380) on address 0x7f62b79c6000. Reason: Unknown.
(VllmWorkerProcess pid=943) WARNING 01-21 23:12:06 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
(VllmWorkerProcess pid=946) WARNING 01-21 23:12:06 fused_moe.py:375] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json
Memory access fault by GPU node-10 (Agent handle: 0x30c37980) on address 0x7f4cfe8c3000. Reason: Unknown.
Memory access fault by GPU node-8 (Agent handle: 0x45341fb0) on address 0x7f642725b000. Reason: Unknown.
Memory access fault by GPU node-15 (Agent handle: 0x380cb080) on address 0x7f7006b00000. Reason: Unknown.
Memory access fault by GPU node-9 (Agent handle: 0x3634efa0) on address 0x7eeb7da08000. Reason: Unknown.
Memory access fault by GPU node-12 (Agent handle: 0x296acfe0) on address 0x7f0174181000. Reason: Unknown.

@samos123
Copy link
Author

It's now working after I switched to these flags:

  deepseek-r1-mi300x:
    enabled: true
    features: [TextGeneration]
    url: hf://deepseek-ai/DeepSeek-R1
    engine: VLLM
    env:
      HIP_FORCE_DEV_KERNARG: "1"
      NCCL_MIN_NCHANNELS: "112"
      TORCH_BLAS_PREFER_HIPBLASLT: "1"
      VLLM_USE_TRITON_FLASH_ATTN: "0"
      VLLM_FP8_PADDING: "0"
    args:
      - --trust-remote-code
      - --max-model-len=8096
      - --max-num-batched-token=8096
      - --max-num-seqs=1024
      - --num-scheduler-steps=10
      - --tensor-parallel-size=8
      - --gpu-memory-utilization=0.90
      - --disable-log-requests
      # - --enable-chunked-prefill=false
      # - --max-seq-len-to-capture=16384
      - --kv-cache-dtype=fp8
    resourceProfile: amd-gpu-mi300x:8
    targetRequests: 1024
    minReplicas: 1

note that I decreased context length and reduced amount of flags I set before. Maybe this works for you too @billcsm

@gshtras
Copy link
Collaborator

gshtras commented Jan 21, 2025

Mem access fault is a new issue that was discovered today, we hope to have a fix ready within the coming days.
Thank you for providing the workaround info.

@samos123
Copy link
Author

I can reproduce the Memory access fault again if I change my working config to use 120k context length. It seems to be something related to larger context length. Trying to find something in between 8k and 120k that doesn't result in mem access fault.

@gshtras
Copy link
Collaborator

gshtras commented Jan 22, 2025

Leaving the other parameters as their default value, max-model-len==32768 should work.
For custom values of other params we now need to mix and match the combination that works. We're investigating the real cause of the crash.

@samos123
Copy link
Author

Testing this now:

      - --max-model-len=64768
      - --max-num-batched-token=64768

maybe it has to be a multiple of 8? Note that's just a wild guess since I know nothing about how the internals of vLLM or GPUs work.

@samos123
Copy link
Author

getting same Memory access fault when using 64768 context length.

@billcsm
Copy link

billcsm commented Jan 22, 2025

@samos123 and @gshtras ,
max-model-len=8096 worked. But max-model-len=32768 and 64768 both failed.

max-model-len=32768 failed at HIP out of memory.

max-model-len=64768 failed at Memory access fault.

Memory access fault by GPU node-13 (Agent handle: 0x479962f0) on address 0x7ee718e79000. Reason: Unknown.
Memory access fault by GPU node-15 (Agent handle: 0xfa450e0) on address 0x7f4934aa0000. Reason: Write access to a read-only page.
Memory access fault by GPU node-12 (Agent handle: 0x33557af0) on address 0x7f3d3bca0000. Reason: Unknown.
Memory access fault by GPU node-10 (Agent handle: 0xec6e910) on address 0x7eccdc6a0000. Reason: Unknown.
Memory access fault by GPU node-9 (Agent handle: 0x2a990340) on address 0x7f4983b24000. Reason: Unknown.
Memory access fault by GPU node-11 (Agent handle: 0x36d57a00) on address 0x7f0b85164000. Reason: Unknown.
Memory access fault by GPU node-8 (Agent handle: 0x14361fd0) on address 0x7f48ae300000. Reason: Unknown.
Memory access fault by GPU node-14 (Agent handle: 0x2a07c740) on address 0x7f7c370d3000. Reason: Unknown.

@samos123
Copy link
Author

samos123 commented Jan 22, 2025

32k context length is working for me with this config:

  deepseek-r1-mi300x:
    enabled: true
    features: [TextGeneration]
    url: hf://deepseek-ai/DeepSeek-R1
    engine: VLLM
    env:
      HIP_FORCE_DEV_KERNARG: "1"
      NCCL_MIN_NCHANNELS: "112"
      TORCH_BLAS_PREFER_HIPBLASLT: "1"
      VLLM_USE_TRITON_FLASH_ATTN: "0"
      VLLM_FP8_PADDING: "0"
    args:
      - --trust-remote-code
      - --max-model-len=32768
      - --max-num-batched-token=32768
      - --max-num-seqs=1024
      - --num-scheduler-steps=10
      - --tensor-parallel-size=8
      - --gpu-memory-utilization=0.90
      - --disable-log-requests
      - --enable-chunked-prefill=false
      - --max-seq-len-to-capture=16384
      - --kv-cache-dtype=fp8
    resourceProfile: amd-gpu-mi300x:8
    targetRequests: 1024
    minReplicas: 1

performance when sending 1000 requests all at once:

============ Serving Benchmark Result ============
Successful requests:                     996       
Benchmark duration (s):                  322.58    
Total input tokens:                      234863    
Total generated tokens:                  168732    
Request throughput (req/s):              3.09      
Output token throughput (tok/s):         523.07    
Total Token throughput (tok/s):          1251.14   
---------------Time to First Token----------------
Mean TTFT (ms):                          29308.07  
Median TTFT (ms):                        24782.61  
P99 TTFT (ms):                           70899.39  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          357.19    
Median TPOT (ms):                        222.17    
P99 TPOT (ms):                           2152.76   
---------------Inter-token Latency----------------
Mean ITL (ms):                           1996.25   
Median ITL (ms):                         1825.19   
P99 ITL (ms):                            7196.28   
==================================================

unclear why 4 requests failed since no errors were printed in vLLM logs

@billcsm
Copy link

billcsm commented Jan 22, 2025

@samos123 and @gshtras ,
You both are right. I re-ran with max-model-len=32768 and it worked. Last time it got failed as I forgot to clean the cache before I did the test. Here is my test configuration.

env:
  HIP_FORCE_DEV_KERNARG: "1"
  NCCL_MIN_NCHANNELS: "112"
  TORCH_BLAS_PREFER_HIPBLASLT: "1"
  VLLM_USE_TRITON_FLASH_ATTN: "0"
  VLLM_FP8_PADDING: "0"
args:
  - --trust-remote-code
  - --max-model-len=32768
  - --max-num-batched-token=32768
  - --max-num-seqs=1024
  - --num-scheduler-steps=10
  - --tensor-parallel-size=8
  - --gpu-memory-utilization=0.90
  - --disable-log-requests
  - --kv-cache-dtype=fp8

Thank you for your helps!

@billcsm
Copy link

billcsm commented Jan 23, 2025

@samos123, I did the performance benchmark with 32k context length on model DeepSeek-V3. Here is its 1000 requests (request-rate 10) test result:

============ Serving Benchmark Result ============
Successful requests: 948
Benchmark duration (s): 239.03
Total input tokens: 199799
Total generated tokens: 183165
Request throughput (req/s): 3.97
Output token throughput (tok/s): 766.27
Total Token throughput (tok/s): 1602.14
---------------Time to First Token----------------
Mean TTFT (ms): 3044.16
Median TTFT (ms): 2846.34
P99 TTFT (ms): 14537.48
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 293.66
Median TPOT (ms): 263.60
P99 TPOT (ms): 1170.39
---------------Inter-token Latency----------------
Mean ITL (ms): 2427.62
Median ITL (ms): 2148.51
P99 ITL (ms): 5403.55

=========================================

My test lost 52 requests. No error was printed.

@gshtras
Copy link
Collaborator

gshtras commented Jan 23, 2025

@samos123, I did the performance benchmark with 32k context length on model DeepSeek-V3. Here is its 1000 requests (request-rate 10) test result:

============ Serving Benchmark Result ============
Successful requests: 948
Benchmark duration (s): 239.03
Total input tokens: 199799
Total generated tokens: 183165
Request throughput (req/s): 3.97
Output token throughput (tok/s): 766.27
Total Token throughput (tok/s): 1602.14
---------------Time to First Token----------------
Mean TTFT (ms): 3044.16
Median TTFT (ms): 2846.34
P99 TTFT (ms): 14537.48
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 293.66
Median TPOT (ms): 263.60
P99 TPOT (ms): 1170.39
---------------Inter-token Latency----------------
Mean ITL (ms): 2427.62
Median ITL (ms): 2148.51
P99 ITL (ms): 5403.55

My test lost 52 requests. No error was printed.

Were you by chance running the benchmark_serving with a random dataset as a param?

@samos123
Copy link
Author

No I was using sharegpt dataset with a fixed seed.

@billcsm
Copy link

billcsm commented Jan 23, 2025

For the previous test, I used sharegpt dataset ShareGPT_V3_unfiltered_cleaned_split.json. Here is twice test results for running the benchmark_serving with a random dataset in 1000 requests (request-rate 10):

====== Serving Benchmark Result =======
Successful requests: 999
Benchmark duration (s): 216.71
Total input tokens: 1022976
Total generated tokens: 123252
Request throughput (req/s): 4.61
Output token throughput (tok/s): 568.73
Total Token throughput (tok/s): 5289.16
---------------Time to First Token----------------
Mean TTFT (ms): 55433.10
Median TTFT (ms): 56938.99
P99 TTFT (ms): 106623.27
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 281.53
Median TPOT (ms): 187.71
P99 TPOT (ms): 2173.08
---------------Inter-token Latency----------------
Mean ITL (ms): 2101.63
Median ITL (ms): 1944.64
P99 ITL (ms): 2637.23
================================

======== Serving Benchmark Result ========
Successful requests: 998
Benchmark duration (s): 216.70
Total input tokens: 1021952
Total generated tokens: 122661
Request throughput (req/s): 4.61
Output token throughput (tok/s): 566.05
Total Token throughput (tok/s): 5282.12
---------------Time to First Token----------------
Mean TTFT (ms): 50746.73
Median TTFT (ms): 53040.23
P99 TTFT (ms): 100442.46
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 311.05
Median TPOT (ms): 243.90
P99 TPOT (ms): 896.89
---------------Inter-token Latency----------------
Mean ITL (ms): 2661.29
Median ITL (ms): 1953.24
P99 ITL (ms): 12172.51
===================================

No error was printed.

@billcsm
Copy link

billcsm commented Jan 24, 2025

Change to 8k context length and re-run the benchmark_serving with a random dataset in 1000 requests (request-rate 10). Still got the request lost. No error was printed.

============ Serving Benchmark Result ============
Successful requests: 998
Benchmark duration (s): 228.42
Total input tokens: 1021952
Total generated tokens: 116473
Request throughput (req/s): 4.37
Output token throughput (tok/s): 509.90
Total Token throughput (tok/s): 4983.83
---------------Time to First Token----------------
Mean TTFT (ms): 54352.95
Median TTFT (ms): 58279.14
P99 TTFT (ms): 109505.51
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 325.31
Median TPOT (ms): 277.91
P99 TPOT (ms): 1112.34
---------------Inter-token Latency----------------
Mean ITL (ms): 2836.07
Median ITL (ms): 1999.27
P99 ITL (ms): 12828.03
=====================================

@samos123
Copy link
Author

This seems to be an AMD specific issue. I've run the exact same benchmark with same flags on NVIDIA GPUs and always get 1000 successful requests.

@gshtras
Copy link
Collaborator

gshtras commented Jan 24, 2025

Could you provide the complete logs from both the server and the client benchmark?

@samos123
Copy link
Author

Yes let me file a separate bug for this. Since this is unrelated to deepseek r1. I've seen this on Llama 3.1 70B as well: https://substratus.ai/blog/benchmarking-llama-3.1-70b-amd-mi300x

For reference here is the same benchmark on L4: https://substratus.ai/blog/benchmarking-llama-3.1-70b-on-l4
GH200:

funny enough 405B model on AMD seems to work fine: https://substratus.ai/blog/benchmarking-llama-3.1-405b-amd-mi300x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants