Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors #27

Open
wants to merge 60 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
426281c
attn_entropy visualization
doomslide Oct 5, 2024
966b1a6
Merge branch 'main' of https://github.com/xjdr-alt/entropix
doomslide Oct 5, 2024
c080dde
factored stats out of attention
doomslide Oct 5, 2024
1773362
factored stats out of attention
doomslide Oct 5, 2024
7bc0599
merged with origin
doomslide Oct 5, 2024
167effe
initial
doomslide Oct 5, 2024
793feae
lets goo
doomslide Oct 5, 2024
7d450f8
added stft
doomslide Oct 5, 2024
530e793
merged with origin
doomslide Oct 5, 2024
029d844
attn_entropy visualization
doomslide Oct 5, 2024
caec96f
Merge branch 'main' of https://github.com/xjdr-alt/entropix
doomslide Oct 5, 2024
5e8a241
repeated 'def main'
doomslide Oct 5, 2024
6c6a516
removed repeated 'def main'
doomslide Oct 5, 2024
65b7a6c
merged main
doomslide Oct 5, 2024
8bbc014
deleted tokenizer directory
doomslide Oct 5, 2024
982e247
renaming
doomslide Oct 5, 2024
36d7711
implemented forensics
doomslide Oct 5, 2024
4c6b7ad
forensics still not working
doomslide Oct 6, 2024
1abb5bc
attn_entropy visualization
doomslide Oct 5, 2024
a26ded2
factored stats out of attention
doomslide Oct 5, 2024
f71de86
factored stats out of attention
doomslide Oct 5, 2024
627b82b
attn_entropy visualization
doomslide Oct 5, 2024
06d3f2f
Merge branch 'main' of https://github.com/xjdr-alt/entropix
doomslide Oct 5, 2024
99a3696
Merge branch 'shrek' into frog
doomslide Oct 6, 2024
835e6a8
cleaned some dependencies
doomslide Oct 6, 2024
a09a6da
rope tests
doomslide Oct 6, 2024
4167572
moved shit around
doomslide Oct 6, 2024
d900250
minor
doomslide Oct 6, 2024
f621d9a
rewrite of xmfr
doomslide Oct 6, 2024
a887232
Merge branch 'main' into frog
doomslide Oct 6, 2024
4f59cc1
trying stuff
doomslide Oct 6, 2024
471e6c6
Merge remote-tracking branch 'origin' into frog
doomslide Oct 6, 2024
f4e34f7
in the process of factoring xfmr_plus
doomslide Oct 6, 2024
e410d40
still weird
doomslide Oct 6, 2024
c729795
still off...
doomslide Oct 6, 2024
cb33990
Merge remote-tracking branch 'origin' into forensics
doomslide Oct 6, 2024
80aa926
fixed params
doomslide Oct 6, 2024
fa5d25f
small fix
doomslide Oct 7, 2024
d114369
Merge branch 'frog' into refactor
doomslide Oct 7, 2024
bc4d15c
refactored calls to sampler
doomslide Oct 7, 2024
bad1767
refactored rope
doomslide Oct 7, 2024
5a8d007
fixed score update
doomslide Oct 7, 2024
c6ac8aa
refactors
doomslide Oct 7, 2024
f612a75
refactoring around attention stats
doomslide Oct 7, 2024
7456c88
Merge remote-tracking branch 'origin' into refactor
doomslide Oct 7, 2024
33d1a25
refactors of kvcache and attn_stats
doomslide Oct 7, 2024
04b7049
removed lm_state
doomslide Oct 7, 2024
c72bce2
.
doomslide Oct 7, 2024
15fa837
Merge remote-tracking branch 'origin' into forensics
doomslide Oct 7, 2024
db77fe3
Merge branch 'refactor' into forensics
doomslide Oct 7, 2024
6736035
factored generation out of main and initialization out of generation
doomslide Oct 7, 2024
9f15bf0
.
doomslide Oct 7, 2024
c864f3b
.
doomslide Oct 7, 2024
4e931e3
cleaned imports
doomslide Oct 7, 2024
60499c8
Delete entropix/lm_state.py
doomslide Oct 7, 2024
7101fa0
added vanilla generator
doomslide Oct 7, 2024
c979592
Merge remote-tracking branch 'fork/refactor' into refactor
doomslide Oct 7, 2024
f4233ef
.
doomslide Oct 7, 2024
5e1de5a
cleaned some accidental nonsense
doomslide Oct 7, 2024
5e9e79c
some scoring for ar generation
doomslide Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added FETCH_HEAD
Empty file.
38 changes: 25 additions & 13 deletions entropix/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
from typing import NamedTuple

