From 0d98a56898accc01451d0cbb671088acfb6898dc Mon Sep 17 00:00:00 2001 From: Francesco Conti Date: Wed, 5 May 2021 12:33:21 +0200 Subject: [PATCH] Ensure that requantization factor D is always power-of-two --- nemo/quant/pact.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/quant/pact.py b/nemo/quant/pact.py index 5a121f7..0597271 100644 --- a/nemo/quant/pact.py +++ b/nemo/quant/pact.py @@ -52,7 +52,7 @@ def pact_quantized_requantize(t, eps_in, eps_out, D=1, exclude_requant_rounding= def pact_integer_requantize(t, eps_in, eps_out, D=1): eps_ratio = (D*eps_in/eps_out).round() device = t.device - return torch.as_tensor((torch.as_tensor(t, dtype=torch.int64) * torch.as_tensor(eps_ratio, dtype=torch.int64) // torch.as_tensor(D, dtype=torch.int64)), dtype=torch.float32, device=device) + return torch.as_tensor((torch.as_tensor(t, dtype=torch.int64) * torch.as_tensor(eps_ratio, dtype=torch.int64) // D), dtype=torch.float32, device=device) # re-quantize from a lower precision (larger eps_in) to a higher precision (lower eps_out) def pact_integer_requantize_add(*t, eps_in_list, eps_out, D=1): @@ -480,7 +480,7 @@ def get_output_eps(self, eps_in_list): self.eps_out = max(self.eps_in_list) self.alpha_out = 2.0**(self.precision.get_bits())-1 # D is selected as a power-of-two - self.D = 2.0**torch.ceil(torch.log2(self.requantization_factor * self.eps_out / min(self.eps_in_list))) + self.D = 2**torch.as_tensor(torch.ceil(torch.log2(self.requantization_factor * self.eps_out / min(self.eps_in_list))), dtype=torch.int64) return self.eps_out def forward(self, *x): @@ -604,11 +604,11 @@ def set_output_eps(self, limit_at_32_bits=True, **kwargs): # self.eps_out = self.alpha.item()/(2.0**(self.precision.get_bits())-1) self.alpha_out = 2.0**(self.precision.get_bits())-1 # D is selected as a power-of-two - D = 2.0**torch.ceil(torch.log2(self.requantization_factor * self.eps_out / self.eps_in)) + D = 2**torch.as_tensor(torch.ceil(torch.log2(self.requantization_factor * self.eps_out / self.eps_in)), dtype=torch.int64) if not limit_at_32_bits: self.D = D else: - self.D = min(D, 2.0**(32-(self.precision.get_bits()))) + self.D = min(D, 2**(32-(self.precision.get_bits()))) def get_output_eps(self, eps_in): r"""Get the output quantum (:math:`\varepsilon`) given the input one.