-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrmscaleCUDA.patch
38 lines (37 loc) · 2.38 KB
/
rmscaleCUDA.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h
index bf91a7bd..192a8d26 100644
--- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h
+++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h
@@ -71,12 +71,12 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca
GMemIterator<Mandatory, AccessTypeW, CtaN, Details::kAccessNumW, uint8_t> weight_iterator(weight,
(interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW,
interleaved_k / Details::kElemsPerByteW);
- GMemIterator<Mandatory, TypeA, CtaN, 1, TypeA> scales_iterator(scales,
- (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n,
- (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave);
- GMemIterator<EnableZero, TypeA, CtaN, 1, TypeA> zeros_iterator(zeros,
- (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n,
- (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave);
+ // GMemIterator<Mandatory, TypeA, CtaN, 1, TypeA> scales_iterator(scales,
+ // (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n,
+ // (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave);
+ // GMemIterator<EnableZero, TypeA, CtaN, 1, TypeA> zeros_iterator(zeros,
+ // (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n,
+ // (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave);
out += offset_m * n + tile_id_n * CtaN * Details::kInterleave;
if constexpr (EnableBias)
@@ -96,9 +96,12 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca
#pragma unroll
for (int i = 0; i < CtaN; ++i)
{
- scales_iterator.load(vec_scale + i, iter, i);
- zeros_iterator.load(vec_zero + i, iter, i);
+ // scales_iterator.load(vec_scale + i, iter, i);
+ // zeros_iterator.load(vec_zero + i, iter, i);
+ vec_scale[i] = static_cast<TypeA>(1);
+ vec_zero[i] = static_cast<TypeA>(0);
}
+
act_scale_iterator.load(vec_act_scale, iter);
#pragma unroll
for (int i = 0; i < CtaN; ++i)