-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added RMSNorm #545
Added RMSNorm #545
Conversation
Nice! This would be good to have. Also, what should the default value of eps be? In my quick skim of the papers I didn't see it mentioned, i.e. eps=0. In practice I imagine have something here is a good idea, I'm just wary of divergences from canonical implementations. |
Ah! Good catch on the numerics! If this looks right I'll rebase into 1 commit to merge |
|
||
$$\frac{x}{\sqrt{\varepsilon + \frac{1}{n}\Vert x \Vert^2_2}} \gamma + \beta$$ | ||
|
||
where $\Vert \cdot \Vert_2$ is the 2-norm, $n = \dim(x)$, and $\gamma$ is a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd probably write it out explicitly, to avoid ambiguity whether "2-norm" is mean or sum.
|
||
where $\Vert \cdot \Vert_2$ is the 2-norm, $n = \dim(x)$, and $\gamma$ is a | ||
learned array with the same shape as $x$ if `use_weight=True`, or | ||
$\gamma = 1/\sqrt{n}$ if `use_weight=False`, as proposed in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not true as per discussion?
|
||
- `shape`: Shape of the input. | ||
- `eps`: Value added to denominator for numerical stability. | ||
- `use_weight`: Whether the module has learnable affine weights. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_bias
is missing?
``` | ||
""" | ||
|
||
shape: tuple[int] = field(static=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this annotation should be tuple[int, ...]
"to replace `rms_norm(x)` with `jax.vmap(rms_norm)(x)`.\n" | ||
) | ||
inv_rms = jax.lax.rsqrt(jnp.sum(x**2) + self.eps) | ||
out = jnp.sqrt(self.dim) * inv_rms * x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that inv_rms
is badly named here, as it's actually an inv-sum-of-squares right now.
Can you just replace jnp.sum
with jnp.mean
?
else: | ||
shape = tuple(shape) | ||
self.shape = shape | ||
self.dim = prod(self.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this is easily derivable from existing attributes (and after my suggested change in __call__
, unused) then I think it can be removed.
self.use_weight = use_weight | ||
self.use_bias = use_bias | ||
self.weight = jnp.ones(shape) if use_weight else None | ||
self.bias = jnp.ones(shape) if use_bias else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bias
should be initialized with jnp.zeros
so that at initialization the elementwise affine is a no-op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, hey Jon!
It looks like this PR isn't going anywhere at the moment -- if you or anyone else wants to pick it up then I'd still be happy to see this merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Patrick :)
I'd be happy to pick this up. Would you/Jason prefer that I make a new PR against main
, or should I create a PR against packquickly:rmsnorm
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No strong feelings! As you've done it is fine.
Closing in favour of #629. |
Added RMS normalisation, a simplified variant of layer norm which computes
where$n = \dim(x)$ and $\gamma$ is a learnable array if $\sqrt{n}$ if $\beta$ is an optional bias term. This has become somewhat popular in transformers for being simpler than layer norm but having similar/better performance.
use_weight=True
, oruse_weight=False
.Originally proposed in this paper, see Lucidrain's dicussion about this for more details.
The defaults
use_weight=False
anduse_bias=False
are intentional. This is meant to be a faster and simpler version of layer norm, so leaving these off by default made the most sense to me.