Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Just to see the diff #3

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions examples/research_projects/codeparrot/scripts/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class TrainingArguments:
dataset_name_valid: Optional[str] = field(
default="codeparrot/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
)
train_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for training."})
valid_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for evaluation."})
train_batch_size: Optional[int] = field(default=320, metadata={"help": "Batch size for training."})
train_batch_size_select: Optional[int] = field(default=32, metadata={"help": "Batch size to subselect for training."})
valid_batch_size: Optional[int] = field(default=32, metadata={"help": "Batch size for evaluation."})
weight_decay: Optional[float] = field(default=0.1, metadata={"help": "Value of weight decay."})
shuffle_buffer: Optional[int] = field(
default=10000, metadata={"help": "Size of buffer used to shuffle streaming dataset."}
Expand All @@ -47,6 +48,18 @@ class TrainingArguments:
default=1024,
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
)
selection_method: Optional[str] = field(
default=None, metadata={"help": "Selection method to subselect from the batch size. Can be uniform or rholoss"}
)
irred_losses: Optional[str] = field(
default="irred_losses.pt", metadata={"help": "Path to irreducible losses pt file. Must be supplied if selection_method is rholoss"}
)
compute_irred_losses: Optional[bool] = field(
default=False,
metadata={
"help": "If True irreducible losses are computed and saved to the path specified by irred_losses."
},
)
resume_from_checkpoint: Optional[str] = field(
default=None, metadata={"help": "States path if the training should continue from a checkpoint folder."}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,38 @@ def get_lr():
completed_steps = 0
t_start = time.time()
loss_tracking = 0

if args.compute_irred_losses:
irred_losses = torch.zeros(len(train_dataloader) * args.train_batch_size, device="cpu")
elif args.irred_losses and not args.compute_irred_losses:
# Should be of shape len(train_dataloader) * args.train_batch_size
irred_losses = torch.load(args.irred_losses, map_location=torch.device("cpu"))
loubnabnl marked this conversation as resolved.
Show resolved Hide resolved
assert irred_losses.shape[0] == len(train_dataloader) * args.train_batch_size

for step, batch in enumerate(train_dataloader, start=1):
if args.resume_from_checkpoint and step < resume_step:
continue # we need to skip steps until we reach the resumed step
if args.selection_method:
if args.selection_method == "uniform":
batch = {k: v[:args.train_batch_size_select] for k,v in batch.items()}
elif args.selection_method == "rholoss":
with torch.no_grad():
out = model(batch, labels=batch, use_cache=False).loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment about the batch size: we're assuming that we can fit a batch size of 320 with our workers, but I think we can only fit 12 sequences on A100 40GB (so on 16 workers: batch of 16*12=192).

So we should probably either incorporate gradient accumulation and store the losses for 2 iterations (2 * 10 (small bz) * 16 gpus=320) or we can change the batch sizes from 320/32 to something that suits us with a 10% ratio like 160/16. In the paper they just talk about the 10% ratio but I'm not sure if using large batches si also important?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no gradients here which means that a) we can likely fit a bigger batch size than 12 b) instead of grad acc. we can just run multiple times right after another & store the losses if it doesnt fit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes right! by grad acc. I also meant doing similar iterations over the losses

losses = accelerator.gather(loss.repeat(args.train_batch_size))
cur_irred_losses = irred_losses[step-1:args.train_batch_size*step]
assert losses.shape == cur_irred_losses.shape
red_losses = losses - cur_irred_losses
# Select the top args.train_batch_size_select losses & produce a new batch
top_losses, top_indices = torch.topk(red_losses, args.train_batch_size_select)
batch = {k: v[top_indices] for k,v in batch.items()}
if args.compute_irred_losses:
with torch.no_grad():
out = model(batch, labels=batch, use_cache=False).loss
losses = accelerator.gather(loss.repeat(args.train_batch_size))
irred_losses[step-1:args.train_batch_size*step] = losses
continue
loss = model(batch, labels=batch, use_cache=False).loss
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size_select)).mean()
loss_tracking += avg_loss.item() / args.gradient_accumulation_steps
log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()})
loss = loss / args.gradient_accumulation_steps
Expand Down Expand Up @@ -313,6 +340,11 @@ def get_lr():
if completed_steps >= args.max_train_steps:
break

# Save irred losses
if args.compute_irred_losses:
torch.save(irred_losses, os.path.join(args.save_dir, args.irred_losses))
exit()

# Evaluate and save the last checkpoint
logger.info("Evaluating and saving model after training")
eval_loss, perplexity = evaluate(args)
Expand Down