Skip to content

Commit

Permalink
Fix computation in normalization layers (#876)
Browse files Browse the repository at this point in the history
* fix formatting

* fix dtype promotion during weights creating itself

* ignore false pywright type warning

* remove redundant dtype check
  • Loading branch information
AakashKumarNain authored Oct 18, 2024
1 parent d9b3ffd commit e2d7e38
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 14 deletions.
42 changes: 28 additions & 14 deletions equinox/nn/_normalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __call__(
"""**Arguments:**
- `x`: A JAX array, with the same shape as the `shape` passed to `__init__`.
- `state`: Ignored; provided for interchangability with the
- `state`: Ignored; provided for interchangeability with the
[`equinox.nn.BatchNorm`][] API.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)
Expand All @@ -140,19 +140,24 @@ def __call__(
"`x.shape` ended with `shape`. However, this turned out to be a "
"frequent source of bugs, so we made the check stricter!"
)
orig_dtype = x.dtype
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(x.dtype, jnp.float32)

x = x.astype(dtype)
mean = jnp.mean(x, keepdims=True)
variance = jnp.var(x, keepdims=True)
variance = jnp.maximum(0.0, variance)
inv = jax.lax.rsqrt(variance + self.eps)
out = (x - mean) * inv
if self.use_weight:
out = self.weight * out
out = self.weight.astype(dtype) * out # pyright: ignore
if self.use_bias:
out = out + self.bias
out = out + self.bias.astype(dtype) # pyright: ignore
if state is sentinel:
return out
return out.astype(orig_dtype)
else:
return out, state
return out.astype(orig_dtype), state


class GroupNorm(Module, strict=True):
Expand Down Expand Up @@ -253,6 +258,12 @@ def __call__(
is passed through unchanged. If `state` is not passed, then just the output is
returned.
"""

orig_dtype = x.dtype
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(x.dtype, jnp.float32)

x = x.astype(dtype)
channels = x.shape[0]
y = x.reshape(self.groups, channels // self.groups, *x.shape[1:])
mean = jax.vmap(ft.partial(jnp.mean, keepdims=True))(y)
Expand All @@ -264,11 +275,11 @@ def __call__(
if self.channelwise_affine:
weight = left_broadcast_to(self.weight, out.shape) # pyright: ignore
bias = left_broadcast_to(self.bias, out.shape) # pyright: ignore
out = weight * out + bias
out = weight.astype(dtype) * out + bias.astype(dtype)
if state is sentinel:
return out
return out.astype(orig_dtype)
else:
return out, state
return out.astype(orig_dtype), state


class RMSNorm(Module, strict=True):
Expand Down Expand Up @@ -377,17 +388,20 @@ def __call__(
"to replace `rms_norm(x)` with `jax.vmap(rms_norm)(x)`.\n"
)

orig_dtype = x.dtype

with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(x.dtype, jnp.float32)

inv_rms = jax.lax.rsqrt(jnp.mean(x.astype(dtype) ** 2) + self.eps)
out = (inv_rms * x.astype(dtype)).astype(x.dtype)
x = x.astype(dtype)
inv_rms = jax.lax.rsqrt(jnp.mean(x**2) + self.eps)
out = inv_rms * x

if self.use_weight:
out = self.weight * out
out = self.weight.astype(dtype) * out # pyright: ignore
if self.use_bias:
out = out + self.bias
out = out + self.bias.astype(dtype) # pyright: ignore
if state is sentinel:
return out
return out.astype(orig_dtype)
else:
return out, state
return out.astype(orig_dtype), state
8 changes: 8 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,12 +886,20 @@ def test_layer_norm(getkey):
assert jnp.allclose(ln(x1), ln(x2), atol=1e-4)
assert jnp.allclose(ln(x1), x3, atol=1e-4)

ln = eqx.nn.LayerNorm(128, dtype=jnp.bfloat16)
x = jrandom.uniform(getkey(), (128,), dtype=jnp.bfloat16)
assert ln(x).dtype == jnp.bfloat16


def test_group_norm(getkey):
gn = eqx.nn.GroupNorm(groups=4, channels=128)
x = jrandom.uniform(getkey(), (128,))
assert gn(x).shape == (128,)

gn = eqx.nn.GroupNorm(groups=4, channels=128, dtype=jnp.bfloat16)
x = jrandom.uniform(getkey(), (128,), dtype=jnp.bfloat16)
assert gn(x).dtype == jnp.bfloat16

gn = eqx.nn.GroupNorm(groups=4, channels=128)
x = jrandom.uniform(getkey(), (128, 4, 5))
assert gn(x).shape == (128, 4, 5)
Expand Down

0 comments on commit e2d7e38

Please sign in to comment.