Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: gptattentionplugin onnxparser compatability #2712

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jl749
Copy link

@jl749 jl749 commented Jan 23, 2025

Issue

#2685

Changes made

Added missing declaration of the fields for tensorrt.OnnxParser conversion api.
I also removed the in_flight_batching plugin field as I couldn't see it was directly used anywhere.

# relevant part of the attached ONNX compile log
[01/23/2025-03:27:56] [TRT] [W] onnxOpImporters.cpp:6507: Attribute in_flight_batching not found in plugin node! Ensure that the plugin creator has a default value defined or the engine may fail to build.

pfc = trt.PluginFieldCollection([
layer_idx, nheads, vision_start, vision_length, num_kv_heads,
layer_idx_in_cache_pool, head_size, unidirectional, q_scaling,
attn_logit_softcapping_scale, position_embedding_type,
rotary_embedding_dim, rotary_embedding_base,
rotary_embedding_scale_type, rotary_embedding_scale,
rotary_embedding_short_m_scale, rotary_embedding_long_m_scale,
rotary_embedding_max_positions, rotary_embedding_original_max_positions,
tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type,
kv_cache_quant_mode_field, remove_input_padding, mask_type_filed,
block_sparse_block_size, block_sparse_homo_head_pattern,
block_sparse_num_local_blocks, block_sparse_vertical_stride,
paged_kv_cache, tokens_per_block, pf_type, max_context_length,
qkv_bias_enabled, do_cross_attention_field, max_distance,
pos_shift_enabled, dense_context_fmha, use_paged_context_fmha_field,
use_fp8_context_fmha_field, has_full_attention_mask_field, use_cache_pf,
is_spec_decoding_enabled, spec_decoding_is_generation_length_variable,
spec_decoding_max_generation_length, is_mla_enabled, q_lora_rank,
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim,
skip_attn_pf, cp_size, cp_rank, cp_group, use_logn_scaling
])

To test

  1. Download this ONNX file
  2. Use tensorrt.OnnxParser api with the strongly_typed flag
  3. Check that engine file is generated without an error
ONNX compile log
[01/23/2025-03:27:48] [TRT] [I] [MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 217, GPU 423 (MiB)
[01/23/2025-03:27:56] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +2038, GPU +374, now: CPU 2395, GPU 797 (MiB)
config.default_device_type = DeviceType.GPU
config.max_aux_streams = -1
config.plugins_to_serialize = []
config.get_memory_pool_limit(trt.MemoryPoolType.WORKSPACE) = 42298834944 Byte (39.4 GiB)
[01/23/2025-03:27:56] [TRT] [I] ----------------------------------------------------------------
[01/23/2025-03:27:56] [TRT] [I] Input filename:   ./model.onnx
[01/23/2025-03:27:56] [TRT] [I] ONNX IR version:  0.0.8
[01/23/2025-03:27:56] [TRT] [I] Opset version:    17
[01/23/2025-03:27:56] [TRT] [I] Producer name:    pytorch
[01/23/2025-03:27:56] [TRT] [I] Producer version: 2.4.1
[01/23/2025-03:27:56] [TRT] [I] Domain:           
[01/23/2025-03:27:56] [TRT] [I] Model version:    0
[01/23/2025-03:27:56] [TRT] [I] Doc string:       
[01/23/2025-03:27:56] [TRT] [I] ----------------------------------------------------------------
[01/23/2025-03:27:56] [TRT] [W] ModelImporter.cpp:459: Make sure input input_ids has Int64 binding.
[01/23/2025-03:27:56] [TRT] [W] ModelImporter.cpp:459: Make sure input host_kv_cache_pool_pointers has Int64 binding.
[01/23/2025-03:27:56] [TRT] [W] ModelImporter.cpp:459: Make sure input host_runtime_perf_knobs has Int64 binding.
[01/23/2025-03:27:56] [TRT] [W] ModelImporter.cpp:459: Make sure input host_context_progress has Int64 binding.
[01/23/2025-03:27:56] [TRT] [I] No checker registered for op: GPTAttention. Attempting to check as plugin.
[01/23/2025-03:27:56] [TRT] [I] No importer registered for op: GPTAttention. Attempting to import as plugin.
[01/23/2025-03:27:56] [TRT] [I] Searching for plugin: GPTAttention, plugin_version: 1, plugin_namespace: tensorrt_llm
[01/23/2025-03:27:56] [TRT] [W] onnxOpImporters.cpp:6507: Attribute in_flight_batching not found in plugin node! Ensure that the plugin creator has a default value defined or the engine may fail to build.
[01/23/2025-03:27:56] [TRT] [I] Successfully created plugin: GPTAttention
[01/23/2025-03:27:56] [TRT] [W] Unused Input: position_ids
[01/23/2025-03:27:56] [TRT] [W] [RemoveDeadLayers] Input Tensor position_ids is unused or used only at compile-time, but is not being removed.
[01/23/2025-03:27:56] [TRT] [I] Global timing cache in use. Profiling results in this builder pass will be stored.
[01/23/2025-03:27:56] [TRT] [I] Compiler backend is used during engine build.
[01/23/2025-03:28:12] [TRT] [I] Detected 16 inputs and 1 output network tensors.
[01/23/2025-03:28:14] [TRT] [I] Total Host Persistent Memory: 1792 bytes
[01/23/2025-03:28:14] [TRT] [I] Total Device Persistent Memory: 0 bytes
[01/23/2025-03:28:14] [TRT] [I] Max Scratch Memory: 33555072 bytes
[01/23/2025-03:28:14] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 6 steps to complete.
[01/23/2025-03:28:14] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.024757ms to assign 4 blocks to 6 nodes requiring 40895488 bytes.
[01/23/2025-03:28:14] [TRT] [I] Total Activation Memory: 40895488 bytes
[01/23/2025-03:28:17] [TRT] [I] Detected 16 inputs and 1 output network tensors.
[01/23/2025-03:28:17] [TRT] [I] Total Host Persistent Memory: 1792 bytes
[01/23/2025-03:28:17] [TRT] [I] Total Device Persistent Memory: 0 bytes
[01/23/2025-03:28:17] [TRT] [I] Max Scratch Memory: 33555072 bytes
[01/23/2025-03:28:17] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 6 steps to complete.
[01/23/2025-03:28:17] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.024627ms to assign 4 blocks to 6 nodes requiring 33562624 bytes.
[01/23/2025-03:28:17] [TRT] [I] Total Activation Memory: 33562624 bytes
[01/23/2025-03:28:18] [TRT] [I] Total Weights Memory: 6339328 bytes
[01/23/2025-03:28:18] [TRT] [I] Compiler backend is used during engine execution.
[01/23/2025-03:28:18] [TRT] [I] Engine generation completed in 21.9013 seconds.
[01/23/2025-03:28:18] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 6 MiB
[01/23/2025-03:28:19] [TRT] [I] Loaded engine size: 7 MiB
TRT Engine uses: 40895488 bytes of Memory

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant