diff --git a/FETCH_HEAD b/FETCH_HEAD new file mode 100644 index 0000000..e69de29 diff --git a/entropix/config.py b/entropix/config.py index b1ab03b..b704bfa 100644 --- a/entropix/config.py +++ b/entropix/config.py @@ -1,5 +1,21 @@ from typing import NamedTuple +class ScaledRopeParams(NamedTuple): + scale_factor: int # = 8 + low_freq_factor: int # = 1 + high_freq_factor: int # = 4 + old_context_len: int # = 8192 (original llama3 length) + +class ModelParams(NamedTuple): + n_layers: int + n_local_heads: int + n_local_kv_heads: int + head_dim: int + vocab_size: int + max_seq_len: int + scaled_rope_params: ScaledRopeParams + use_scaled_rope: bool + rope_theta: float params = { "dim": 2048, @@ -15,23 +31,19 @@ "max_seq_len": 4096 } - -class ModelParams(NamedTuple): - n_layers: int - n_local_heads: int - n_local_kv_heads: int - head_dim: int - max_seq_len: int - rope_theta: float - use_scaled_rope: bool - - LLAMA_1B_PARAMS = ModelParams( n_layers=params["n_layers"], n_local_heads=params["n_heads"], n_local_kv_heads=params["n_kv_heads"], head_dim=params["dim"] // params["n_heads"], + vocab_size=params["vocab_size"], max_seq_len=params["max_seq_len"], rope_theta=params["rope_theta"], - use_scaled_rope=params["use_scaled_rope"] -) + use_scaled_rope=params["use_scaled_rope"], + scaled_rope_params=ScaledRopeParams( + scale_factor=8, + low_freq_factor=1, + high_freq_factor=4, + old_context_len=8192 + ) +) \ No newline at end of file diff --git a/entropix/generator.py b/entropix/generator.py new file mode 100644 index 0000000..effdc1d --- /dev/null +++ b/entropix/generator.py @@ -0,0 +1,135 @@ +from entropix.model import KVCache, xfmr +from entropix.rope import precompute_freqs_cis +from entropix.sampler import sample, SamplerConfig +from entropix.config import ModelParams +from entropix.stats import AttnStats +from entropix.sampler import SamplerConfig +from entropix.model import xfmr +from entropix.sampler import _sample +from entropix.utils import calculate_varentropy_logsoftmax + +import jax.numpy as jnp +import jax +from typing import NamedTuple +class InitialState(NamedTuple): + tokens: jax.Array + kvcache: KVCache + attn_stats: AttnStats + freqs_cis: jax.Array + attn_mask: jax.Array + stop_tokens: jax.Array + sampler_cfg: SamplerConfig + logits_cache: jax.Array + + +def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array: + mask = jnp.zeros((seqlen, seqlen), dtype=jnp.float32) + if seqlen > 1: + mask = jnp.full((seqlen, seqlen), float('-inf')) + mask = jnp.triu(mask, k=1) + mask = jnp.hstack([jnp.zeros((seqlen, start_pos)), mask], dtype=jnp.float32) + return mask + + +def initialize(model_params, tokens, max_gen_len): + tokens = jnp.array([tokens], dtype=jnp.int32) + bsz, seqlen = tokens.shape + max_total_len = seqlen + max_gen_len + attn_mask = build_attn_mask(seqlen, 0) + freqs_cis = precompute_freqs_cis(model_params) + kvcache = KVCache.new(model_params, bsz, max_total_len) + attn_stats = AttnStats.new(model_params, bsz, max_total_len) + logits_cache = jnp.zeros((bsz, max_total_len, model_params.vocab_size), dtype=jnp.float32) + stop = jnp.array([128001, 128008, 128009], dtype=jnp.int32) + sampler_cfg = SamplerConfig() + return { + 'tokens': tokens, + 'kvcache': kvcache, + 'attn_stats': attn_stats, + 'freqs_cis': freqs_cis, + 'attn_mask': attn_mask, + 'stop_tokens': stop, + 'sampler_cfg': sampler_cfg, + 'logits_cache': logits_cache, + } + +def generate(xfmr_weights, model_params, tokenizer, initial_state, max_gen_len): + kvcache = initial_state['kvcache'] + attn_stats = initial_state['attn_stats'] + attn_mask = initial_state['attn_mask'] + freqs_cis = initial_state['freqs_cis'] + stop_tokens = initial_state['stop_tokens'] + sampler_cfg = initial_state['sampler_cfg'] + tokens = initial_state['tokens'] + + prompt_len = tokens.shape[1] + + logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, 0, freqs_cis[:prompt_len], kvcache, attn_stats, attn_mask=attn_mask) + cur_pos, max_total_len = prompt_len, prompt_len + max_gen_len + next_token = jnp.argmax(logits[:,-1], axis=-1, keepdims=True).astype(jnp.int32) + gen_tokens = next_token + print(tokenizer.decode([next_token.item()]), end='', flush=True) + while cur_pos < max_total_len: + cur_pos += 1 + logits, kvcache, scores, attn_stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache, attn_stats) + next_token = sample(gen_tokens, logits, scores, cfg=sampler_cfg) + gen_tokens = jnp.concatenate((gen_tokens, next_token)) + print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True) + if jnp.isin(next_token, stop_tokens).any(): + break + +def vanilla_generate(xfmr_weights, model_params, tokenizer, initial_state, n_gen_tokens, rng): + kvcache = initial_state['kvcache'] + attn_stats = initial_state['attn_stats'] + attn_mask = initial_state['attn_mask'] + freqs_cis = initial_state['freqs_cis'] + sampler_cfg = initial_state['sampler_cfg'] + logits_cache = initial_state['logits_cache'] + tokens = initial_state['tokens'] + + prompt_len = tokens.shape[1] + + logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, 0, freqs_cis[:prompt_len], kvcache, attn_stats, attn_mask=attn_mask) + logits_cache = logits_cache.at[:, :prompt_len, :].set(logits) + cur_pos, max_total_len = prompt_len, prompt_len + n_gen_tokens + next_token = _sample(logits_cache[:, cur_pos:cur_pos+1, :], temperature=sampler_cfg.temp, min_p=sampler_cfg.min_p, top_k=sampler_cfg.top_k, top_p=sampler_cfg.top_p, key=rng) + gen_tokens = next_token + + print(tokenizer.decode([next_token.item()]), end='', flush=True) + while cur_pos < max_total_len: + cur_pos += 1 + logits, kvcache, scores, attn_stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache, attn_stats) + logits_cache = logits_cache.at[:, cur_pos:cur_pos+1, :].set(logits) + next_token = _sample(logits, temperature=sampler_cfg.temp, min_p=sampler_cfg.min_p, top_k=sampler_cfg.top_k, top_p=sampler_cfg.top_p, key=rng) + gen_tokens = jnp.concatenate((gen_tokens, next_token)) + print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True) + return gen_tokens, logits_cache, attn_stats + +def score_N(xfmr_weights: jax.Array, model_params: ModelParams, tokens: jax.Array, start_pos: int, N:int): + """ + This function calculates a model's scoring of a (batch of) sequence(s) of tokens in various ways. + + tokens: jax.Array, shape (batch_size, tokens.shape[1], N) + start_pos: int, the position in the sequence to start scoring + N: int, the number of sequences to score + """ + initial_state = initialize(model_params, tokens, 1) + seqlen = tokens.shape[1] + logits, _, scores, attn_stats = xfmr( + xfmr_weights=xfmr_weights, + model_params=model_params, + cur_pos=0, + tokens=initial_state['tokens'], + freqs_cis=initial_state['freqs_cis'][:seqlen], + kvcache=initial_state['kvcache'], + attn_stats=initial_state['attn_stats'], + attn_mask=initial_state['attn_mask'] + ) + shape = logits.shape # (batch_size, tokens.shape[1]*N, vocab_size) + logits = logits.reshape(tokens.shape[0], tokens.shape[1], N, model_params.vocab_size).transpose(0, 1, 3, 2) # (batch_size, tokens.shape[1], vocab_size, N) <--(batch_size, tokens.shape[1]*N, vocab_size) + log_probs = jax.nn.log_softmax(logits, axis=2) + log_joint_probs = log_probs.sum(axis=1) + joint_entropy, joint_varentropy = calculate_varentropy_logsoftmax(log_joint_probs, axis=-1) # (batch_size, tokens.shape[1], vocab_size) + log_likelihood = jnp.take_along_axis(log_probs[:, start_pos-1:-1,:], tokens[:, start_pos:, :, None], axis=-1).squeeze(-1) # (batch_size, tokens.shape[1]-start_pos, N) + cross_entropy = log_likelihood.sum(axis=1) # (batch_size, N) + return cross_entropy, joint_entropy, joint_varentropy \ No newline at end of file diff --git a/entropix/kvcache.py b/entropix/kvcache.py index 392c80a..2d16075 100644 --- a/entropix/kvcache.py +++ b/entropix/kvcache.py @@ -1,20 +1,24 @@ -from typing import NamedTuple - import jax import jax.numpy as jnp +from entropix.config import ModelParams +from typing import NamedTuple + class KVCache(NamedTuple): k: jax.Array v: jax.Array - + @classmethod - def new(cls, layers: int, bsz: int, max_seq_len: int, kv_heads: int, head_dim: int) -> 'KVCache': - return cls( - k=jnp.zeros((layers, bsz, max_seq_len, kv_heads, head_dim), dtype=jnp.bfloat16), - v=jnp.zeros((layers, bsz, max_seq_len, kv_heads, head_dim), dtype=jnp.bfloat16) - ) - + def new(cls, model_params: ModelParams, bsz: int, max_total_len) -> 'KVCache': + kv_heads = model_params.n_local_kv_heads + layers = model_params.n_layers + head_dim = model_params.head_dim + return cls( + k=jnp.zeros((layers, bsz, max_total_len, kv_heads, head_dim), dtype=jnp.bfloat16), + v=jnp.zeros((layers, bsz, max_total_len, kv_heads, head_dim), dtype=jnp.bfloat16) + ) + def update(self, xk: jax.Array, xv: jax.Array, layer_idx: int, cur_pos: int, n_rep: int): ck = jax.lax.dynamic_update_slice(self.k, jnp.bfloat16(xk[None, ...]), (layer_idx, 0, cur_pos, 0, 0)) cv = jax.lax.dynamic_update_slice(self.v, jnp.bfloat16(xv[None, ...]), (layer_idx, 0, cur_pos, 0, 0)) @@ -27,3 +31,4 @@ def update(self, xk: jax.Array, xv: jax.Array, layer_idx: int, cur_pos: int, n_r return keys, values, KVCache(k=ck, v=cv) + diff --git a/entropix/main.py b/entropix/main.py index b8dcc42..6c2db0a 100644 --- a/entropix/main.py +++ b/entropix/main.py @@ -5,51 +5,17 @@ import jax.numpy as jnp import tyro +from pathlib import Path + from entropix.config import LLAMA_1B_PARAMS -from entropix.kvcache import KVCache -from entropix.model import xfmr -from entropix.sampler import SamplerConfig, sample from entropix.prompts import create_prompts_from_csv, prompt from entropix.sampler import sample from entropix.tokenizer import Tokenizer from entropix.weights import load_weights +from entropix.generator import generate, vanilla_generate, initialize -DEFAULT_WEIGHTS_PATH = Path(__file__).parent / '../weights' - -def apply_scaling(freqs: jax.Array): - SCALE_FACTOR = 8 - LOW_FREQ_FACTOR = 1 - HIGH_FREQ_FACTOR = 4 - OLD_CONTEXT_LEN = 8192 # original llama3 length - - low_freq_wavelen = OLD_CONTEXT_LEN / LOW_FREQ_FACTOR - high_freq_wavelen = OLD_CONTEXT_LEN / HIGH_FREQ_FACTOR - - def scale_freq(freq): - wavelen = 2 * math.pi / freq - - def scale_mid(_): - smooth = (OLD_CONTEXT_LEN / wavelen - LOW_FREQ_FACTOR) / (HIGH_FREQ_FACTOR - LOW_FREQ_FACTOR) - return (1 - smooth) * freq / SCALE_FACTOR + smooth * freq - - return jax.lax.cond( - wavelen < high_freq_wavelen, - lambda _: freq, - lambda _: jax.lax.cond(wavelen > low_freq_wavelen, lambda _: freq / SCALE_FACTOR, scale_mid, None), - None - ) - - return jax.vmap(scale_freq)(freqs) - - -def precompute_freqs_cis(dim: int, end: int, theta: float = 500000.0, use_scaled: bool = False, dtype: jnp.dtype = jnp.float32) -> jax.Array: - freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) - if use_scaled: - freqs = apply_scaling(freqs) - t = jnp.arange(end, dtype=dtype) - freqs = jnp.outer(t, freqs) - return jnp.exp(1j * freqs) +DEFAULT_WEIGHTS_PATH = Path(__file__).parent / '../weights' def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array: mask = jnp.zeros((seqlen, seqlen), dtype=jnp.float32) @@ -62,47 +28,28 @@ def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array: def main(weights_path: Path = DEFAULT_WEIGHTS_PATH.joinpath('1B-Instruct')): model_params = LLAMA_1B_PARAMS - xfmr_weights = load_weights(weights_path.absolute()) - tokenizer = Tokenizer('entropix/tokenizer.model') - - # Create the batch of tokens - def generate(xfmr_weights, model_params, tokens): - gen_tokens = None - cur_pos = 0 - tokens = jnp.array([tokens], jnp.int32) - bsz, seqlen = tokens.shape - attn_mask = build_attn_mask(seqlen, cur_pos) - freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope) - kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim) - logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask) - next_token = jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32) - gen_tokens = next_token - print(tokenizer.decode([next_token.item()]), end='', flush=True) - cur_pos = seqlen - stop = jnp.array([128001, 128008, 128009]) - sampler_cfg = SamplerConfig() - while cur_pos < 8192: - cur_pos += 1 - logits, kvcache, scores, stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache) - next_token = sample(gen_tokens, logits, scores, cfg=sampler_cfg) - gen_tokens = jnp.concatenate((gen_tokens, next_token)) - print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True) - if jnp.isin(next_token, stop).any(): - break + xfmr_weights = load_weights(weights_path) + tokenizer = Tokenizer('entropix/tokenizer.model') + raw_tokens1 = tokenizer.encode(prompt, bos=False, eos=False, allowed_special='all') csv_path = Path('entropix/data/prompts.csv') prompts = create_prompts_from_csv(csv_path) PROMPT_TEST = False + # Create a random key + rng_key = jax.random.PRNGKey(0) + if PROMPT_TEST: for p in prompts: print(p) tokens = tokenizer.encode(p, bos=False, eos=False, allowed_special='all') - generate(xfmr_weights, model_params, tokens) + initial_state = initialize(model_params, tokens, 100) + vanilla_generate(xfmr_weights, model_params, tokenizer, initial_state, 100, rng_key) else: print(prompt) tokens = tokenizer.encode(prompt, bos=False, eos=False, allowed_special='all') - generate(xfmr_weights, model_params, tokens) + initial_state = initialize(model_params, tokens, 100) + vanilla_generate(xfmr_weights, model_params, tokenizer, initial_state, 100, rng_key) if __name__ == '__main__': - tyro.cli(main) + tyro.cli(main) \ No newline at end of file diff --git a/entropix/mcts.py b/entropix/mcts.py index a6b0289..9c13f1c 100644 --- a/entropix/mcts.py +++ b/entropix/mcts.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F from typing import Tuple from entropix.torch_main import calculate_varentropy_logsoftmax, _sample diff --git a/entropix/model.py b/entropix/model.py index 31acb91..41fbd2e 100644 --- a/entropix/model.py +++ b/entropix/model.py @@ -1,37 +1,23 @@ from typing import Optional, Tuple - import jax import jax.numpy as jnp - from functools import partial from entropix.config import ModelParams from entropix.kvcache import KVCache from entropix.stats import AttnStats +from entropix.stats import AttnStats from entropix.weights import XfmrWeights, LayerWeights - +from entropix.rope import apply_rotary_emb DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) -#@partial(jax.jit, static_argnames=("eps")) +# @partial(jax.jit, static_argnames=("eps")) def rms_norm(x: jax.Array, w: jax.Array, eps: float = 1e-6) -> jax.Array: return w * (x * jax.lax.rsqrt(jax.lax.pow(x, 2).mean(-1, keepdims=True) + eps)) - -#@partial(jax.jit, static_argnames=("dtype")) -def apply_rotary_emb(xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array, dtype: jnp.dtype = jnp.float32) -> Tuple[jax.Array, jax.Array]: - reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) - reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) - xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) - xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) - xq_out = xq_ * freqs_cis[None, :, None, :] - xk_out = xk_ * freqs_cis[None, :, None, :] - xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1) - xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1) - return xq_out.astype(dtype), xk_out.astype(dtype) - -#@partial(jax.jit, static_argnames=("model_params", "cur_pos", "layer_idx")) +# @partial(jax.jit, static_argnames=("model_params", "cur_pos", "layer_idx")) def attention(x: jax.Array, layer_weights: LayerWeights, model_params, cur_pos: int, layer_idx: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array] = None) -> Tuple[jax.Array, KVCache]: bsz, _, _ = x.shape n_rep = model_params.n_local_heads // model_params.n_local_kv_heads @@ -56,22 +42,17 @@ def attention(x: jax.Array, layer_weights: LayerWeights, model_params, cur_pos: out = jnp.dot(output, layer_weights.wo.T) return out, kvcache, pre_scores -#@partial(jax.jit) +# @partial(jax.jit) def feed_forward(x: jax.Array, layer_weights: LayerWeights) -> jax.Array: return jnp.dot(jax.nn.silu(jnp.dot(x, layer_weights.w1.T)) * jnp.dot(x, layer_weights.w3.T), layer_weights.w2.T) -#@partial(jax.jit, static_argnames=("model_params", "cur_pos")) -def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: jax.Array, cur_pos: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array]=None) -> Tuple[jax.Array, KVCache]: +# @partial(jax.jit, static_argnames=("model_params", "cur_pos")) +def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: jax.Array, cur_pos: int, freqs_cis: jax.Array, kvcache: KVCache, attn_stats: AttnStats, attn_mask: Optional[jax.Array]=None) -> Tuple[jax.Array, KVCache]: h = xfmr_weights.tok_embeddings[tokens] - attn_stats = AttnStats.new( - bsz=tokens.shape[0], - n_layers=model_params.n_layers, - n_heads=model_params.n_local_heads - ) for i in range(model_params.n_layers): norm_x = rms_norm(h, xfmr_weights.layer_weights[i].attention_norm) h_attn, kvcache, scores = attention(norm_x, xfmr_weights.layer_weights[i], model_params, cur_pos, i, freqs_cis, kvcache, attn_mask=attn_mask) - attn_stats = attn_stats.update(scores[:,:,-1,:], i) + attn_stats = attn_stats.update(scores,cur_pos,i) h = h + h_attn h = h + feed_forward(rms_norm(h, xfmr_weights.layer_weights[i].ffn_norm), xfmr_weights.layer_weights[i]) logits = jnp.dot(rms_norm(h, xfmr_weights.norm), xfmr_weights.output.T) diff --git a/entropix/rope.py b/entropix/rope.py new file mode 100644 index 0000000..495d588 --- /dev/null +++ b/entropix/rope.py @@ -0,0 +1,50 @@ + +from typing import Tuple +import jax.numpy as jnp +import jax +import math +from entropix.config import ModelParams, ScaledRopeParams + +#@partial(jax.jit, static_argnames=("dtype")) +def apply_rotary_emb(xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array, dtype: jnp.dtype = jnp.float32) -> Tuple[jax.Array, jax.Array]: + reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) + reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) + xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) + xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) + xq_out = xq_ * freqs_cis[None, :, None, :] + xk_out = xk_ * freqs_cis[None, :, None, :] + xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1) + xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1) + return xq_out.astype(dtype), xk_out.astype(dtype) + +def precompute_freqs_cis(model_params: ModelParams, dtype: jnp.dtype = jnp.float32) -> jax.Array: + dim = model_params.head_dim + freqs = 1.0 / (model_params.rope_theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) + if model_params.use_scaled_rope: + freqs = apply_scaling(model_params.scaled_rope_params, freqs) + freqs = jnp.outer(jnp.arange(model_params.max_seq_len, dtype=dtype), freqs) + return jnp.exp(1j * freqs) + +def apply_scaling(scaled_rope_params: ScaledRopeParams, freqs: jax.Array): + scale_factor = scaled_rope_params.scale_factor + low_freq_factor = scaled_rope_params.low_freq_factor + high_freq_factor = scaled_rope_params.high_freq_factor + old_context_len = scaled_rope_params.old_context_len # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + def scale_freq(freq): + wavelen = 2 * math.pi / freq + + def scale_mid(_): + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + return (1 - smooth) * freq / scale_factor + smooth * freq + + return jax.lax.cond( + wavelen < high_freq_wavelen, + lambda _: freq, + lambda _: jax.lax.cond(wavelen > low_freq_wavelen, lambda _: freq / scale_factor, scale_mid, None), + None + ) + return jax.vmap(scale_freq)(freqs) diff --git a/entropix/sampler.py b/entropix/sampler.py index e9b9420..a37c2ea 100644 --- a/entropix/sampler.py +++ b/entropix/sampler.py @@ -1,24 +1,8 @@ -from typing import Dict, Tuple - +from typing import Dict import chex -import jax import jax.numpy as jnp - -LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E - -@jax.jit -def calculate_varentropy_logsoftmax(logits: jnp.ndarray, axis: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Calculate the entropy and varentropy of the probability distribution using logsoftmax.""" - log_probs = jax.nn.log_softmax(logits, axis=axis) - probs = jnp.exp(log_probs) - entropy = -jnp.sum(probs * log_probs, axis=axis) / LN_2 # Convert to base-2 - varentropy = jnp.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, axis=axis) - return entropy, varentropy - -def multinomial_sample_one(probs_sort: jax.Array, key) -> jax.Array: - """Samples one token from a multinomial distribution with sorted probabilities.""" - q = jax.random.exponential(key=key, shape=probs_sort.shape) - return jnp.argmax(probs_sort / q, axis=-1, keepdims=True).astype(jnp.int32) +from entropix.utils import calculate_varentropy_logsoftmax, multinomial_sample_one +import jax def _sample( logits: jax.Array, *, temperature: float | jax.Array, top_p: float | jax.Array, top_k: int | jax.Array, min_p: float | jax.Array, key=jax.random.PRNGKey(1337),) -> jax.Array: @@ -45,17 +29,17 @@ def _sample( logits: jax.Array, *, temperature: float | jax.Array, top_p: float next_token_g = jnp.take_along_axis(probs_idx, next_token.reshape(bsz, 1), axis=-1) return next_token_g.astype(jnp.int32) -def calculate_metrics(logits: jnp.ndarray, attention_scores: jnp.ndarray) -> Dict[str, jnp.ndarray]: +def calculate_metrics(logits: jnp.ndarray, attn_scores: jnp.ndarray) -> Dict[str, jnp.ndarray]: entropy, varentropy = calculate_varentropy_logsoftmax(logits) - attention_probs = jax.nn.softmax(attention_scores, axis=-1) + attention_probs = jax.nn.softmax(attn_scores, axis=-1) attn_entropy = -jnp.sum(attention_probs * jnp.log2(jnp.clip(attention_probs, 1e-10, 1.0)), axis=-1) attn_varentropy = jnp.var(attn_entropy, axis=1) mean_attention = jnp.mean(attention_probs, axis=1) agreement = jnp.mean(jnp.abs(attention_probs - mean_attention[:, None, :]), axis=(1, 2)) - interaction_strength = jnp.mean(jnp.abs(attention_scores), axis=(1, 2, 3)) + interaction_strength = jnp.mean(jnp.abs(attn_scores), axis=(1, 2, 3)) return { "logits_entropy": jnp.mean(entropy), @@ -114,7 +98,6 @@ class SamplerConfig: ada_score_agree: float = 0.5 ada_score_int: float = 0.6 - def sample(gen_tokens: jax.Array, logits: jax.Array, attention_scores: jax.Array, cfg: SamplerConfig, clarifying_question_token: int = 2564, key=jax.random.PRNGKey(1337)) -> jax.Array: @@ -123,7 +106,6 @@ def sample(gen_tokens: jax.Array, logits: jax.Array, attention_scores: jax.Array attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"] agreement = metrics["agreement"] interaction_strength = metrics["interaction_strength"] - # Low Entropy, Low Varentropy: "flowing with unspoken intent" if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh: return jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32) @@ -151,7 +133,7 @@ def sample(gen_tokens: jax.Array, logits: jax.Array, attention_scores: jax.Array top_p_adj = max(0.5, cfg.top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high return _sample(logits, temperature=max(2.0, cfg.temp * temp_adj), top_p=top_p_adj, top_k=cfg.top_k, min_p=cfg.min_p, key=key) - # Middle ground: use adaptive sampling + # Middle ground: smooth transition else: logits_uncertainty = metrics["logits_entropy"] + metrics["logits_varentropy"] attn_uncertainty = metrics["attn_entropy"] + metrics["attn_varentropy"] diff --git a/entropix/stats.py b/entropix/stats.py index 9a54131..bbc5160 100644 --- a/entropix/stats.py +++ b/entropix/stats.py @@ -1,20 +1,21 @@ from typing import NamedTuple import jax import jax.numpy as jnp +from entropix.config import ModelParams class AttnStats(NamedTuple): - entropy: jax.Array # (bsz, n_layers, num_heads) - varentropy: jax.Array # (bsz, n_layers, num_heads) - n_layers: int - n_heads: int + scores: jax.Array # (bsz, seqlen, n_layers, num_heads, seqlen) + entropy: jax.Array # (bsz, seqlen, n_layers, num_heads) + varentropy: jax.Array # (bsz, seqlen, n_layers, num_heads) + @classmethod - def new(cls, bsz: int, n_layers: int, n_heads: int) -> 'AttnStats': + def new(cls, model_params: ModelParams, bsz: int, max_total_len: int) -> 'AttnStats': + n_heads, n_layers = model_params.n_local_heads, model_params.n_layers return cls( - entropy=jnp.zeros((bsz, n_layers, n_heads), dtype=jnp.float32), - varentropy=jnp.zeros((bsz, n_layers, n_heads), dtype=jnp.float32), - n_layers=n_layers, - n_heads=n_heads + entropy=jnp.zeros((bsz, max_total_len, n_layers, n_heads), dtype=jnp.float32), + varentropy=jnp.zeros((bsz, max_total_len, n_layers, n_heads), dtype=jnp.float32), + scores=jnp.zeros((bsz, max_total_len, n_layers, n_heads, max_total_len), dtype=jnp.float32) ) @property @@ -25,21 +26,21 @@ def avg_entropy(self): def std_error(self): return jnp.sqrt(jnp.mean(self.varentropy)) / (self.n_heads * self.n_layers) - def update(self, scores: jax.Array, layer_idx: int): - # scores shape: (bsz, n_heads, seqlen, n_words) + def update(self, scores: jax.Array, cur_pos: int, layer_idx: int): + # scores shape: (bsz, n_heads, seqlen, seqlen) + seqlen = scores.shape[-1] probs = jax.nn.softmax(scores, axis=-1) new_entropy = -jnp.sum(jnp.where(probs > 0, probs * jnp.log(probs), 0), axis=-1) - new_varentropy = jnp.sum(probs * (jnp.log(probs) + new_entropy[..., None])**2, axis=-1) - - # print(f"Layer {layer_idx} - Scores shape: {scores.shape}, Probs shape: {probs.shape}") - # print(f"Layer {layer_idx} - New entropy shape: {new_entropy.shape}, Min: {jnp.min(new_entropy)}, Max: {jnp.max(new_entropy)}") - - updated_stats = self._replace( - entropy=self.entropy.at[:, layer_idx, :].set(new_entropy), - varentropy=self.varentropy.at[:, layer_idx, :].set(new_varentropy) + new_varentropy = jnp.sum(probs * (jnp.log(probs) + new_entropy[..., None])**2, axis=-1) + if cur_pos==0: + return self._replace( + scores=self.scores.at[:, :seqlen, layer_idx, :, :seqlen].set(scores.transpose(0,2,1,3)), # (bsz, seqlen, n_layers, n_heads, seqlen) <-- (bsz, seqlen, n_heads, seqlen) + entropy=self.entropy.at[:, :seqlen, layer_idx, :].set(new_entropy.transpose(0,2,1)), + varentropy=self.varentropy.at[:, :seqlen, layer_idx, :].set(new_varentropy.transpose(0,2,1)) ) - - # print(f"Layer {layer_idx} - Updated entropy shape: {updated_stats.entropy.shape}") - # print(f"Layer {layer_idx} - Updated entropy for this layer: {updated_stats.entropy[:, layer_idx, :]}") - - return updated_stats \ No newline at end of file + else: + return self._replace( + scores=self.scores.at[:, cur_pos, layer_idx, :, :].set(scores[:,:,-1,:]), # (bsz, seqlen, n_layers, n_heads, seqlen) <-- (bsz, seqlen, n_heads, seqlen) + entropy=self.entropy.at[:, cur_pos, layer_idx, :].set(new_entropy[...,-1]), + varentropy=self.varentropy.at[:, cur_pos, layer_idx, :].set(new_varentropy[...,-1]) + ) \ No newline at end of file diff --git a/entropix/utils.py b/entropix/utils.py new file mode 100644 index 0000000..9a18ec0 --- /dev/null +++ b/entropix/utils.py @@ -0,0 +1,19 @@ +import jax +import jax.numpy as jnp +from typing import Tuple + +LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E + +@jax.jit +def calculate_varentropy_logsoftmax(logits: jnp.ndarray, axis: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Calculate the entropy and varentropy of the probability distribution using logsoftmax.""" + log_probs = jax.nn.log_softmax(logits, axis=axis) + probs = jnp.exp(log_probs) + entropy = -jnp.sum(probs * log_probs, axis=axis) / LN_2 # Convert to base-2 + varentropy = jnp.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, axis=axis) + return entropy, varentropy + +def multinomial_sample_one(probs_sort: jax.Array, key) -> jax.Array: + """Samples one token from a multinomial distribution with sorted probabilities.""" + q = jax.random.exponential(key=key, shape=probs_sort.shape) + return jnp.argmax(probs_sort / q, axis=-1, keepdims=True).astype(jnp.int32)