Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding tokens and perplexity calculation #4

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ __pycache__/
*.pt
*.pyc
input.txt
wandb
2 changes: 1 addition & 1 deletion config/train_shakespeare_char.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@

# on macbook also add
# device = 'cpu' # run on cpu only
# compile = False # do not torch compile the model
# compile = False # do not torch compile the model
35 changes: 27 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm


# The modified softmax version:
def surftmex(x, dim=-1):
maxes = torch.max(x, dim, keepdim=True)[0]
x_exp = torch.exp(x-maxes)
x_exp_sum = torch.sum(x_exp, dim, keepdim=True)
output_custom = x_exp/(torch.exp(-maxes)+x_exp_sum) # << The key bit is the +torch.exp(-maxes)
return output_custom

class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
Expand Down Expand Up @@ -381,28 +391,37 @@ def estimate_mfu(self, fwdbwd_per_iter, dt):
return mfu

@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, pbar=False):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
ppl = torch.zeros(idx.shape[0], device=idx.device)
for _ in tqdm(range(max_new_tokens), disable=not pbar):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
idx_cond = (
idx
if idx.size(1) <= self.config.block_size
else idx[:, -self.config.block_size :]
)
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
tlogits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
v, _ = torch.topk(tlogits, min(top_k, tlogits.size(-1)))
tlogits[tlogits < v[:, [-1]]] = -float("Inf")
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
probs = F.softmax(tlogits, dim=-1)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)

return idx
# calculate perplexity from final probs, feel free to change as needed
ppl += -torch.log2(probs[:, idx_next.squeeze()].diag()) / max_new_tokens
ppl = ppl.detach().cpu().numpy()

return idx, ppl
38 changes: 25 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
warmup_iters = 2000 # how many steps to warm up for
lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
# Perplexity Required
start = "cevap" # or "<|endoftext|>" or etc.
sample_length = 16
# DDP settings
backend = 'nccl' # 'nccl', 'gloo', etc.
# system
Expand Down Expand Up @@ -113,8 +116,20 @@

# poor man's data loader
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')

# attempt to derive vocab_size from the dataset
meta_path = os.path.join(data_dir, 'meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
meta_vocab_size = meta['vocab_size']
print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")

mtype = np.uint16 if meta_vocab_size > 256 else np.uint8

train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=mtype, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=mtype, mode='r')
def get_batch(split):
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
Expand All @@ -131,15 +146,6 @@ def get_batch(split):
iter_num = 0
best_val_loss = 1e9

# attempt to derive vocab_size from the dataset
meta_path = os.path.join(data_dir, 'meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
meta_vocab_size = meta['vocab_size']
print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")

# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
Expand Down Expand Up @@ -249,23 +255,28 @@ def get_lr(it):
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model.module if ddp else model # unwrap DDP container if needed
running_mfu = -1.0
token_num = 0
while True:

# determine and set the learning rate for this iteration
lr = get_lr(iter_num) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group['lr'] = lr

# Calculate perplexity
gen_x = torch.tensor([meta["stoi"][i] for i in start]).tile(batch_size, 1).to(device)
_, ppl = model.generate(gen_x, sample_length)
# evaluate the loss on train/val sets and write checkpoints
if iter_num % eval_interval == 0 and master_process:
losses = estimate_loss()
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
print(f"step {iter_num}, tokens {token_num:,d}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}",
f"avg perplexity: {ppl.mean():.4f}")
if wandb_log:
wandb.log({
"iter": iter_num,
"train/loss": losses['train'],
"val/loss": losses['val'],
"lr": lr,
"perplexity": ppl.mean(),
"mfu": running_mfu*100, # convert to percentage
})
if losses['val'] < best_val_loss or always_save_checkpoint:
Expand Down Expand Up @@ -324,6 +335,7 @@ def get_lr(it):
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
iter_num += 1
local_iter_num += 1
token_num += tokens_per_iter

# termination conditions
if iter_num > max_iters:
Expand Down