diff --git a/.gitignore b/.gitignore index a3c82a25b7..ae49f32bd2 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ __pycache__/ *.pt *.pyc input.txt +wandb diff --git a/config/train_shakespeare_char.py b/config/train_shakespeare_char.py index 41c81dfb4e..82ff98ee27 100644 --- a/config/train_shakespeare_char.py +++ b/config/train_shakespeare_char.py @@ -34,4 +34,4 @@ # on macbook also add # device = 'cpu' # run on cpu only -# compile = False # do not torch compile the model +# compile = False # do not torch compile the model \ No newline at end of file diff --git a/model.py b/model.py index 827e8ca970..6b4523f66b 100644 --- a/model.py +++ b/model.py @@ -19,6 +19,16 @@ import torch import torch.nn as nn from torch.nn import functional as F +from tqdm import tqdm + + +# The modified softmax version: +def surftmex(x, dim=-1): + maxes = torch.max(x, dim, keepdim=True)[0] + x_exp = torch.exp(x-maxes) + x_exp_sum = torch.sum(x_exp, dim, keepdim=True) + output_custom = x_exp/(torch.exp(-maxes)+x_exp_sum) # << The key bit is the +torch.exp(-maxes) + return output_custom class LayerNorm(nn.Module): """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ @@ -381,28 +391,37 @@ def estimate_mfu(self, fwdbwd_per_iter, dt): return mfu @torch.no_grad() - def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): + def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, pbar=False): """ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this. """ - for _ in range(max_new_tokens): + ppl = torch.zeros(idx.shape[0], device=idx.device) + for _ in tqdm(range(max_new_tokens), disable=not pbar): # if the sequence context is growing too long we must crop it at block_size - idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] + idx_cond = ( + idx + if idx.size(1) <= self.config.block_size + else idx[:, -self.config.block_size :] + ) # forward the model to get the logits for the index in the sequence logits, _ = self(idx_cond) # pluck the logits at the final step and scale by desired temperature - logits = logits[:, -1, :] / temperature + tlogits = logits[:, -1, :] / temperature # optionally crop the logits to only the top k options if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float('Inf') + v, _ = torch.topk(tlogits, min(top_k, tlogits.size(-1))) + tlogits[tlogits < v[:, [-1]]] = -float("Inf") # apply softmax to convert logits to (normalized) probabilities - probs = F.softmax(logits, dim=-1) + probs = F.softmax(tlogits, dim=-1) # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # append sampled index to the running sequence and continue idx = torch.cat((idx, idx_next), dim=1) - return idx + # calculate perplexity from final probs, feel free to change as needed + ppl += -torch.log2(probs[:, idx_next.squeeze()].diag()) / max_new_tokens + ppl = ppl.detach().cpu().numpy() + + return idx, ppl diff --git a/train.py b/train.py index a482ab7f4e..27caccb170 100644 --- a/train.py +++ b/train.py @@ -66,6 +66,9 @@ warmup_iters = 2000 # how many steps to warm up for lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla +# Perplexity Required +start = "cevap" # or "<|endoftext|>" or etc. +sample_length = 16 # DDP settings backend = 'nccl' # 'nccl', 'gloo', etc. # system @@ -113,8 +116,20 @@ # poor man's data loader data_dir = os.path.join('data', dataset) -train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') -val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') + +# attempt to derive vocab_size from the dataset +meta_path = os.path.join(data_dir, 'meta.pkl') +meta_vocab_size = None +if os.path.exists(meta_path): + with open(meta_path, 'rb') as f: + meta = pickle.load(f) + meta_vocab_size = meta['vocab_size'] + print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") + +mtype = np.uint16 if meta_vocab_size > 256 else np.uint8 + +train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=mtype, mode='r') +val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=mtype, mode='r') def get_batch(split): data = train_data if split == 'train' else val_data ix = torch.randint(len(data) - block_size, (batch_size,)) @@ -131,15 +146,6 @@ def get_batch(split): iter_num = 0 best_val_loss = 1e9 -# attempt to derive vocab_size from the dataset -meta_path = os.path.join(data_dir, 'meta.pkl') -meta_vocab_size = None -if os.path.exists(meta_path): - with open(meta_path, 'rb') as f: - meta = pickle.load(f) - meta_vocab_size = meta['vocab_size'] - print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") - # model init model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line @@ -249,23 +255,28 @@ def get_lr(it): local_iter_num = 0 # number of iterations in the lifetime of this process raw_model = model.module if ddp else model # unwrap DDP container if needed running_mfu = -1.0 +token_num = 0 while True: - # determine and set the learning rate for this iteration lr = get_lr(iter_num) if decay_lr else learning_rate for param_group in optimizer.param_groups: param_group['lr'] = lr + # Calculate perplexity + gen_x = torch.tensor([meta["stoi"][i] for i in start]).tile(batch_size, 1).to(device) + _, ppl = model.generate(gen_x, sample_length) # evaluate the loss on train/val sets and write checkpoints if iter_num % eval_interval == 0 and master_process: losses = estimate_loss() - print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") + print(f"step {iter_num}, tokens {token_num:,d}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}", + f"avg perplexity: {ppl.mean():.4f}") if wandb_log: wandb.log({ "iter": iter_num, "train/loss": losses['train'], "val/loss": losses['val'], "lr": lr, + "perplexity": ppl.mean(), "mfu": running_mfu*100, # convert to percentage }) if losses['val'] < best_val_loss or always_save_checkpoint: @@ -324,6 +335,7 @@ def get_lr(it): print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") iter_num += 1 local_iter_num += 1 + token_num += tokens_per_iter # termination conditions if iter_num > max_iters: