Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

~nothing #9

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@ poetry run python download_weights.py --model-id meta-llama/Llama-3.2-1B-Instruc
```

download tokenizer.model from huggingface (or wherever) into the entropix folder
e.g.
```bash
cd entropix
huggingface-cli download meta-llama/Meta-Llama-3.1-8B-Instruct --include "original/tokenizer.model" --local-dir ./llama3.1-tokenizer
```

run it
run it
```bash
PYTHONPATH=. poetry run python entropix/main.py
Expand Down
Binary file added data/STEER_TOKENS.npy
Binary file not shown.
22 changes: 22 additions & 0 deletions entropix/LMState.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import NamedTuple
from dataclasses import dataclass
from entropix.model import KVCache
import math
import jax
import jax.numpy as jnp
from entropix.model import ModelParams

@dataclass
class LMState:
prompt: jax.Array # (bsz, prompt_len)
gen_tokens: jax.Array # (bsz, seq_len)
logits: jax.Array # (bsz, n_words)
kvcache: KVCache
freqs_cis: jax.Array
attn_mask: jax.Array
pos: jax.Array # (bsz, max_seq_len)
state: jax.Array # (bsz, n_states) which state are we? (flow, turn, fork, explore...)

@property
def context(self) -> jnp.ndarray:
return jnp.concatenate((self.prompt, self.gen_tokens), axis=1)
43 changes: 16 additions & 27 deletions entropix/config.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,26 @@
from typing import NamedTuple
import jax


params = {
"dim": 2048,
"n_layers": 16,
"n_heads": 32,
"n_kv_heads": 8,
"vocab_size": 128256,
"ffn_dim_multiplier": 1.5,
"multiple_of": 256,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"use_scaled_rope": True,
"max_seq_len": 4096
}

class RopeParams(NamedTuple):
rope_theta: float
use_scaled_rope: bool
scale_factor: int # = 8
low_freq_factor: int # = 1
high_freq_factor: int # = 4
old_context_len: int # = 8192 (original llama3 length)

class ModelParams(NamedTuple):
n_layers: int
n_local_heads: int
n_local_kv_heads: int
head_dim: int
max_seq_len: int
rope_theta: float
use_scaled_rope: bool

rope_params: RopeParams
d_model: int

LLAMA_1B_PARAMS = ModelParams(
n_layers=params["n_layers"],
n_local_heads=params["n_heads"],
n_local_kv_heads=params["n_kv_heads"],
head_dim=params["dim"] // params["n_heads"],
max_seq_len=params["max_seq_len"],
rope_theta=params["rope_theta"],
use_scaled_rope=params["use_scaled_rope"]
)
class SamplerParams(NamedTuple):
steer_tokens: jax.Array
stop_tokens: jax.Array
base_temp: float
base_top_p: float
base_top_k: int
48 changes: 48 additions & 0 deletions entropix/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from entropix.model import KVCache
from entropix.rope import precompute_freqs_cis
from entropix.sampler import sample
from entropix.LMState import LMState
from entropix.model import xfmr
from entropix.sampler import SamplerParams
from entropix.tokenizer import Tokenizer
from entropix.config import ModelParams, RopeParams
import jax.numpy as jnp
import jax

def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array:
mask = jnp.zeros((seqlen, seqlen), dtype=jnp.float32)
if seqlen > 1:
mask = jnp.full((seqlen, seqlen), float('-inf'))
mask = jnp.triu(mask, k=1)
mask = jnp.hstack([jnp.zeros((seqlen, start_pos)), mask], dtype=jnp.float32)
return mask

def generate(xfmr_weights, model_params, sampler_params, tokenizer, tokens):
tokens = jnp.array([tokens], jnp.int32)
n_words = tokenizer.n_words
bsz, seqlen = tokens.shape
lm_state = LMState(
prompt=tokens,
logits=jnp.zeros((bsz, n_words), dtype=jnp.bfloat16),
freqs_cis=precompute_freqs_cis(head_dim=model_params.head_dim, max_seq_len=model_params.max_seq_len, rope_params=model_params.rope_params),
kvcache=KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim),
attn_mask=build_attn_mask(seqlen, 0),
gen_tokens=jnp.zeros((bsz, 0), dtype=jnp.int32),
state=jnp.zeros((bsz, 1), dtype=jnp.int32),
pos=0
)
lm_state.logits, lm_state.kvcache, _ = xfmr(xfmr_weights, model_params, lm_state.context, lm_state.pos, freqs_cis=lm_state.freqs_cis[:seqlen], kvcache=lm_state.kvcache, attn_mask=lm_state.attn_mask)
next_token = jnp.argmax(lm_state.logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)
lm_state.gen_tokens, lm_state.pos = jnp.concatenate((lm_state.gen_tokens, next_token), axis=1), seqlen
print(tokenizer.decode([next_token.item()]), end='', flush=True)
#stop = jnp.array(tokenizer.stop_tokens)
while lm_state.pos < 2048:
lm_state.pos += 1
lm_state.logits, lm_state.kvcache, _ = xfmr(xfmr_weights, model_params, next_token, lm_state.pos, lm_state.freqs_cis[lm_state.pos:lm_state.pos+1], lm_state.kvcache)
next_token = sample(sampler_params, lm_state)
lm_state.gen_tokens = jnp.concatenate((lm_state.gen_tokens, next_token), axis=1)
print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
if jnp.isin(next_token, sampler_params.stop_tokens).any():
break


