Skip to content

Commit

Permalink
Ensure that requantization factor D is always power-of-two
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescoConti committed May 8, 2021
1 parent afcc2ea commit 0d98a56
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions nemo/quant/pact.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def pact_quantized_requantize(t, eps_in, eps_out, D=1, exclude_requant_rounding=
def pact_integer_requantize(t, eps_in, eps_out, D=1):
eps_ratio = (D*eps_in/eps_out).round()
device = t.device
return torch.as_tensor((torch.as_tensor(t, dtype=torch.int64) * torch.as_tensor(eps_ratio, dtype=torch.int64) // torch.as_tensor(D, dtype=torch.int64)), dtype=torch.float32, device=device)
return torch.as_tensor((torch.as_tensor(t, dtype=torch.int64) * torch.as_tensor(eps_ratio, dtype=torch.int64) // D), dtype=torch.float32, device=device)

# re-quantize from a lower precision (larger eps_in) to a higher precision (lower eps_out)
def pact_integer_requantize_add(*t, eps_in_list, eps_out, D=1):
Expand Down Expand Up @@ -480,7 +480,7 @@ def get_output_eps(self, eps_in_list):
self.eps_out = max(self.eps_in_list)
self.alpha_out = 2.0**(self.precision.get_bits())-1
# D is selected as a power-of-two
self.D = 2.0**torch.ceil(torch.log2(self.requantization_factor * self.eps_out / min(self.eps_in_list)))
self.D = 2**torch.as_tensor(torch.ceil(torch.log2(self.requantization_factor * self.eps_out / min(self.eps_in_list))), dtype=torch.int64)
return self.eps_out

def forward(self, *x):
Expand Down Expand Up @@ -604,11 +604,11 @@ def set_output_eps(self, limit_at_32_bits=True, **kwargs):
# self.eps_out = self.alpha.item()/(2.0**(self.precision.get_bits())-1)
self.alpha_out = 2.0**(self.precision.get_bits())-1
# D is selected as a power-of-two
D = 2.0**torch.ceil(torch.log2(self.requantization_factor * self.eps_out / self.eps_in))
D = 2**torch.as_tensor(torch.ceil(torch.log2(self.requantization_factor * self.eps_out / self.eps_in)), dtype=torch.int64)
if not limit_at_32_bits:
self.D = D
else:
self.D = min(D, 2.0**(32-(self.precision.get_bits())))
self.D = min(D, 2**(32-(self.precision.get_bits())))

def get_output_eps(self, eps_in):
r"""Get the output quantum (:math:`\varepsilon`) given the input one.
Expand Down

0 comments on commit 0d98a56

Please sign in to comment.