Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RMSNorm #629

Merged
merged 4 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/api/nn/normalisation.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@
members:
- __init__
- __call__

---

::: equinox.nn.RMSNorm
selection:
members:
- __init__
- __call__
6 changes: 5 additions & 1 deletion equinox/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from ._inference import inference_mode as inference_mode
from ._linear import Identity as Identity, Linear as Linear
from ._mlp import MLP as MLP
from ._normalisation import GroupNorm as GroupNorm, LayerNorm as LayerNorm
from ._normalisation import (
GroupNorm as GroupNorm,
LayerNorm as LayerNorm,
RMSNorm as RMSNorm,
)
from ._pool import (
AdaptiveAvgPool1d as AdaptiveAvgPool1d,
AdaptiveAvgPool2d as AdaptiveAvgPool2d,
Expand Down
114 changes: 114 additions & 0 deletions equinox/nn/_normalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,117 @@ def __call__(
return out
else:
return out, state


class RMSNorm(Module):
r"""
A simplified version of LayerNorm which rescales the inputs, but does not center
them. Optionally applies a learned reweighting of the transformed array afterward.

Given an input array $x$, this layer computes

$$\frac{x}{\sqrt{\varepsilon + \frac{1}{n}\Vert x \Vert^2_2}} \gamma + \beta$$

where $\Vert x \Vert^2_2 = \sum_{i=1}^n x_i^2$, $n = \dim(x)$, and $\gamma$ is a
learned array with the same shape as $x$ if `use_weight=True`, or
$\gamma = 1$ if `use_weight=False`, as proposed in
[this paper](https://browse.arxiv.org/abs/2307.14995). `\beta` is an optional bias
term.

??? cite
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this will need a new line here in order to render correctly in the docs.

i.e.

??? cite

    foo

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, whilst we're here -- maybe worth emphasising that use_{weight,bias}=False is the default.

(Should we keep it as the default? I appreciate that's maybe more useful, but it's inconsistent with the other normalisation layers. WDYT?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising the point about the defaults. I feel its good to maintain consistency, especially with LayerNorm. Changed defaults to on.

[Root Mean Square Layer Normalization](https://browse.arxiv.org/abs/1910.07467)

```bibtex
@article{zhang2019root,
title={Root Mean Square Layer Normalization},
author={Biao Zhang and Rico Sennrich},
year={2019},
journal={arXiv:1910.07467}
}
```
"""

shape: tuple[int, ...] = field(static=True)
eps: float = field(static=True)
use_weight: bool = field(static=True)
use_bias: bool = field(static=True)
weight: Optional[Float[Array, "*shape"]]
bias: Optional[Float[Array, "*shape"]]

def __init__(
self,
shape: Union[int, Sequence[int]],
eps: float = 1e-5,
use_weight: bool = False,
use_bias: bool = False,
**kwargs,
):
"""**Arguments:**

- `shape`: Shape of the input.
- `eps`: Value added to denominator for numerical stability.
- `use_weight`: Whether the module has learnable affine weights.
- `use_bias`: Whether the module has learnable affine shift.
"""
super().__init__(**kwargs)
if isinstance(shape, int):
shape = (shape,)
else:
shape = tuple(shape)
self.shape = shape
self.eps = eps
self.use_weight = use_weight
self.use_bias = use_bias
self.weight = jnp.ones(shape) if use_weight else None
self.bias = jnp.zeros(shape) if use_bias else None

@overload
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
...

@overload
def __call__(
self, x: Array, state: State, *, key: Optional[PRNGKeyArray] = None
) -> tuple[Array, State]:
...

@jax.named_scope("eqx.nn.RMSNorm")
def __call__(
self,
x: Float[Array, "*shape"],
state: State = sentinel,
*,
key: Optional[PRNGKeyArray] = None,
) -> Union[Array, tuple[Array, State]]:
"""**Arguments:**

- `x`: A JAX array, with the same shape as the `shape` passed to `__init__`.
- `state`: Ignored; provided for interchangability with the
[`equinox.nn.BatchNorm`][] API.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)

**Returns:**

The output is a JAX array of the same shape as `x`.

If `state` is passed, then a 2-tuple of `(output, state)` is returned. The state
is passed through unchanged. If `state` is not passed, then just the output is
returned.
"""
if x.shape != self.shape:
raise ValueError(
"`RMSNorm(shape)(x)` must satisfy the invariant `shape == x.shape`"
f"Received `shape={self.shape} and `x.shape={x.shape}`. You might need "
"to replace `rms_norm(x)` with `jax.vmap(rms_norm)(x)`.\n"
)
inv_rms = jax.lax.rsqrt(jnp.mean(x**2) + self.eps)
out = inv_rms * x
if self.use_weight:
out = self.weight * out
if self.use_bias:
out = out + self.bias
if state is sentinel:
return out
else:
return out, state
Loading