class ScaledRopeParams(NamedTuple):
scale_factor: int # = 8
low_freq_factor: int # = 1
high_freq_factor: int # = 4
old_context_len: int # = 8192 (original llama3 length)

class ModelParams(NamedTuple):
n_layers: int
n_local_heads: int
n_local_kv_heads: int
head_dim: int
vocab_size: int
max_seq_len: int
scaled_rope_params: ScaledRopeParams
use_scaled_rope: bool
rope_theta: float

params = {
"dim": 2048,
Expand All @@ -15,23 +31,19 @@
"max_seq_len": 4096
}


class ModelParams(NamedTuple):
n_layers: int
n_local_heads: int
n_local_kv_heads: int
head_dim: int
max_seq_len: int
rope_theta: float
use_scaled_rope: bool


LLAMA_1B_PARAMS = ModelParams(
n_layers=params["n_layers"],
n_local_heads=params["n_heads"],
n_local_kv_heads=params["n_kv_heads"],
head_dim=params["dim"] // params["n_heads"],
vocab_size=params["vocab_size"],
max_seq_len=params["max_seq_len"],
rope_theta=params["rope_theta"],
use_scaled_rope=params["use_scaled_rope"]
)
use_scaled_rope=params["use_scaled_rope"],
scaled_rope_params=ScaledRopeParams(
scale_factor=8,
low_freq_factor=1,
high_freq_factor=4,
old_context_len=8192
)
)
135 changes: 135 additions & 0 deletions entropix/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from entropix.model import KVCache, xfmr
from entropix.rope import precompute_freqs_cis
from entropix.sampler import sample, SamplerConfig
from entropix.config import ModelParams
from entropix.stats import AttnStats
from entropix.sampler import SamplerConfig
from entropix.model import xfmr
from entropix.sampler import _sample
from entropix.utils import calculate_varentropy_logsoftmax

import jax.numpy as jnp
import jax
from typing import NamedTuple
class InitialState(NamedTuple):
tokens: jax.Array
kvcache: KVCache
attn_stats: AttnStats
freqs_cis: jax.Array
attn_mask: jax.Array
stop_tokens: jax.Array
sampler_cfg: SamplerConfig
logits_cache: jax.Array


def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array:
mask = jnp.zeros((seqlen, seqlen), dtype=jnp.float32)
if seqlen > 1:
mask = jnp.full((seqlen, seqlen), float('-inf'))
mask = jnp.triu(mask, k=1)
mask = jnp.hstack([jnp.zeros((seqlen, start_pos)), mask], dtype=jnp.float32)
return mask


def initialize(model_params, tokens, max_gen_len):
tokens = jnp.array([tokens], dtype=jnp.int32)
bsz, seqlen = tokens.shape
max_total_len = seqlen + max_gen_len
attn_mask = build_attn_mask(seqlen, 0)
freqs_cis = precompute_freqs_cis(model_params)
kvcache = KVCache.new(model_params, bsz, max_total_len)
attn_stats = AttnStats.new(model_params, bsz, max_total_len)
logits_cache = jnp.zeros((bsz, max_total_len, model_params.vocab_size), dtype=jnp.float32)
stop = jnp.array([128001, 128008, 128009], dtype=jnp.int32)
sampler_cfg = SamplerConfig()
return {
'tokens': tokens,
'kvcache': kvcache,
'attn_stats': attn_stats,
'freqs_cis': freqs_cis,
'attn_mask': attn_mask,
'stop_tokens': stop,
'sampler_cfg': sampler_cfg,
'logits_cache': logits_cache,
}

