From dfe4bec3e2fe41e272e5ff1bb2b527e4f518291b Mon Sep 17 00:00:00 2001 From: fruitea Date: Fri, 5 Jul 2024 04:41:19 -0700 Subject: [PATCH] fix: SSM don't use cudaGraph --- src/ops/fused.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 4ef5f84460..95a4cc8da6 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -612,7 +612,8 @@ __host__ void //graph_params.Print(); // int shard_id = task->index_point.point_data[0]; - bool use_cuda_graph = (bc->prompt_phase == false && bc->get_mode() == TREE_SEARCH_MODE); + // bool use_cuda_graph = (bc->prompt_phase == false && bc->get_mode() == TREE_SEARCH_MODE); + bool use_cuda_graph = false; bool captured = false; if(use_cuda_graph && metas->graph_collections.count(graph_params) != 0) { @@ -961,7 +962,7 @@ __host__ void case OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); - SpecIncMultiHeadSelfAttentionMeta const *m = + SpecIncMultiHeadSelfAttentionMeta *m = (SpecIncMultiHeadSelfAttentionMeta *)metas->meta[op]; // TreeSearchBatchConfig const *search_bc = // (TreeSearchBatchConfig *)task->args;