diff --git a/equinox/nn/_normalisation.py b/equinox/nn/_normalisation.py index 498e8705..f3e20a1b 100644 --- a/equinox/nn/_normalisation.py +++ b/equinox/nn/_normalisation.py @@ -1,7 +1,6 @@ import functools as ft import warnings from collections.abc import Sequence -from math import prod from typing import Optional, overload, Union import jax @@ -279,9 +278,9 @@ class RMSNorm(Module): $$\frac{x}{\sqrt{\varepsilon + \frac{1}{n}\Vert x \Vert^2_2}} \gamma + \beta$$ - where $\Vert \cdot \Vert_2$ is the 2-norm, $n = \dim(x)$, and $\gamma$ is a + where $\Vert x \Vert^2_2 = \sum_{i=1}^n x_i^2$, $n = \dim(x)$, and $\gamma$ is a learned array with the same shape as $x$ if `use_weight=True`, or - $\gamma = 1/\sqrt{n}$ if `use_weight=False`, as proposed in + $\gamma = 1$ if `use_weight=False`, as proposed in [this paper](https://browse.arxiv.org/abs/2307.14995). `\beta` is an optional bias term. @@ -298,8 +297,7 @@ class RMSNorm(Module): ``` """ - shape: tuple[int] = field(static=True) - dim: int = field(static=True) + shape: tuple[int, ...] = field(static=True) eps: float = field(static=True) use_weight: bool = field(static=True) use_bias: bool = field(static=True) @@ -319,6 +317,7 @@ def __init__( - `shape`: Shape of the input. - `eps`: Value added to denominator for numerical stability. - `use_weight`: Whether the module has learnable affine weights. + - `use_bias`: Whether the module has learnable affine shift. """ super().__init__(**kwargs) if isinstance(shape, int): @@ -326,12 +325,11 @@ def __init__( else: shape = tuple(shape) self.shape = shape - self.dim = prod(self.shape) self.eps = eps self.use_weight = use_weight self.use_bias = use_bias self.weight = jnp.ones(shape) if use_weight else None - self.bias = jnp.ones(shape) if use_bias else None + self.bias = jnp.zeros(shape) if use_bias else None @overload def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array: @@ -373,8 +371,8 @@ def __call__( f"Received `shape={self.shape} and `x.shape={x.shape}`. You might need " "to replace `rms_norm(x)` with `jax.vmap(rms_norm)(x)`.\n" ) - inv_rms = jax.lax.rsqrt(jnp.sum(x**2) + self.eps) - out = jnp.sqrt(self.dim) * inv_rms * x + inv_rms = jax.lax.rsqrt(jnp.mean(x**2) + self.eps) + out = inv_rms * x if self.use_weight: out = self.weight * out if self.use_bias: