Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RMSNorm #545

Closed
wants to merge 2 commits into from
Closed

Added RMSNorm #545

wants to merge 2 commits into from

Conversation

packquickly
Copy link
Contributor

Added RMS normalisation, a simplified variant of layer norm which computes

$$ \frac{x}{\sqrt{\varepsilon + \frac{1}{n}\Vert x \Vert^2_2}} \gamma + \beta$$

where $n = \dim(x)$ and $\gamma$ is a learnable array if use_weight=True, or $\sqrt{n}$ if use_weight=False. $\beta$ is an optional bias term. This has become somewhat popular in transformers for being simpler than layer norm but having similar/better performance.

Originally proposed in this paper, see Lucidrain's dicussion about this for more details.

The defaults use_weight=False and use_bias=False are intentional. This is meant to be a faster and simpler version of layer norm, so leaving these off by default made the most sense to me.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Oct 9, 2023

Nice! This would be good to have.
For use_weight=False, I don't think your description and the code quite match up -- you have both a 1/n and a gamma = sqrt(n) factor; I think you only want one of the two. (I also checked both papers) and only the one scaling factor seems to be consistent with them too; they basically just do x/rms(x)).

Also, what should the default value of eps be? In my quick skim of the papers I didn't see it mentioned, i.e. eps=0. In practice I imagine have something here is a good idea, I'm just wary of divergences from canonical implementations.

equinox/nn/_normalisation.py Outdated Show resolved Hide resolved
equinox/nn/_normalisation.py Show resolved Hide resolved
@packquickly
Copy link
Contributor Author

Ah! Good catch on the numerics! If this looks right I'll rebase into 1 commit to merge


$$\frac{x}{\sqrt{\varepsilon + \frac{1}{n}\Vert x \Vert^2_2}} \gamma + \beta$$

where $\Vert \cdot \Vert_2$ is the 2-norm, $n = \dim(x)$, and $\gamma$ is a
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably write it out explicitly, to avoid ambiguity whether "2-norm" is mean or sum.


where $\Vert \cdot \Vert_2$ is the 2-norm, $n = \dim(x)$, and $\gamma$ is a
learned array with the same shape as $x$ if `use_weight=True`, or
$\gamma = 1/\sqrt{n}$ if `use_weight=False`, as proposed in
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not true as per discussion?


- `shape`: Shape of the input.
- `eps`: Value added to denominator for numerical stability.
- `use_weight`: Whether the module has learnable affine weights.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_bias is missing?

```
"""

shape: tuple[int] = field(static=True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this annotation should be tuple[int, ...]

"to replace `rms_norm(x)` with `jax.vmap(rms_norm)(x)`.\n"
)
inv_rms = jax.lax.rsqrt(jnp.sum(x**2) + self.eps)
out = jnp.sqrt(self.dim) * inv_rms * x
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that inv_rms is badly named here, as it's actually an inv-sum-of-squares right now.

Can you just replace jnp.sum with jnp.mean?

else:
shape = tuple(shape)
self.shape = shape
self.dim = prod(self.shape)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is easily derivable from existing attributes (and after my suggested change in __call__, unused) then I think it can be removed.

self.use_weight = use_weight
self.use_bias = use_bias
self.weight = jnp.ones(shape) if use_weight else None
self.bias = jnp.ones(shape) if use_bias else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias should be initialized with jnp.zeros so that at initialization the elementwise affine is a no-op.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, hey Jon!

It looks like this PR isn't going anywhere at the moment -- if you or anyone else wants to pick it up then I'd still be happy to see this merged.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Patrick :)

I'd be happy to pick this up. Would you/Jason prefer that I make a new PR against main, or should I create a PR against packquickly:rmsnorm?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong feelings! As you've done it is fine.

@jondeaton jondeaton mentioned this pull request Dec 27, 2023
@patrick-kidger
Copy link
Owner

Closing in favour of #629.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants