diff --git a/qkeras/qlayers.py b/qkeras/qlayers.py index 298a77a3..f5cff710 100644 --- a/qkeras/qlayers.py +++ b/qkeras/qlayers.py @@ -103,7 +103,10 @@ def __call__(self, shape, dtype=None): max_x = np.max(abs(x)) std_x = np.std(x) - delta = self.quantizer.max() * 2**-self.quantizer.bits + if hasattr(self.quantizer, 'delta') and callable(getattr(self.quantizer, 'delta')): + delta = self.quantizer.delta() + else: + delta = self.quantizer.max() * 2**-self.quantizer.bits # delta is the minimum resolution of the number system. # we want to make sure we have enough values.