diff --git a/data.py b/data.py index 53b882c..5e571ed 100644 --- a/data.py +++ b/data.py @@ -6,11 +6,12 @@ import lance + def apply_fim(sample, fim_prefix, fim_middle, fim_suffix, fim_pad, mode, np_rng): """ Applies FIM transformation on one sample """ - boundaries = sorted(np_rng.randint(low=0, high=len(sample)+1, size=2)) + boundaries = sorted(np_rng.randint(low=0, high=len(sample) + 1, size=2)) prefix = sample[: boundaries[0]] middle = sample[boundaries[0] : boundaries[1]] @@ -23,18 +24,18 @@ def apply_fim(sample, fim_prefix, fim_middle, fim_suffix, fim_pad, mode, np_rng) elif diff < 0: extend = torch.cat([fim_pad for _ in range(-diff)]) suffix = torch.cat([suffix, extend]) - - if mode == 'spm': + + if mode == "spm": # Apply SPM - transfomed_example = torch.cat([ - fim_prefix, fim_suffix, suffix, fim_middle, prefix, middle - ]) + transfomed_example = torch.cat( + [fim_prefix, fim_suffix, suffix, fim_middle, prefix, middle] + ) else: # Apply PSM - transfomed_example = torch.cat([ - fim_prefix, prefix, fim_suffix, suffix, fim_middle, middle - ]) - + transfomed_example = torch.cat( + [fim_prefix, prefix, fim_suffix, suffix, fim_middle, middle] + ) + return transfomed_example @@ -48,16 +49,16 @@ def __init__( fim_suffix, fim_pad, fim_rate=0.5, - mode='psm', - rng_seed=42 + mode="psm", + rng_seed=42, ): # Load the lance dataset from the saved path self.ds = lance.dataset(dataset_path) self.context_len = context_len - + # Doing this so the sampler never asks for an index at the end of text self.length = self.ds.count_rows() - context_len - + self.np_rng = np.random.RandomState(seed=rng_seed) self.fim_prefix = torch.tensor([fim_prefix]) @@ -75,14 +76,14 @@ def from_idxs(self, idxs): Little utility function to get the data from lance """ data = self.ds.take(idxs).to_pylist() - data = torch.tensor(list(map(lambda x: x['value'], data))) + data = torch.tensor(list(map(lambda x: x["value"], data))) return data def apply_fim(self, sample): """ Applies FIM transformation on one sample """ - boundaries = sorted(self.np_rng.randint(low=0, high=len(sample)+1, size=2)) + boundaries = sorted(self.np_rng.randint(low=0, high=len(sample) + 1, size=2)) prefix = sample[: boundaries[0]] middle = sample[boundaries[0] : boundaries[1]] @@ -95,18 +96,32 @@ def apply_fim(self, sample): elif diff < 0: extend = torch.cat([self.fim_pad for _ in range(-diff)]) suffix = torch.cat([suffix, extend]) - - if self.mode == 'spm': + + if self.mode == "spm": # Apply SPM - transfomed_example = torch.cat([ - self.fim_prefix, self.fim_suffix, suffix, self.fim_middle, prefix, middle - ]) + transfomed_example = torch.cat( + [ + self.fim_prefix, + self.fim_suffix, + suffix, + self.fim_middle, + prefix, + middle, + ] + ) else: # Apply PSM - transfomed_example = torch.cat([ - self.fim_prefix, prefix, self.fim_suffix, suffix, self.fim_middle, middle - ]) - + transfomed_example = torch.cat( + [ + self.fim_prefix, + prefix, + self.fim_suffix, + suffix, + self.fim_middle, + middle, + ] + ) + return transfomed_example def __getitem__(self, idx): @@ -114,7 +129,7 @@ def __getitem__(self, idx): Generate a list of indices starting from the current idx to idx+context_len+1 with optional fim transformation """ - current_window_idxs = np.arange(idx, idx+self.context_len+1) + current_window_idxs = np.arange(idx, idx + self.context_len + 1) sample = self.from_idxs(current_window_idxs) # Apply FIM transformation depending on the rate @@ -122,9 +137,9 @@ def __getitem__(self, idx): sample = self.apply_fim(sample) # +1 in labels because it is 1 step ahead of input tokens - tokens = sample[0:self.context_len] - labels = sample[1:self.context_len+1] - return {'tokens': tokens, 'labels': labels} + tokens = sample[0 : self.context_len] + labels = sample[1 : self.context_len + 1] + return {"tokens": tokens, "labels": labels} class MambaSampler(Sampler): @@ -137,6 +152,7 @@ class MambaSampler(Sampler): def __init__(self, data_source, k=16): self.data_source = data_source + self.num_samples = len(self.data_source) self.available_indices = list(range(0, self.num_samples, k)) random.shuffle(self.available_indices) @@ -144,4 +160,4 @@ def __iter__(self): yield from self.available_indices def __len__(self) -> int: - return len(self.available_indices) \ No newline at end of file + return len(self.available_indices) diff --git a/train.py b/train.py index 09370ab..9e2acf7 100644 --- a/train.py +++ b/train.py @@ -18,12 +18,15 @@ import wandb + # Params (replace with Arg parser later) class Args: wandb = False tokenizer_model = "EleutherAI/gpt-neox-20b" model_name = "state-spaces/mamba-790m" - dataset_path = "/teamspace/studios/codeparrot-dataset-lance/code_parrot_github_python.lance" + dataset_path = ( + "/teamspace/studios/codeparrot-dataset-lance/code_parrot_github_python.lance" + ) eval_dataset_path = "fim_data_eval.lance" dataset = lance.dataset(dataset_path) low_cpu_mem_usage = False @@ -43,9 +46,10 @@ class Args: T_0 = 1000 T_mult = 1 eta_min = 1e-5 - device = torch.device('cuda:0') + device = torch.device("cuda:0") # Total chunks of context_len+1 size we can get - steps_per_epoch = (dataset.count_rows() // context_len+1) // 4 + steps_per_epoch = (dataset.count_rows() // context_len + 1) // 4 + # Define Tokenizer and Model tokenizer = transformers.AutoTokenizer.from_pretrained(Args.tokenizer_model) @@ -56,7 +60,14 @@ class Args: ).to(Args.device) # Get the FIM-specific tokens and get their token ids -tokenizer.add_tokens([Args.fim_prefix_token, Args.fim_middle_token, Args.fim_middle_token, Args.fim_pad_token]) +tokenizer.add_tokens( + [ + Args.fim_prefix_token, + Args.fim_middle_token, + Args.fim_middle_token, + Args.fim_pad_token, + ] +) prefix_tok_id = tokenizer.convert_tokens_to_ids(Args.fim_prefix_token) middle_tok_id = tokenizer.convert_tokens_to_ids(Args.fim_middle_token) suffix_tok_id = tokenizer.convert_tokens_to_ids(Args.fim_middle_token) @@ -75,13 +86,9 @@ class Args: mean = original_embeddings.mean(dim=0) n = original_embeddings.size()[0] sigma = ((original_embeddings - mean).T @ (original_embeddings - mean)) / n -dist = torch.distributions.MultivariateNormal( - mean, - covariance_matrix=1e-5*sigma -) +dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=1e-5 * sigma) new_token_embeddings = torch.stack( - tuple((dist.sample() for _ in range(len(fim_tokens)))), - dim=0 + tuple((dist.sample() for _ in range(len(fim_tokens)))), dim=0 ) # Get updated embedding layer and make a copy of it's weights @@ -89,7 +96,7 @@ class Args: new_embeddings = embeddings.weight.clone() # Set the new token' embeddings to the newly sampled embeddings -new_embeddings[-len(fim_tokens):] = new_token_embeddings +new_embeddings[-len(fim_tokens) :] = new_token_embeddings # Update the model's embeddings with the new embeddings embeddings.weight = torch.nn.Parameter(new_embeddings) @@ -97,31 +104,30 @@ class Args: # Make train dataset and train dataloader train_dataset = MambaDataset( - Args.dataset_path, - context_len=Args.context_len, + Args.dataset_path, + context_len=Args.context_len, fim_prefix=prefix_tok_id, fim_middle=middle_tok_id, fim_suffix=suffix_tok_id, fim_pad=pad_tok_id, fim_rate=Args.fim_rate, - mode='psm', + mode="psm", ) -train_dataloader = iter(DataLoader( - train_dataset, - batch_size=Args.train_batch_size, - sampler=MambaSampler(train_dataset, k=Args.context_len+1), - shuffle=False, - pin_memory=True -)) +train_dataloader = iter( + DataLoader( + train_dataset, + batch_size=Args.train_batch_size, + sampler=MambaSampler(train_dataset, k=Args.context_len + 1), + shuffle=False, + pin_memory=True, + ) +) # Optimizer and Scheduler optimizer = torch.optim.AdamW(model.parameters(), lr=Args.lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( - optimizer, - T_0=Args.T_0, - T_mult=Args.T_mult, - eta_min=Args.eta_min + optimizer, T_0=Args.T_0, T_mult=Args.T_mult, eta_min=Args.eta_min ) # Start training @@ -131,31 +137,36 @@ class Args: print(f"Training steps per epoch: {Args.steps_per_epoch:,}\n") # print(f"Total training steps in training: {Args.steps_per_epoch * Args.epochs:,}") + def wandb_log(**kwargs): """Easy interface to log stuff to wandb""" for k, v in kwargs.items(): wandb.log({k: v}) + if Args.wandb: # Convert the Config class to a dict for logging config_dict = dict(vars(Args)) - del[config_dict['__module__']] - del[config_dict['__dict__']] - del[config_dict['__weakref__']] - del[config_dict['__doc__']] + del [config_dict["__module__"]] + del [config_dict["__dict__"]] + del [config_dict["__weakref__"]] + del [config_dict["__doc__"]] from dotenv import load_dotenv + load_dotenv() wandb.login() run = wandb.init( - project='pytorch', + project="pytorch", config=config_dict, - group='mamba-train', - job_type='train', + group="mamba-train", + job_type="train", ) wandb.watch(model) -prog_bar = tqdm(range(Args.steps_per_epoch * Args.epochs), total=Args.steps_per_epoch * Args.epochs) +prog_bar = tqdm( + range(Args.steps_per_epoch * Args.epochs), total=Args.steps_per_epoch * Args.epochs +) for epoch in range(Args.epochs): model.train() total_loss = [] @@ -166,12 +177,12 @@ def wandb_log(**kwargs): batch[k] = v.to(Args.device) # Get predictions - predictions = model(batch['tokens']) + predictions = model(batch["tokens"]) # Reshape predictions and calculate loss B, C, V = predictions.shape - predictions = predictions.view(B*C, V) - targets = batch['labels'].view(B*C) + predictions = predictions.view(B * C, V) + targets = batch["labels"].view(B * C) loss = torch.nn.functional.cross_entropy(predictions, targets) prog_bar.set_description((f"loss: {loss.item():.4f}")) @@ -189,14 +200,14 @@ def wandb_log(**kwargs): try: perplexity = np.exp(np.mean(total_loss)) except OverflowError: - perplexity = float('-inf') - + perplexity = float("-inf") + if Args.wandb: wandb_log(train_perplexity=perplexity) print(f"epoch: {epoch} | train perplexity: {perplexity:.4f}") # Save the model after training -model_name = Args.model_name.split('/')[-1] +model_name = Args.model_name.split("/")[-1] torch.save(model.state_dict(), f"{model_name}-fim.bin") -print("Saved the model!") \ No newline at end of file +print("Saved the model!")