Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RMSNorm #545

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/api/nn/normalisation.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@
members:
- __init__
- __call__

---

::: equinox.nn.RMSNorm
selection:
members:
- __init__
- __call__
6 changes: 5 additions & 1 deletion equinox/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from ._inference import inference_mode as inference_mode
from ._linear import Identity as Identity, Linear as Linear
from ._mlp import MLP as MLP
from ._normalisation import GroupNorm as GroupNorm, LayerNorm as LayerNorm
from ._normalisation import (
GroupNorm as GroupNorm,
LayerNorm as LayerNorm,
RMSNorm as RMSNorm,
)
from ._pool import (
AdaptiveAvgPool1d as AdaptiveAvgPool1d,
AdaptiveAvgPool2d as AdaptiveAvgPool2d,
Expand Down
123 changes: 123 additions & 0 deletions equinox/nn/_normalisation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools as ft
import warnings
from collections.abc import Sequence
from math import prod
from typing import Optional, overload, Union

import jax
Expand Down Expand Up @@ -267,3 +268,125 @@ def __call__(
return out
else:
return out, state


class RMSNorm(Module):
r"""
A simplified version of LayerNorm which rescales the inputs, but does not center
them. Optionally applies a learned reweighting of the transformed array afterward.

Given an input array $x$, this layer computes

$$\frac{x}{\sqrt{\varepsilon + \frac{1}{n}\Vert x \Vert^2_2}} \gamma + \beta$$

where $\Vert \cdot \Vert_2$ is the 2-norm, $n = \dim(x)$, and $\gamma$ is a
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably write it out explicitly, to avoid ambiguity whether "2-norm" is mean or sum.

learned array with the same shape as $x$ if `use_weight=True`, or
$\gamma = 1/\sqrt{n}$ if `use_weight=False`, as proposed in
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not true as per discussion?

[this paper](https://browse.arxiv.org/abs/2307.14995). `\beta` is an optional bias
term.

??? cite
[Root Mean Square Layer Normalization](https://browse.arxiv.org/abs/1910.07467)

```bibtex
@article{zhang2019root,
title={Root Mean Square Layer Normalization},
author={Biao Zhang and Rico Sennrich},
year={2019},
journal={arXiv:1910.07467}
}
```
"""

shape: tuple[int] = field(static=True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this annotation should be tuple[int, ...]

dim: int = field(static=True)
eps: float = field(static=True)
use_weight: bool = field(static=True)
use_bias: bool = field(static=True)
weight: Optional[Float[Array, "*shape"]]
bias: Optional[Float[Array, "*shape"]]

def __init__(
self,
shape: Union[int, Sequence[int]],
eps: float = 1e-5,
use_weight: bool = False,
use_bias: bool = False,
**kwargs,
):
"""**Arguments:**

- `shape`: Shape of the input.
- `eps`: Value added to denominator for numerical stability.
- `use_weight`: Whether the module has learnable affine weights.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_bias is missing?

"""
super().__init__(**kwargs)
if isinstance(shape, int):
shape = (shape,)
else:
shape = tuple(shape)
self.shape = shape
self.dim = prod(self.shape)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is easily derivable from existing attributes (and after my suggested change in __call__, unused) then I think it can be removed.

self.eps = eps
self.use_weight = use_weight
self.use_bias = use_bias
self.weight = jnp.ones(shape) if use_weight else None
self.bias = jnp.ones(shape) if use_bias else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias should be initialized with jnp.zeros so that at initialization the elementwise affine is a no-op.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, hey Jon!

It looks like this PR isn't going anywhere at the moment -- if you or anyone else wants to pick it up then I'd still be happy to see this merged.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Patrick :)

I'd be happy to pick this up. Would you/Jason prefer that I make a new PR against main, or should I create a PR against packquickly:rmsnorm?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong feelings! As you've done it is fine.


@overload
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
...

@overload
def __call__(
self, x: Array, state: State, *, key: Optional[PRNGKeyArray] = None
) -> tuple[Array, State]:
...

@jax.named_scope("eqx.nn.RMSNorm")
def __call__(
self,
x: Float[Array, "*shape"],
state: State = sentinel,
*,
key: Optional[PRNGKeyArray] = None,
) -> Union[Array, tuple[Array, State]]:
"""**Arguments:**

- `x`: A JAX array, with the same shape as the `shape` passed to `__init__`.
- `state`: Ignored; provided for interchangability with the
[`equinox.nn.BatchNorm`][] API.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)

**Returns:**

The output is a JAX array of the same shape as `x`.

If `state` is passed, then a 2-tuple of `(output, state)` is returned. The state
is passed through unchanged. If `state` is not passed, then just the output is
returned.
"""
if x.shape != self.shape:
raise ValueError(
"`RMSNorm(shape)(x)` must satisfy the invariant `shape == x.shape`"
f"Received `shape={self.shape} and `x.shape={x.shape}`. You might need "
"to replace `rms_norm(x)` with `jax.vmap(rms_norm)(x)`.\n"
"\n"
"If this is a new error for you, it might be because this became "
packquickly marked this conversation as resolved.
Show resolved Hide resolved
"stricter in Equinox v0.11.0. Previously all that was required is that "
"`x.shape` ended with `shape`. However, this turned out to be a "
"frequent source of bugs, so we made the check stricter!"
)
inv_rms = jax.lax.rsqrt(jnp.sum(x**2) + self.eps)
packquickly marked this conversation as resolved.
Show resolved Hide resolved
out = inv_rms * x
if self.use_weight:
out = self.weight * out
else:
out = jnp.sqrt(self.dim) * out
if self.use_bias:
out = out + self.bias
if state is sentinel:
return out
else:
return out, state