-
Notifications
You must be signed in to change notification settings - Fork 320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
~nothing #9
Closed
Closed
~nothing #9
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
426281c
attn_entropy visualization
doomslide 966b1a6
Merge branch 'main' of https://github.com/xjdr-alt/entropix
doomslide c080dde
factored stats out of attention
doomslide 1773362
factored stats out of attention
doomslide 7bc0599
merged with origin
doomslide File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from typing import NamedTuple | ||
from dataclasses import dataclass | ||
from entropix.model import KVCache | ||
import math | ||
import jax | ||
import jax.numpy as jnp | ||
from entropix.model import ModelParams | ||
|
||
@dataclass | ||
class LMState: | ||
prompt: jax.Array # (bsz, prompt_len) | ||
gen_tokens: jax.Array # (bsz, seq_len) | ||
logits: jax.Array # (bsz, n_words) | ||
kvcache: KVCache | ||
freqs_cis: jax.Array | ||
attn_mask: jax.Array | ||
pos: jax.Array # (bsz, max_seq_len) | ||
state: jax.Array # (bsz, n_states) which state are we? (flow, turn, fork, explore...) | ||
|
||
@property | ||
def context(self) -> jnp.ndarray: | ||
return jnp.concatenate((self.prompt, self.gen_tokens), axis=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,26 @@ | ||
from typing import NamedTuple | ||
import jax | ||
|
||
|
||
params = { | ||
"dim": 2048, | ||
"n_layers": 16, | ||
"n_heads": 32, | ||
"n_kv_heads": 8, | ||
"vocab_size": 128256, | ||
"ffn_dim_multiplier": 1.5, | ||
"multiple_of": 256, | ||
"norm_eps": 1e-05, | ||
"rope_theta": 500000.0, | ||
"use_scaled_rope": True, | ||
"max_seq_len": 4096 | ||
} | ||
|
||
class RopeParams(NamedTuple): | ||
rope_theta: float | ||
use_scaled_rope: bool | ||
scale_factor: int # = 8 | ||
low_freq_factor: int # = 1 | ||
high_freq_factor: int # = 4 | ||
old_context_len: int # = 8192 (original llama3 length) | ||
|
||
class ModelParams(NamedTuple): | ||
n_layers: int | ||
n_local_heads: int | ||
n_local_kv_heads: int | ||
head_dim: int | ||
max_seq_len: int | ||
rope_theta: float | ||
use_scaled_rope: bool | ||
|
||
rope_params: RopeParams | ||
d_model: int | ||
|
||
LLAMA_1B_PARAMS = ModelParams( | ||
n_layers=params["n_layers"], | ||
n_local_heads=params["n_heads"], | ||
n_local_kv_heads=params["n_kv_heads"], | ||
head_dim=params["dim"] // params["n_heads"], | ||
max_seq_len=params["max_seq_len"], | ||
rope_theta=params["rope_theta"], | ||
use_scaled_rope=params["use_scaled_rope"] | ||
) | ||
class SamplerParams(NamedTuple): | ||
steer_tokens: jax.Array | ||
stop_tokens: jax.Array | ||
base_temp: float | ||
base_top_p: float | ||
base_top_k: int |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from entropix.model import KVCache | ||
from entropix.rope import precompute_freqs_cis | ||
from entropix.sampler import sample | ||
from entropix.LMState import LMState | ||
from entropix.model import xfmr | ||
from entropix.sampler import SamplerParams | ||
from entropix.tokenizer import Tokenizer | ||
from entropix.config import ModelParams, RopeParams | ||
import jax.numpy as jnp | ||
import jax | ||
|
||
def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array: | ||
mask = jnp.zeros((seqlen, seqlen), dtype=jnp.float32) | ||
if seqlen > 1: | ||
mask = jnp.full((seqlen, seqlen), float('-inf')) | ||
mask = jnp.triu(mask, k=1) | ||
mask = jnp.hstack([jnp.zeros((seqlen, start_pos)), mask], dtype=jnp.float32) | ||
return mask | ||
|
||
def generate(xfmr_weights, model_params, sampler_params, tokenizer, tokens): | ||
tokens = jnp.array([tokens], jnp.int32) | ||
n_words = tokenizer.n_words | ||
bsz, seqlen = tokens.shape | ||
lm_state = LMState( | ||
prompt=tokens, | ||
logits=jnp.zeros((bsz, n_words), dtype=jnp.bfloat16), | ||
freqs_cis=precompute_freqs_cis(head_dim=model_params.head_dim, max_seq_len=model_params.max_seq_len, rope_params=model_params.rope_params), | ||
kvcache=KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim), | ||
attn_mask=build_attn_mask(seqlen, 0), | ||
gen_tokens=jnp.zeros((bsz, 0), dtype=jnp.int32), | ||
state=jnp.zeros((bsz, 1), dtype=jnp.int32), | ||
pos=0 | ||
) | ||
lm_state.logits, lm_state.kvcache, _ = xfmr(xfmr_weights, model_params, lm_state.context, lm_state.pos, freqs_cis=lm_state.freqs_cis[:seqlen], kvcache=lm_state.kvcache, attn_mask=lm_state.attn_mask) | ||
next_token = jnp.argmax(lm_state.logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32) | ||
lm_state.gen_tokens, lm_state.pos = jnp.concatenate((lm_state.gen_tokens, next_token), axis=1), seqlen | ||
print(tokenizer.decode([next_token.item()]), end='', flush=True) | ||
#stop = jnp.array(tokenizer.stop_tokens) | ||
while lm_state.pos < 2048: | ||
lm_state.pos += 1 | ||
lm_state.logits, lm_state.kvcache, _ = xfmr(xfmr_weights, model_params, next_token, lm_state.pos, lm_state.freqs_cis[lm_state.pos:lm_state.pos+1], lm_state.kvcache) | ||
next_token = sample(sampler_params, lm_state) | ||
lm_state.gen_tokens = jnp.concatenate((lm_state.gen_tokens, next_token), axis=1) | ||
print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True) | ||
if jnp.isin(next_token, sampler_params.stop_tokens).any(): | ||
break | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge tags left in here