Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable jit by default; add documentation on jit-disabling envvar #34

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ run it (torch)
PYTHONPATH=. poetry run python entropix/torch_main.py
```


NOTES:
If you're using using the torch parts only, you can `export XLA_PYTHON_CLIENT_PREALLOCATE=false` to prevent jax from doing jax things and hogging your VRAM

## Tips and Tricks

### Jit is Too Slow

For rapid iteration, `jax.jit` might be too slow. In this case, set:
```
JAX_DISABLE_JIT=True
```
in your environment to disable it.
18 changes: 8 additions & 10 deletions entropix/model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from functools import partial
from typing import Optional, Tuple

import jax
import jax.numpy as jnp

from functools import partial

from entropix.config import ModelParams
from entropix.kvcache import KVCache
from entropix.stats import AttnStats
from entropix.weights import XfmrWeights, LayerWeights

from entropix.weights import LayerWeights, XfmrWeights

DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)

Expand All @@ -19,7 +17,7 @@ def rms_norm(x: jax.Array, w: jax.Array, eps: float = 1e-6) -> jax.Array:
return w * (x * jax.lax.rsqrt(jax.lax.pow(x, 2).mean(-1, keepdims=True) + eps))


#@partial(jax.jit, static_argnames=("dtype"))
@partial(jax.jit, static_argnames=("dtype"))
def apply_rotary_emb(xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array, dtype: jnp.dtype = jnp.float32) -> Tuple[jax.Array, jax.Array]:
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
Expand All @@ -31,8 +29,8 @@ def apply_rotary_emb(xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array, dtype:
xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
return xq_out.astype(dtype), xk_out.astype(dtype)

#@partial(jax.jit, static_argnames=("model_params", "cur_pos", "layer_idx"))
def attention(x: jax.Array, layer_weights: LayerWeights, model_params, cur_pos: int, layer_idx: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array] = None) -> Tuple[jax.Array, KVCache]:
@partial(jax.jit, static_argnames=("model_params", "cur_pos", "layer_idx"))
def attention(x: jax.Array, layer_weights: LayerWeights, model_params, cur_pos: int, layer_idx: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array] = None) -> Tuple[jax.Array, KVCache, jax.Array]:
bsz, _, _ = x.shape
n_rep = model_params.n_local_heads // model_params.n_local_kv_heads
xq = jnp.dot(x, layer_weights.wq.T).reshape(bsz, -1, model_params.n_local_heads, model_params.head_dim)
Expand All @@ -56,12 +54,12 @@ def attention(x: jax.Array, layer_weights: LayerWeights, model_params, cur_pos:
out = jnp.dot(output, layer_weights.wo.T)
return out, kvcache, pre_scores

#@partial(jax.jit)
@partial(jax.jit)
def feed_forward(x: jax.Array, layer_weights: LayerWeights) -> jax.Array:
return jnp.dot(jax.nn.silu(jnp.dot(x, layer_weights.w1.T)) * jnp.dot(x, layer_weights.w3.T), layer_weights.w2.T)

#@partial(jax.jit, static_argnames=("model_params", "cur_pos"))
def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: jax.Array, cur_pos: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array]=None) -> Tuple[jax.Array, KVCache]:
@partial(jax.jit, static_argnames=("model_params", "cur_pos"))
def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: jax.Array, cur_pos: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array]=None) -> Tuple[jax.Array, KVCache, jax.Array, AttnStats]:
h = xfmr_weights.tok_embeddings[tokens]
attn_stats = AttnStats.new(
bsz=tokens.shape[0],
Expand Down