diff --git a/README.md b/README.md index c238a8f..b90be95 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,15 @@ run it (torch) PYTHONPATH=. poetry run python entropix/torch_main.py ``` - NOTES: If you're using using the torch parts only, you can `export XLA_PYTHON_CLIENT_PREALLOCATE=false` to prevent jax from doing jax things and hogging your VRAM + +## Tips and Tricks + +### Jit is Too Slow + +For rapid iteration, `jax.jit` might be too slow. In this case, set: +``` +JAX_DISABLE_JIT=True +``` +in your environment to disable it. diff --git a/entropix/model.py b/entropix/model.py index 31acb91..7e7e01e 100644 --- a/entropix/model.py +++ b/entropix/model.py @@ -1,15 +1,13 @@ +from functools import partial from typing import Optional, Tuple import jax import jax.numpy as jnp -from functools import partial - from entropix.config import ModelParams from entropix.kvcache import KVCache from entropix.stats import AttnStats -from entropix.weights import XfmrWeights, LayerWeights - +from entropix.weights import LayerWeights, XfmrWeights DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) @@ -19,7 +17,7 @@ def rms_norm(x: jax.Array, w: jax.Array, eps: float = 1e-6) -> jax.Array: return w * (x * jax.lax.rsqrt(jax.lax.pow(x, 2).mean(-1, keepdims=True) + eps)) -#@partial(jax.jit, static_argnames=("dtype")) +@partial(jax.jit, static_argnames=("dtype")) def apply_rotary_emb(xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array, dtype: jnp.dtype = jnp.float32) -> Tuple[jax.Array, jax.Array]: reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) @@ -31,8 +29,8 @@ def apply_rotary_emb(xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array, dtype: xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1) return xq_out.astype(dtype), xk_out.astype(dtype) -#@partial(jax.jit, static_argnames=("model_params", "cur_pos", "layer_idx")) -def attention(x: jax.Array, layer_weights: LayerWeights, model_params, cur_pos: int, layer_idx: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array] = None) -> Tuple[jax.Array, KVCache]: +@partial(jax.jit, static_argnames=("model_params", "cur_pos", "layer_idx")) +def attention(x: jax.Array, layer_weights: LayerWeights, model_params, cur_pos: int, layer_idx: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array] = None) -> Tuple[jax.Array, KVCache, jax.Array]: bsz, _, _ = x.shape n_rep = model_params.n_local_heads // model_params.n_local_kv_heads xq = jnp.dot(x, layer_weights.wq.T).reshape(bsz, -1, model_params.n_local_heads, model_params.head_dim) @@ -56,12 +54,12 @@ def attention(x: jax.Array, layer_weights: LayerWeights, model_params, cur_pos: out = jnp.dot(output, layer_weights.wo.T) return out, kvcache, pre_scores -#@partial(jax.jit) +@partial(jax.jit) def feed_forward(x: jax.Array, layer_weights: LayerWeights) -> jax.Array: return jnp.dot(jax.nn.silu(jnp.dot(x, layer_weights.w1.T)) * jnp.dot(x, layer_weights.w3.T), layer_weights.w2.T) -#@partial(jax.jit, static_argnames=("model_params", "cur_pos")) -def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: jax.Array, cur_pos: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array]=None) -> Tuple[jax.Array, KVCache]: +@partial(jax.jit, static_argnames=("model_params", "cur_pos")) +def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: jax.Array, cur_pos: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array]=None) -> Tuple[jax.Array, KVCache, jax.Array, AttnStats]: h = xfmr_weights.tok_embeddings[tokens] attn_stats = AttnStats.new( bsz=tokens.shape[0],