Skip to content

Commit

Permalink
fix: correct variable names and string formatting, ensure consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaymeh committed Mar 5, 2024
1 parent df75d55 commit f2ba058
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 70 deletions.
76 changes: 46 additions & 30 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

import lance


def apply_fim(sample, fim_prefix, fim_middle, fim_suffix, fim_pad, mode, np_rng):
"""
Applies FIM transformation on one sample
"""
boundaries = sorted(np_rng.randint(low=0, high=len(sample)+1, size=2))
boundaries = sorted(np_rng.randint(low=0, high=len(sample) + 1, size=2))

prefix = sample[: boundaries[0]]
middle = sample[boundaries[0] : boundaries[1]]
Expand All @@ -23,18 +24,18 @@ def apply_fim(sample, fim_prefix, fim_middle, fim_suffix, fim_pad, mode, np_rng)
elif diff < 0:
extend = torch.cat([fim_pad for _ in range(-diff)])
suffix = torch.cat([suffix, extend])
if mode == 'spm':

if mode == "spm":
# Apply SPM
transfomed_example = torch.cat([
fim_prefix, fim_suffix, suffix, fim_middle, prefix, middle
])
transfomed_example = torch.cat(
[fim_prefix, fim_suffix, suffix, fim_middle, prefix, middle]
)
else:
# Apply PSM
transfomed_example = torch.cat([
fim_prefix, prefix, fim_suffix, suffix, fim_middle, middle
])
transfomed_example = torch.cat(
[fim_prefix, prefix, fim_suffix, suffix, fim_middle, middle]
)

return transfomed_example


Expand All @@ -48,16 +49,16 @@ def __init__(
fim_suffix,
fim_pad,
fim_rate=0.5,
mode='psm',
rng_seed=42
mode="psm",
rng_seed=42,
):
# Load the lance dataset from the saved path
self.ds = lance.dataset(dataset_path)
self.context_len = context_len

# Doing this so the sampler never asks for an index at the end of text
self.length = self.ds.count_rows() - context_len

self.np_rng = np.random.RandomState(seed=rng_seed)

