-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added RMSNorm #545
Added RMSNorm #545
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,3 +29,11 @@ | |
members: | ||
- __init__ | ||
- __call__ | ||
|
||
--- | ||
|
||
::: equinox.nn.RMSNorm | ||
selection: | ||
members: | ||
- __init__ | ||
- __call__ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import functools as ft | ||
import warnings | ||
from collections.abc import Sequence | ||
from math import prod | ||
from typing import Optional, overload, Union | ||
|
||
import jax | ||
|
@@ -267,3 +268,125 @@ def __call__( | |
return out | ||
else: | ||
return out, state | ||
|
||
|
||
class RMSNorm(Module): | ||
r""" | ||
A simplified version of LayerNorm which rescales the inputs, but does not center | ||
them. Optionally applies a learned reweighting of the transformed array afterward. | ||
|
||
Given an input array $x$, this layer computes | ||
|
||
$$\frac{x}{\sqrt{\varepsilon + \frac{1}{n}\Vert x \Vert^2_2}} \gamma + \beta$$ | ||
|
||
where $\Vert \cdot \Vert_2$ is the 2-norm, $n = \dim(x)$, and $\gamma$ is a | ||
learned array with the same shape as $x$ if `use_weight=True`, or | ||
$\gamma = 1/\sqrt{n}$ if `use_weight=False`, as proposed in | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not true as per discussion? |
||
[this paper](https://browse.arxiv.org/abs/2307.14995). `\beta` is an optional bias | ||
term. | ||
|
||
??? cite | ||
[Root Mean Square Layer Normalization](https://browse.arxiv.org/abs/1910.07467) | ||
|
||
```bibtex | ||
@article{zhang2019root, | ||
title={Root Mean Square Layer Normalization}, | ||
author={Biao Zhang and Rico Sennrich}, | ||
year={2019}, | ||
journal={arXiv:1910.07467} | ||
} | ||
``` | ||
""" | ||
|
||
shape: tuple[int] = field(static=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this annotation should be |
||
dim: int = field(static=True) | ||
eps: float = field(static=True) | ||
use_weight: bool = field(static=True) | ||
use_bias: bool = field(static=True) | ||
weight: Optional[Float[Array, "*shape"]] | ||
bias: Optional[Float[Array, "*shape"]] | ||
|
||
def __init__( | ||
self, | ||
shape: Union[int, Sequence[int]], | ||
eps: float = 1e-5, | ||
use_weight: bool = False, | ||
use_bias: bool = False, | ||
**kwargs, | ||
): | ||
"""**Arguments:** | ||
|
||
- `shape`: Shape of the input. | ||
- `eps`: Value added to denominator for numerical stability. | ||
- `use_weight`: Whether the module has learnable affine weights. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
super().__init__(**kwargs) | ||
if isinstance(shape, int): | ||
shape = (shape,) | ||
else: | ||
shape = tuple(shape) | ||
self.shape = shape | ||
self.dim = prod(self.shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As this is easily derivable from existing attributes (and after my suggested change in |
||
self.eps = eps | ||
self.use_weight = use_weight | ||
self.use_bias = use_bias | ||
self.weight = jnp.ones(shape) if use_weight else None | ||
self.bias = jnp.ones(shape) if use_bias else None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, hey Jon! It looks like this PR isn't going anywhere at the moment -- if you or anyone else wants to pick it up then I'd still be happy to see this merged. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey Patrick :) I'd be happy to pick this up. Would you/Jason prefer that I make a new PR against There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No strong feelings! As you've done it is fine. |
||
|
||
@overload | ||
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array: | ||
... | ||
|
||
@overload | ||
def __call__( | ||
self, x: Array, state: State, *, key: Optional[PRNGKeyArray] = None | ||
) -> tuple[Array, State]: | ||
... | ||
|
||
@jax.named_scope("eqx.nn.RMSNorm") | ||
def __call__( | ||
self, | ||
x: Float[Array, "*shape"], | ||
state: State = sentinel, | ||
*, | ||
key: Optional[PRNGKeyArray] = None, | ||
) -> Union[Array, tuple[Array, State]]: | ||
"""**Arguments:** | ||
|
||
- `x`: A JAX array, with the same shape as the `shape` passed to `__init__`. | ||
- `state`: Ignored; provided for interchangability with the | ||
[`equinox.nn.BatchNorm`][] API. | ||
- `key`: Ignored; provided for compatibility with the rest of the Equinox API. | ||
(Keyword only argument.) | ||
|
||
**Returns:** | ||
|
||
The output is a JAX array of the same shape as `x`. | ||
|
||
If `state` is passed, then a 2-tuple of `(output, state)` is returned. The state | ||
is passed through unchanged. If `state` is not passed, then just the output is | ||
returned. | ||
""" | ||
if x.shape != self.shape: | ||
raise ValueError( | ||
"`RMSNorm(shape)(x)` must satisfy the invariant `shape == x.shape`" | ||
f"Received `shape={self.shape} and `x.shape={x.shape}`. You might need " | ||
"to replace `rms_norm(x)` with `jax.vmap(rms_norm)(x)`.\n" | ||
"\n" | ||
"If this is a new error for you, it might be because this became " | ||
packquickly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"stricter in Equinox v0.11.0. Previously all that was required is that " | ||
"`x.shape` ended with `shape`. However, this turned out to be a " | ||
"frequent source of bugs, so we made the check stricter!" | ||
) | ||
inv_rms = jax.lax.rsqrt(jnp.sum(x**2) + self.eps) | ||
packquickly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
out = inv_rms * x | ||
if self.use_weight: | ||
out = self.weight * out | ||
else: | ||
out = jnp.sqrt(self.dim) * out | ||
if self.use_bias: | ||
out = out + self.bias | ||
if state is sentinel: | ||
return out | ||
else: | ||
return out, state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd probably write it out explicitly, to avoid ambiguity whether "2-norm" is mean or sum.