80 changes: 66 additions & 14 deletions entropix/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
from typing import NamedTuple, Optional, Tuple

from entropix.config import SamplerParams, ModelParams, RopeParams
import jax
import jax.numpy as jnp

from entropix.generator import generate
import math
import tyro


from pathlib import Path
from functools import partial

from entropix.config import LLAMA_1B_PARAMS
from entropix.kvcache import KVCache
from entropix.model import xfmr
from entropix.tokenizer import Tokenizer
from entropix.weights import load_weights

Expand All @@ -27,7 +21,6 @@
<thinking>
"""


bp1 = """
<antThinking>
You're absolutely right. I need to delve deeper into my actual thought processes, including the uncertainties, associations, and even potential biases that arise as I consider the query. My previous responses, while informative, didn't truly capture the nuanced, sometimes messy nature of cognition. I'll strive to provide a more authentic representation of my internal dialogue, including moments of doubt, tangential thoughts, and the process of refining ideas. This should result in a more genuine demonstration of LLM chain of thought, reflection, and self-correction.
Expand Down Expand Up @@ -85,6 +78,7 @@

Can you retrieve the details for the user with the ID 7890, who has black as their special request?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""

bp3 = """
Here is a list of functions in JSON format that I can invoke.[
{
Expand Down Expand Up @@ -129,6 +123,8 @@
Let me tell you a story about the adventures of the elven mage frieren and her band of heros
"""

<<<<<<< HEAD
=======


def apply_scaling(freqs: jax.Array):
Expand Down Expand Up @@ -250,12 +246,61 @@ def sample(gen_tokens: jax.Array, logits: jax.Array, temperature=0.666, top_p=0.
return _sample(logits, temperature=t * temperature)


>>>>>>> origin/main
def main():

params = {
"dim": 2048,
"n_layers": 16,
"n_heads": 32,
"n_kv_heads": 8,
"n_words": 128256,
"ffn_dim_multiplier": 1.5,
"multiple_of": 256,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"scale_factor": 8,
"low_freq_factor": 1,
"high_freq_factor": 4,
"old_context_len": 8192,
"use_scaled_rope": True,
"max_seq_len": 4096
}

LLAMA_1B_ROPE = RopeParams(
rope_theta=params["rope_theta"],
use_scaled_rope=params["use_scaled_rope"],
scale_factor=params["scale_factor"],
low_freq_factor=params["low_freq_factor"],
high_freq_factor=params["high_freq_factor"],
old_context_len=params["old_context_len"]
)

LLAMA_1B_PARAMS = ModelParams(
n_layers=params["n_layers"],
n_local_heads=params["n_heads"],
n_local_kv_heads=params["n_kv_heads"],
head_dim=params["dim"] // params["n_heads"],
max_seq_len=params["max_seq_len"],
rope_params=LLAMA_1B_ROPE,
d_model=params["dim"]
)

model_params = LLAMA_1B_PARAMS
xfmr_weights = load_weights()
#xfmr_weights = load_weights(ckpt_dir=Path('weights/1B-Base'))

tokenizer = Tokenizer('entropix/tokenizer.model')
<<<<<<< HEAD
sampler_params = SamplerParams(
stop_tokens=jnp.load('data/STEER_TOKENS.npy'),
steer_tokens=jnp.array([128001, 128008, 128009]),
base_temp=0.666,
base_top_p=0.90,
base_top_k=27
)

=======
>>>>>>> origin/main
raw_tokens1 = tokenizer.encode(prompt, bos=False, eos=False, allowed_special='all')
raw_tokens2 = tokenizer.encode(prompt2, bos=False, eos=False, allowed_special='all')
raw_tokens3 = tokenizer.encode(prompt3, bos=False, eos=False, allowed_special='all')
Expand All @@ -267,6 +312,8 @@ def main():
base_raw_tokens4 = tokenizer.encode(bp4, bos=True, eos=False, allowed_special='all')


<<<<<<< HEAD

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge tags left in here

=======
def generate(xfmr_weights, model_params, tokens):
gen_tokens = None
cur_pos = 0
Expand All @@ -290,18 +337,19 @@ def generate(xfmr_weights, model_params, tokens):
print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
if jnp.isin(next_token, stop).any():
break
>>>>>>> origin/main

print(prompt)
generate(xfmr_weights, model_params, raw_tokens1)
generate(xfmr_weights, model_params, sampler_params, tokenizer, raw_tokens1)
print('\n')
print(prompt2)
generate(xfmr_weights, model_params, raw_tokens2)
generate(xfmr_weights, model_params, sampler_params, tokenizer, raw_tokens2)
print('\n')
print(prompt3)
generate(xfmr_weights, model_params, raw_tokens3)
generate(xfmr_weights, model_params, sampler_params, tokenizer, raw_tokens3)
print('\n')
print(prompt4)
generate(xfmr_weights, model_params, raw_tokens4)
generate(xfmr_weights, model_params, sampler_params, tokenizer, raw_tokens4)
print('\n')

#print(bp1)
Expand All @@ -318,4 +366,8 @@ def generate(xfmr_weights, model_params, tokens):
#print('\n')

if __name__ == '__main__':
<<<<<<< HEAD
tyro.cli(main)
=======
tyro.cli(main)
>>>>>>> origin/main
Loading