def generate(xfmr_weights, model_params, tokenizer, initial_state, max_gen_len):
kvcache = initial_state['kvcache']
attn_stats = initial_state['attn_stats']
attn_mask = initial_state['attn_mask']
freqs_cis = initial_state['freqs_cis']
stop_tokens = initial_state['stop_tokens']
sampler_cfg = initial_state['sampler_cfg']
tokens = initial_state['tokens']

prompt_len = tokens.shape[1]

logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, 0, freqs_cis[:prompt_len], kvcache, attn_stats, attn_mask=attn_mask)
cur_pos, max_total_len = prompt_len, prompt_len + max_gen_len
next_token = jnp.argmax(logits[:,-1], axis=-1, keepdims=True).astype(jnp.int32)
gen_tokens = next_token
print(tokenizer.decode([next_token.item()]), end='', flush=True)
while cur_pos < max_total_len:
cur_pos += 1
logits, kvcache, scores, attn_stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache, attn_stats)
next_token = sample(gen_tokens, logits, scores, cfg=sampler_cfg)
gen_tokens = jnp.concatenate((gen_tokens, next_token))
print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
if jnp.isin(next_token, stop_tokens).any():
break

def vanilla_generate(xfmr_weights, model_params, tokenizer, initial_state, n_gen_tokens, rng):
kvcache = initial_state['kvcache']
attn_stats = initial_state['attn_stats']
attn_mask = initial_state['attn_mask']
freqs_cis = initial_state['freqs_cis']
sampler_cfg = initial_state['sampler_cfg']
logits_cache = initial_state['logits_cache']
tokens = initial_state['tokens']

prompt_len = tokens.shape[1]

logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, 0, freqs_cis[:prompt_len], kvcache, attn_stats, attn_mask=attn_mask)
logits_cache = logits_cache.at[:, :prompt_len, :].set(logits)
cur_pos, max_total_len = prompt_len, prompt_len + n_gen_tokens
next_token = _sample(logits_cache[:, cur_pos:cur_pos+1, :], temperature=sampler_cfg.temp, min_p=sampler_cfg.min_p, top_k=sampler_cfg.top_k, top_p=sampler_cfg.top_p, key=rng)
gen_tokens = next_token

print(tokenizer.decode([next_token.item()]), end='', flush=True)
while cur_pos < max_total_len:
cur_pos += 1
logits, kvcache, scores, attn_stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache, attn_stats)
logits_cache = logits_cache.at[:, cur_pos:cur_pos+1, :].set(logits)
next_token = _sample(logits, temperature=sampler_cfg.temp, min_p=sampler_cfg.min_p, top_k=sampler_cfg.top_k, top_p=sampler_cfg.top_p, key=rng)
gen_tokens = jnp.concatenate((gen_tokens, next_token))
print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
return gen_tokens, logits_cache, attn_stats

