From e2d7e38321ea7e75e0b0c706cef3f805067dc29f Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Fri, 18 Oct 2024 13:33:43 +0530 Subject: [PATCH] Fix computation in normalization layers (#876) * fix formatting * fix dtype promotion during weights creating itself * ignore false pywright type warning * remove redundant dtype check --- equinox/nn/_normalisation.py | 42 ++++++++++++++++++++++++------------ tests/test_nn.py | 8 +++++++ 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/equinox/nn/_normalisation.py b/equinox/nn/_normalisation.py index aef58fba..2ee1f37a 100644 --- a/equinox/nn/_normalisation.py +++ b/equinox/nn/_normalisation.py @@ -116,7 +116,7 @@ def __call__( """**Arguments:** - `x`: A JAX array, with the same shape as the `shape` passed to `__init__`. - - `state`: Ignored; provided for interchangability with the + - `state`: Ignored; provided for interchangeability with the [`equinox.nn.BatchNorm`][] API. - `key`: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.) @@ -140,19 +140,24 @@ def __call__( "`x.shape` ended with `shape`. However, this turned out to be a " "frequent source of bugs, so we made the check stricter!" ) + orig_dtype = x.dtype + with jax.numpy_dtype_promotion("standard"): + dtype = jnp.result_type(x.dtype, jnp.float32) + + x = x.astype(dtype) mean = jnp.mean(x, keepdims=True) variance = jnp.var(x, keepdims=True) variance = jnp.maximum(0.0, variance) inv = jax.lax.rsqrt(variance + self.eps) out = (x - mean) * inv if self.use_weight: - out = self.weight * out + out = self.weight.astype(dtype) * out # pyright: ignore if self.use_bias: - out = out + self.bias + out = out + self.bias.astype(dtype) # pyright: ignore if state is sentinel: - return out + return out.astype(orig_dtype) else: - return out, state + return out.astype(orig_dtype), state class GroupNorm(Module, strict=True): @@ -253,6 +258,12 @@ def __call__( is passed through unchanged. If `state` is not passed, then just the output is returned. """ + + orig_dtype = x.dtype + with jax.numpy_dtype_promotion("standard"): + dtype = jnp.result_type(x.dtype, jnp.float32) + + x = x.astype(dtype) channels = x.shape[0] y = x.reshape(self.groups, channels // self.groups, *x.shape[1:]) mean = jax.vmap(ft.partial(jnp.mean, keepdims=True))(y) @@ -264,11 +275,11 @@ def __call__( if self.channelwise_affine: weight = left_broadcast_to(self.weight, out.shape) # pyright: ignore bias = left_broadcast_to(self.bias, out.shape) # pyright: ignore - out = weight * out + bias + out = weight.astype(dtype) * out + bias.astype(dtype) if state is sentinel: - return out + return out.astype(orig_dtype) else: - return out, state + return out.astype(orig_dtype), state class RMSNorm(Module, strict=True): @@ -377,17 +388,20 @@ def __call__( "to replace `rms_norm(x)` with `jax.vmap(rms_norm)(x)`.\n" ) + orig_dtype = x.dtype + with jax.numpy_dtype_promotion("standard"): dtype = jnp.result_type(x.dtype, jnp.float32) - inv_rms = jax.lax.rsqrt(jnp.mean(x.astype(dtype) ** 2) + self.eps) - out = (inv_rms * x.astype(dtype)).astype(x.dtype) + x = x.astype(dtype) + inv_rms = jax.lax.rsqrt(jnp.mean(x**2) + self.eps) + out = inv_rms * x if self.use_weight: - out = self.weight * out + out = self.weight.astype(dtype) * out # pyright: ignore if self.use_bias: - out = out + self.bias + out = out + self.bias.astype(dtype) # pyright: ignore if state is sentinel: - return out + return out.astype(orig_dtype) else: - return out, state + return out.astype(orig_dtype), state diff --git a/tests/test_nn.py b/tests/test_nn.py index 7e09e451..d20ac9bd 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -886,12 +886,20 @@ def test_layer_norm(getkey): assert jnp.allclose(ln(x1), ln(x2), atol=1e-4) assert jnp.allclose(ln(x1), x3, atol=1e-4) + ln = eqx.nn.LayerNorm(128, dtype=jnp.bfloat16) + x = jrandom.uniform(getkey(), (128,), dtype=jnp.bfloat16) + assert ln(x).dtype == jnp.bfloat16 + def test_group_norm(getkey): gn = eqx.nn.GroupNorm(groups=4, channels=128) x = jrandom.uniform(getkey(), (128,)) assert gn(x).shape == (128,) + gn = eqx.nn.GroupNorm(groups=4, channels=128, dtype=jnp.bfloat16) + x = jrandom.uniform(getkey(), (128,), dtype=jnp.bfloat16) + assert gn(x).dtype == jnp.bfloat16 + gn = eqx.nn.GroupNorm(groups=4, channels=128) x = jrandom.uniform(getkey(), (128, 4, 5)) assert gn(x).shape == (128, 4, 5)