Skip to content

Commit

Permalink
Address code reviews.
Browse files Browse the repository at this point in the history
  • Loading branch information
jondeaton committed Dec 27, 2023
1 parent 790de08 commit bab8e4a
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions equinox/nn/_normalisation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import functools as ft
import warnings
from collections.abc import Sequence
from math import prod
from typing import Optional, overload, Union

import jax
Expand Down Expand Up @@ -279,9 +278,9 @@ class RMSNorm(Module):
$$\frac{x}{\sqrt{\varepsilon + \frac{1}{n}\Vert x \Vert^2_2}} \gamma + \beta$$
where $\Vert \cdot \Vert_2$ is the 2-norm, $n = \dim(x)$, and $\gamma$ is a
where $\Vert x \Vert^2_2 = \sum_{i=1}^n x_i^2$, $n = \dim(x)$, and $\gamma$ is a
learned array with the same shape as $x$ if `use_weight=True`, or
$\gamma = 1/\sqrt{n}$ if `use_weight=False`, as proposed in
$\gamma = 1$ if `use_weight=False`, as proposed in
[this paper](https://browse.arxiv.org/abs/2307.14995). `\beta` is an optional bias
term.
Expand All @@ -298,8 +297,7 @@ class RMSNorm(Module):
```
"""

shape: tuple[int] = field(static=True)
dim: int = field(static=True)
shape: tuple[int, ...] = field(static=True)
eps: float = field(static=True)
use_weight: bool = field(static=True)
use_bias: bool = field(static=True)
Expand All @@ -319,19 +317,19 @@ def __init__(
- `shape`: Shape of the input.
- `eps`: Value added to denominator for numerical stability.
- `use_weight`: Whether the module has learnable affine weights.
- `use_bias`: Whether the module has learnable affine shift.
"""
super().__init__(**kwargs)
if isinstance(shape, int):
shape = (shape,)
else:
shape = tuple(shape)
self.shape = shape
self.dim = prod(self.shape)
self.eps = eps
self.use_weight = use_weight
self.use_bias = use_bias
self.weight = jnp.ones(shape) if use_weight else None
self.bias = jnp.ones(shape) if use_bias else None
self.bias = jnp.zeros(shape) if use_bias else None

@overload
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
Expand Down Expand Up @@ -373,8 +371,8 @@ def __call__(
f"Received `shape={self.shape} and `x.shape={x.shape}`. You might need "
"to replace `rms_norm(x)` with `jax.vmap(rms_norm)(x)`.\n"
)
inv_rms = jax.lax.rsqrt(jnp.sum(x**2) + self.eps)
out = jnp.sqrt(self.dim) * inv_rms * x
inv_rms = jax.lax.rsqrt(jnp.mean(x**2) + self.eps)
out = inv_rms * x
if self.use_weight:
out = self.weight * out
if self.use_bias:
Expand Down

0 comments on commit bab8e4a

Please sign in to comment.