def score_N(xfmr_weights: jax.Array, model_params: ModelParams, tokens: jax.Array, start_pos: int, N:int):
"""
This function calculates a model's scoring of a (batch of) sequence(s) of tokens in various ways.

tokens: jax.Array, shape (batch_size, tokens.shape[1], N)
start_pos: int, the position in the sequence to start scoring
N: int, the number of sequences to score
"""
initial_state = initialize(model_params, tokens, 1)
seqlen = tokens.shape[1]
logits, _, scores, attn_stats = xfmr(
xfmr_weights=xfmr_weights,
model_params=model_params,
cur_pos=0,
tokens=initial_state['tokens'],
freqs_cis=initial_state['freqs_cis'][:seqlen],
kvcache=initial_state['kvcache'],
attn_stats=initial_state['attn_stats'],
attn_mask=initial_state['attn_mask']
)
shape = logits.shape # (batch_size, tokens.shape[1]*N, vocab_size)
logits = logits.reshape(tokens.shape[0], tokens.shape[1], N, model_params.vocab_size).transpose(0, 1, 3, 2) # (batch_size, tokens.shape[1], vocab_size, N) <--(batch_size, tokens.shape[1]*N, vocab_size)
log_probs = jax.nn.log_softmax(logits, axis=2)
log_joint_probs = log_probs.sum(axis=1)
joint_entropy, joint_varentropy = calculate_varentropy_logsoftmax(log_joint_probs, axis=-1) # (batch_size, tokens.shape[1], vocab_size)
log_likelihood = jnp.take_along_axis(log_probs[:, start_pos-1:-1,:], tokens[:, start_pos:, :, None], axis=-1).squeeze(-1) # (batch_size, tokens.shape[1]-start_pos, N)
cross_entropy = log_likelihood.sum(axis=1) # (batch_size, N)
return cross_entropy, joint_entropy, joint_varentropy
23 changes: 14 additions & 9 deletions entropix/kvcache.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
from typing import NamedTuple

import jax
import jax.numpy as jnp
from entropix.config import ModelParams
from typing import NamedTuple



class KVCache(NamedTuple):
k: jax.Array
v: jax.Array

@classmethod
def new(cls, layers: int, bsz: int, max_seq_len: int, kv_heads: int, head_dim: int) -> 'KVCache':
return cls(
k=jnp.zeros((layers, bsz, max_seq_len, kv_heads, head_dim), dtype=jnp.bfloat16),
v=jnp.zeros((layers, bsz, max_seq_len, kv_heads, head_dim), dtype=jnp.bfloat16)
)

def new(cls, model_params: ModelParams, bsz: int, max_total_len) -> 'KVCache':
kv_heads = model_params.n_local_kv_heads
layers = model_params.n_layers
head_dim = model_params.head_dim
return cls(
k=jnp.zeros((layers, bsz, max_total_len, kv_heads, head_dim), dtype=jnp.bfloat16),
v=jnp.zeros((layers, bsz, max_total_len, kv_heads, head_dim), dtype=jnp.bfloat16)
)

