diff --git a/torch_npu/contrib/module/linear_weight_quant.py b/torch_npu/contrib/module/linear_weight_quant.py index aa100e49b..3943fed65 100644 --- a/torch_npu/contrib/module/linear_weight_quant.py +++ b/torch_npu/contrib/module/linear_weight_quant.py @@ -63,6 +63,7 @@ def __init__( quant_scale: bool = False, quant_offset: bool = False, antiquant_group_size: int = 0, + inner_precise: int = 0, ) -> None: super(LinearWeightQuant, self).__init__() self.weight = Parameter(torch.empty((out_features, in_features), device=device), False) @@ -89,6 +90,7 @@ def __init__( self.register_parameter('bias', None) self.antiquant_group_size = antiquant_group_size + self.inner_precise = inner_precise def forward(self, x: Tensor) -> Tensor: antiquant_scale = self.antiquant_scale @@ -100,4 +102,4 @@ def forward(self, x: Tensor) -> Tensor: antiquant_offset = self.antiquant_offset.transpose(-1, -2) return torch_npu.npu_weight_quant_batchmatmul(x, self.weight.transpose(-1, -2), antiquant_scale, antiquant_offset, self.quant_scale, self.quant_offset, - self.bias, self.antiquant_group_size) + self.bias, self.antiquant_group_size, self.inner_precise) diff --git a/torch_npu/meta/_meta_registrations.py b/torch_npu/meta/_meta_registrations.py index aadb64be4..38a7ef036 100644 --- a/torch_npu/meta/_meta_registrations.py +++ b/torch_npu/meta/_meta_registrations.py @@ -340,7 +340,7 @@ def npu_mm_all_reduce_base_forward(x1, x2, hcom, reduce_op='sum', bias=None, ant @impl(m, "npu_weight_quant_batchmatmul") -def npu_weight_quant_batchmatmul_meta(x, weight, antiquant_scale, antiquant_offset=None, quant_scale=None, quant_offset=None, bias=None, antiquant_group_size=0): +def npu_weight_quant_batchmatmul_meta(x, weight, antiquant_scale, antiquant_offset=None, quant_scale=None, quant_offset=None, bias=None, antiquant_group_size=0, inner_precise=0): dim_m = x.size(0) if weight.dtype == torch.int32 and weight.is_contiguous(): dim_n = weight.size(1) * 8