self.fim_prefix = torch.tensor([fim_prefix])
Expand All @@ -75,14 +76,14 @@ def from_idxs(self, idxs):
Little utility function to get the data from lance
"""
data = self.ds.take(idxs).to_pylist()
data = torch.tensor(list(map(lambda x: x['value'], data)))
data = torch.tensor(list(map(lambda x: x["value"], data)))
return data

def apply_fim(self, sample):
"""
Applies FIM transformation on one sample
"""
boundaries = sorted(self.np_rng.randint(low=0, high=len(sample)+1, size=2))
boundaries = sorted(self.np_rng.randint(low=0, high=len(sample) + 1, size=2))

prefix = sample[: boundaries[0]]
middle = sample[boundaries[0] : boundaries[1]]
Expand All @@ -95,36 +96,50 @@ def apply_fim(self, sample):
elif diff < 0:
extend = torch.cat([self.fim_pad for _ in range(-diff)])
suffix = torch.cat([suffix, extend])
if self.mode == 'spm':

if self.mode == "spm":
# Apply SPM
transfomed_example = torch.cat([
self.fim_prefix, self.fim_suffix, suffix, self.fim_middle, prefix, middle
])
transfomed_example = torch.cat(
[
self.fim_prefix,
self.fim_suffix,
suffix,
self.fim_middle,
prefix,
middle,
]
)
else:
# Apply PSM
transfomed_example = torch.cat([
self.fim_prefix, prefix, self.fim_suffix, suffix, self.fim_middle, middle
])

transfomed_example = torch.cat(
[
self.fim_prefix,
prefix,
self.fim_suffix,
suffix,
self.fim_middle,
middle,
]
)

return transfomed_example

def __getitem__(self, idx):
"""
Generate a list of indices starting from the current idx to idx+context_len+1
with optional fim transformation
"""
current_window_idxs = np.arange(idx, idx+self.context_len+1)
current_window_idxs = np.arange(idx, idx + self.context_len + 1)
sample = self.from_idxs(current_window_idxs)

# Apply FIM transformation depending on the rate
if self.np_rng.binomial(1, self.fim_rate):
sample = self.apply_fim(sample)

# +1 in labels because it is 1 step ahead of input tokens
tokens = sample[0:self.context_len]
labels = sample[1:self.context_len+1]
return {'tokens': tokens, 'labels': labels}
tokens = sample[0 : self.context_len]
labels = sample[1 : self.context_len + 1]
return {"tokens": tokens, "labels": labels}


class MambaSampler(Sampler):
Expand All @@ -137,11 +152,12 @@ class MambaSampler(Sampler):

def __init__(self, data_source, k=16):
self.data_source = data_source
self.num_samples = len(self.data_source)
self.available_indices = list(range(0, self.num_samples, k))
random.shuffle(self.available_indices)

def __iter__(self):
yield from self.available_indices

def __len__(self) -> int:
return len(self.available_indices)
return len(self.available_indices)
91 changes: 51 additions & 40 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@

import wandb


# Params (replace with Arg parser later)
class Args:
wandb = False
tokenizer_model = "EleutherAI/gpt-neox-20b"
model_name = "state-spaces/mamba-790m"
dataset_path = "/teamspace/studios/codeparrot-dataset-lance/code_parrot_github_python.lance"
dataset_path = (
"/teamspace/studios/codeparrot-dataset-lance/code_parrot_github_python.lance"
)
eval_dataset_path = "fim_data_eval.lance"
dataset = lance.dataset(dataset_path)
low_cpu_mem_usage = False
Expand All @@ -43,9 +46,10 @@ class Args:
T_0 = 1000
T_mult = 1
eta_min = 1e-5
device = torch.device('cuda:0')
device = torch.device("cuda:0")
# Total chunks of context_len+1 size we can get
steps_per_epoch = (dataset.count_rows() // context_len+1) // 4
steps_per_epoch = (dataset.count_rows() // context_len + 1) // 4


# Define Tokenizer and Model
tokenizer = transformers.AutoTokenizer.from_pretrained(Args.tokenizer_model)
Expand All @@ -56,7 +60,14 @@ class Args:
).to(Args.device)

# Get the FIM-specific tokens and get their token ids
tokenizer.add_tokens([Args.fim_prefix_token, Args.fim_middle_token, Args.fim_middle_token, Args.fim_pad_token])
tokenizer.add_tokens(
[
Args.fim_prefix_token,
Args.fim_middle_token,
Args.fim_middle_token,
Args.fim_pad_token,
]
)
prefix_tok_id = tokenizer.convert_tokens_to_ids(Args.fim_prefix_token)
middle_tok_id = tokenizer.convert_tokens_to_ids(Args.fim_middle_token)
suffix_tok_id = tokenizer.convert_tokens_to_ids(Args.fim_middle_token)
Expand All @@ -75,53 +86,48 @@ class Args:
mean = original_embeddings.mean(dim=0)
n = original_embeddings.size()[0]
sigma = ((original_embeddings - mean).T @ (original_embeddings - mean)) / n
dist = torch.distributions.MultivariateNormal(
mean,
covariance_matrix=1e-5*sigma
)
dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=1e-5 * sigma)
new_token_embeddings = torch.stack(
tuple((dist.sample() for _ in range(len(fim_tokens)))),
dim=0
tuple((dist.sample() for _ in range(len(fim_tokens)))), dim=0
)

# Get updated embedding layer and make a copy of it's weights
embeddings = model.get_input_embeddings()
new_embeddings = embeddings.weight.clone()

# Set the new token' embeddings to the newly sampled embeddings
new_embeddings[-len(fim_tokens):] = new_token_embeddings
new_embeddings[-len(fim_tokens) :] = new_token_embeddings

# Update the model's embeddings with the new embeddings
embeddings.weight = torch.nn.Parameter(new_embeddings)
model.set_input_embeddings(embeddings)

# Make train dataset and train dataloader
train_dataset = MambaDataset(
Args.dataset_path,
context_len=Args.context_len,
Args.dataset_path,
context_len=Args.context_len,
fim_prefix=prefix_tok_id,
fim_middle=middle_tok_id,
fim_suffix=suffix_tok_id,
fim_pad=pad_tok_id,
fim_rate=Args.fim_rate,
mode='psm',
mode="psm",
)

train_dataloader = iter(DataLoader(
train_dataset,
batch_size=Args.train_batch_size,
sampler=MambaSampler(train_dataset, k=Args.context_len+1),
shuffle=False,
pin_memory=True
))
train_dataloader = iter(
DataLoader(
train_dataset,
batch_size=Args.train_batch_size,
sampler=MambaSampler(train_dataset, k=Args.context_len + 1),
shuffle=False,
pin_memory=True,
)
)

# Optimizer and Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=Args.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=Args.T_0,
T_mult=Args.T_mult,
eta_min=Args.eta_min
optimizer, T_0=Args.T_0, T_mult=Args.T_mult, eta_min=Args.eta_min
)

# Start training
Expand All @@ -131,31 +137,36 @@ class Args:
print(f"Training steps per epoch: {Args.steps_per_epoch:,}\n")
# print(f"Total training steps in training: {Args.steps_per_epoch * Args.epochs:,}")


def wandb_log(**kwargs):
"""Easy interface to log stuff to wandb"""
for k, v in kwargs.items():
wandb.log({k: v})


if Args.wandb:
# Convert the Config class to a dict for logging
config_dict = dict(vars(Args))
del[config_dict['__module__']]
del[config_dict['__dict__']]
del[config_dict['__weakref__']]
del[config_dict['__doc__']]
del [config_dict["__module__"]]
del [config_dict["__dict__"]]
del [config_dict["__weakref__"]]
del [config_dict["__doc__"]]

from dotenv import load_dotenv

load_dotenv()
wandb.login()
run = wandb.init(
project='pytorch',
project="pytorch",
config=config_dict,
group='mamba-train',
job_type='train',
group="mamba-train",
job_type="train",
)
wandb.watch(model)

prog_bar = tqdm(range(Args.steps_per_epoch * Args.epochs), total=Args.steps_per_epoch * Args.epochs)
prog_bar = tqdm(
range(Args.steps_per_epoch * Args.epochs), total=Args.steps_per_epoch * Args.epochs
)
for epoch in range(Args.epochs):
model.train()
total_loss = []
Expand All @@ -166,12 +177,12 @@ def wandb_log(**kwargs):
batch[k] = v.to(Args.device)

# Get predictions
predictions = model(batch['tokens'])
predictions = model(batch["tokens"])

# Reshape predictions and calculate loss
B, C, V = predictions.shape
predictions = predictions.view(B*C, V)
targets = batch['labels'].view(B*C)
predictions = predictions.view(B * C, V)
targets = batch["labels"].view(B * C)
loss = torch.nn.functional.cross_entropy(predictions, targets)
prog_bar.set_description((f"loss: {loss.item():.4f}"))

Expand All @@ -189,14 +200,14 @@ def wandb_log(**kwargs):
try:
perplexity = np.exp(np.mean(total_loss))
except OverflowError:
perplexity = float('-inf')
perplexity = float("-inf")

if Args.wandb:
wandb_log(train_perplexity=perplexity)

print(f"epoch: {epoch} | train perplexity: {perplexity:.4f}")

# Save the model after training
model_name = Args.model_name.split('/')[-1]
model_name = Args.model_name.split("/")[-1]
torch.save(model.state_dict(), f"{model_name}-fim.bin")
print("Saved the model!")
print("Saved the model!")

0 comments on commit f2ba058

Please sign in to comment.