def update(self, xk: jax.Array, xv: jax.Array, layer_idx: int, cur_pos: int, n_rep: int):
ck = jax.lax.dynamic_update_slice(self.k, jnp.bfloat16(xk[None, ...]), (layer_idx, 0, cur_pos, 0, 0))
cv = jax.lax.dynamic_update_slice(self.v, jnp.bfloat16(xv[None, ...]), (layer_idx, 0, cur_pos, 0, 0))
Expand All @@ -27,3 +31,4 @@ def update(self, xk: jax.Array, xv: jax.Array, layer_idx: int, cur_pos: int, n_r

return keys, values, KVCache(k=ck, v=cv)


83 changes: 15 additions & 68 deletions entropix/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,17 @@
import jax.numpy as jnp
import tyro

from pathlib import Path

from entropix.config import LLAMA_1B_PARAMS
from entropix.kvcache import KVCache
from entropix.model import xfmr
from entropix.sampler import SamplerConfig, sample
from entropix.prompts import create_prompts_from_csv, prompt
from entropix.sampler import sample
from entropix.tokenizer import Tokenizer
from entropix.weights import load_weights
from entropix.generator import generate, vanilla_generate, initialize

DEFAULT_WEIGHTS_PATH = Path(__file__).parent / '../weights'

def apply_scaling(freqs: jax.Array):
SCALE_FACTOR = 8
LOW_FREQ_FACTOR = 1
HIGH_FREQ_FACTOR = 4
OLD_CONTEXT_LEN = 8192 # original llama3 length

low_freq_wavelen = OLD_CONTEXT_LEN / LOW_FREQ_FACTOR
high_freq_wavelen = OLD_CONTEXT_LEN / HIGH_FREQ_FACTOR

def scale_freq(freq):
wavelen = 2 * math.pi / freq

def scale_mid(_):
smooth = (OLD_CONTEXT_LEN / wavelen - LOW_FREQ_FACTOR) / (HIGH_FREQ_FACTOR - LOW_FREQ_FACTOR)
return (1 - smooth) * freq / SCALE_FACTOR + smooth * freq

return jax.lax.cond(
wavelen < high_freq_wavelen,
lambda _: freq,
lambda _: jax.lax.cond(wavelen > low_freq_wavelen, lambda _: freq / SCALE_FACTOR, scale_mid, None),
None
)

return jax.vmap(scale_freq)(freqs)


def precompute_freqs_cis(dim: int, end: int, theta: float = 500000.0, use_scaled: bool = False, dtype: jnp.dtype = jnp.float32) -> jax.Array:
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
if use_scaled:
freqs = apply_scaling(freqs)
t = jnp.arange(end, dtype=dtype)
freqs = jnp.outer(t, freqs)
return jnp.exp(1j * freqs)

DEFAULT_WEIGHTS_PATH = Path(__file__).parent / '../weights'

def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array:
mask = jnp.zeros((seqlen, seqlen), dtype=jnp.float32)
Expand All @@ -62,47 +28,28 @@ def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array:

def main(weights_path: Path = DEFAULT_WEIGHTS_PATH.joinpath('1B-Instruct')):
model_params = LLAMA_1B_PARAMS
xfmr_weights = load_weights(weights_path.absolute())
tokenizer = Tokenizer('entropix/tokenizer.model')

# Create the batch of tokens
def generate(xfmr_weights, model_params, tokens):
gen_tokens = None
cur_pos = 0
tokens = jnp.array([tokens], jnp.int32)
bsz, seqlen = tokens.shape
attn_mask = build_attn_mask(seqlen, cur_pos)
freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope)
kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim)
logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
next_token = jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)
gen_tokens = next_token
print(tokenizer.decode([next_token.item()]), end='', flush=True)
cur_pos = seqlen
stop = jnp.array([128001, 128008, 128009])
sampler_cfg = SamplerConfig()
while cur_pos < 8192:
cur_pos += 1
logits, kvcache, scores, stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache)
next_token = sample(gen_tokens, logits, scores, cfg=sampler_cfg)
gen_tokens = jnp.concatenate((gen_tokens, next_token))
print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
if jnp.isin(next_token, stop).any():
break
xfmr_weights = load_weights(weights_path)

tokenizer = Tokenizer('entropix/tokenizer.model')
raw_tokens1 = tokenizer.encode(prompt, bos=False, eos=False, allowed_special='all')
csv_path = Path('entropix/data/prompts.csv')
prompts = create_prompts_from_csv(csv_path)
PROMPT_TEST = False

# Create a random key
rng_key = jax.random.PRNGKey(0)

if PROMPT_TEST:
for p in prompts:
print(p)
tokens = tokenizer.encode(p, bos=False, eos=False, allowed_special='all')
generate(xfmr_weights, model_params, tokens)
initial_state = initialize(model_params, tokens, 100)
vanilla_generate(xfmr_weights, model_params, tokenizer, initial_state, 100, rng_key)
else:
print(prompt)
tokens = tokenizer.encode(prompt, bos=False, eos=False, allowed_special='all')
generate(xfmr_weights, model_params, tokens)
initial_state = initialize(model_params, tokens, 100)
vanilla_generate(xfmr_weights, model_params, tokenizer, initial_state, 100, rng_key)

if __name__ == '__main__':
tyro.cli(main)
tyro.cli(main)
1 change: 0 additions & 1 deletion entropix/mcts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn.functional as F
from typing import Tuple

from entropix.torch_main import calculate_varentropy_logsoftmax, _sample
Expand Down
Loading