Skip to content

Commit

Permalink
!16059 伪量化npu_weight_quant_batchmatmul接口增加inner_precise参数
Browse files Browse the repository at this point in the history
Merge pull request !16059 from chenxu/master
  • Loading branch information
chenxu authored and it-is-a-robot committed Nov 25, 2024
1 parent 0f3ca47 commit b39b922
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion torch_npu/contrib/module/linear_weight_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
quant_scale: bool = False,
quant_offset: bool = False,
antiquant_group_size: int = 0,
inner_precise: int = 0,
) -> None:
super(LinearWeightQuant, self).__init__()
self.weight = Parameter(torch.empty((out_features, in_features), device=device), False)
Expand All @@ -89,6 +90,7 @@ def __init__(
self.register_parameter('bias', None)

self.antiquant_group_size = antiquant_group_size
self.inner_precise = inner_precise

def forward(self, x: Tensor) -> Tensor:
antiquant_scale = self.antiquant_scale
Expand All @@ -100,4 +102,4 @@ def forward(self, x: Tensor) -> Tensor:
antiquant_offset = self.antiquant_offset.transpose(-1, -2)
return torch_npu.npu_weight_quant_batchmatmul(x, self.weight.transpose(-1, -2), antiquant_scale,
antiquant_offset, self.quant_scale, self.quant_offset,
self.bias, self.antiquant_group_size)
self.bias, self.antiquant_group_size, self.inner_precise)
2 changes: 1 addition & 1 deletion torch_npu/meta/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def npu_mm_all_reduce_base_forward(x1, x2, hcom, reduce_op='sum', bias=None, ant


@impl(m, "npu_weight_quant_batchmatmul")
def npu_weight_quant_batchmatmul_meta(x, weight, antiquant_scale, antiquant_offset=None, quant_scale=None, quant_offset=None, bias=None, antiquant_group_size=0):
def npu_weight_quant_batchmatmul_meta(x, weight, antiquant_scale, antiquant_offset=None, quant_scale=None, quant_offset=None, bias=None, antiquant_group_size=0, inner_precise=0):
dim_m = x.size(0)
if weight.dtype == torch.int32 and weight.is_contiguous():
dim_n = weight.size(1) * 8
Expand Down

0 comments on commit b39b922

Please sign in to comment.