diff --git a/examples/research_projects/codeparrot/requirements.txt b/examples/research_projects/codeparrot/requirements.txt index 7eff3ac7f1..70ef16d3a2 100644 --- a/examples/research_projects/codeparrot/requirements.txt +++ b/examples/research_projects/codeparrot/requirements.txt @@ -1,9 +1,9 @@ -transformers==4.19.0 +git+https://github.com/loubnabnl/transformers.git@loss-reduction-none +accelerate==0.15.0 datasets==1.16.0 wandb==0.12.0 tensorboard==2.6.0 torch==1.11.0 huggingface-hub==0.1.0 -git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe datasketch==1.5.7 dpu_utils \ No newline at end of file diff --git a/examples/research_projects/codeparrot/scripts/arguments.py b/examples/research_projects/codeparrot/scripts/arguments.py index 4def9ac3b8..bb89706461 100644 --- a/examples/research_projects/codeparrot/scripts/arguments.py +++ b/examples/research_projects/codeparrot/scripts/arguments.py @@ -20,12 +20,16 @@ class TrainingArguments: dataset_name_valid: Optional[str] = field( default="codeparrot/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."} ) - train_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for training."}) - valid_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for evaluation."}) + train_batch_size: Optional[int] = field(default=320, metadata={"help": "Batch size for training."}) + train_batch_size_select: Optional[int] = field(default=32, metadata={"help": "Batch size to subselect for training."}) + valid_batch_size: Optional[int] = field(default=32, metadata={"help": "Batch size for evaluation."}) weight_decay: Optional[float] = field(default=0.1, metadata={"help": "Value of weight decay."}) shuffle_buffer: Optional[int] = field( default=10000, metadata={"help": "Size of buffer used to shuffle streaming dataset."} ) + no_streaming: Optional[bool] = field( + default=False, metadata={"help": "Whether not to use streaming for the dataset."} + ) learning_rate: Optional[float] = field(default=2e-4, metadata={"help": "Learning rate fo training."}) lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "Learning rate."}) num_warmup_steps: Optional[int] = field( @@ -47,6 +51,18 @@ class TrainingArguments: default=1024, metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."}, ) + selection_method: Optional[str] = field( + default=None, metadata={"help": "Selection method to subselect from the batch size. Can be uniform or rholoss"} + ) + irred_losses: Optional[str] = field( + default="irred_losses.pt", metadata={"help": "Path to irreducible losses pt file. Must be supplied if selection_method is rholoss"} + ) + compute_irred_losses: Optional[bool] = field( + default=False, + metadata={ + "help": "If True irreducible losses are computed and saved to the path specified by irred_losses." + }, + ) resume_from_checkpoint: Optional[str] = field( default=None, metadata={"help": "States path if the training should continue from a checkpoint folder."} ) diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index b2af8767a2..393f7730fc 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -1,10 +1,7 @@ -import logging import os import time from argparse import Namespace -from pathlib import Path -import datasets import torch from datasets import load_dataset from torch.optim import AdamW @@ -12,11 +9,11 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe -import transformers from accelerate import Accelerator, DistributedType from arguments import TrainingArguments -from huggingface_hub import Repository from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed +from utils_rholoss import compute_save_control_examples, sanity_check_irred_losses +from utils_training import compute_tflops, evaluate, get_grouped_params, get_lr, setup_logging class ConstantLengthDataset(IterableDataset): @@ -94,37 +91,17 @@ def shuffle(self, buffer_size=1000): return ShufflerIterDataPipe(self, buffer_size=buffer_size) -def setup_logging(args): - project_name = args.model_ckpt.split("/")[-1] - logger = logging.getLogger(__name__) - log_dir = Path(args.save_dir) / "log/" - log_dir.mkdir(exist_ok=True) - filename = f"debug_{accelerator.process_index}.log" - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()], - ) - if accelerator.is_main_process: # we only want to setup logging once - accelerator.init_trackers(project_name, vars(args)) - run_name = accelerator.trackers[0].run.name - logger.setLevel(logging.INFO) - datasets.utils.logging.set_verbosity_info() - transformers.utils.logging.set_verbosity_info() - else: - run_name = "" - logger.setLevel(logging.ERROR) - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - return logger, run_name - - def create_dataloaders(args): - ds_kwargs = {"streaming": True} - train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs) - train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) - valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs) + train_data = load_dataset(args.dataset_name_train, split="train", use_auth_token=True) + train_data = train_data.shuffle(seed=args.seed) + if args.dataset_name_train == args.dataset_name_valid: + # Split the dataset into train and validation + data = train_data.train_test_split(test_size=0.005, shuffle=False, seed=args.seed) + train_data = data["train"] + valid_data = data["test"] + print(f"Size of train set: {len(train_data)} and validation set {len(valid_data)}") + else: + valid_data = load_dataset(args.dataset_name_valid, split="train", use_auth_token=True) train_dataset = ConstantLengthDataset( tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized ) @@ -132,64 +109,18 @@ def create_dataloaders(args): tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized ) train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer) + # ShufflerIterDataPipe does not work in torch > 1.11 train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size) return train_dataloader, eval_dataloader -def get_grouped_params(model, args, no_decay=["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]): - params_with_wd, params_without_wd = [], [] - for n, p in model.named_parameters(): - if any(nd in n for nd in no_decay): - params_without_wd.append(p) - else: - params_with_wd.append(p) - return [ - {"params": params_with_wd, "weight_decay": args.weight_decay}, - {"params": params_without_wd, "weight_decay": 0.0}, - ] - - def log_metrics(step, metrics): logger.info(f"Step {step}: {metrics}") if accelerator.is_main_process: accelerator.log(metrics, step) -def compute_tflops(elapsed_time, accelerator, args): - # TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf). - config_model = accelerator.unwrap_model(model).config - checkpoint_factor = 4 if args.gradient_checkpointing else 3 - batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps - factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2) - flops_per_iteration = factor * ( - 1.0 - + (args.seq_length / (6.0 * config_model.n_embd)) - + (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd)) - ) - tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12)) - return tflops - - -def evaluate(args): - model.eval() - losses = [] - for step, batch in enumerate(eval_dataloader): - with torch.no_grad(): - outputs = model(batch, labels=batch) - loss = outputs.loss.repeat(args.valid_batch_size) - losses.append(accelerator.gather(loss)) - if args.max_eval_steps > 0 and step >= args.max_eval_steps: - break - losses = torch.cat(losses) - loss = losses[: eval_dataloader.dataset.current_size].mean() - try: - perplexity = torch.exp(loss) - except OverflowError: - perplexity = float("inf") - return loss.item(), perplexity.item() - - # Settings parser = HfArgumentParser(TrainingArguments) args = parser.parse_args() @@ -202,23 +133,15 @@ def evaluate(args): samples_per_step = accelerator.state.num_processes * args.train_batch_size set_seed(args.seed) -# Clone model repository -if accelerator.is_main_process: - hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt) - # Logging -logger, run_name = setup_logging(args) +logger, run_name = setup_logging(accelerator, args) logger.info(accelerator.state) -# Checkout new branch on repo -if accelerator.is_main_process: - hf_repo.git_checkout(run_name, create_branch_ok=True) - # Load model and tokenizer -model = AutoModelForCausalLM.from_pretrained(args.save_dir) +model = AutoModelForCausalLM.from_pretrained(args.model_ckpt) if args.gradient_checkpointing: model.gradient_checkpointing_enable() -tokenizer = AutoTokenizer.from_pretrained(args.save_dir) +tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt) # Load dataset and dataloader train_dataloader, eval_dataloader = create_dataloaders(args) @@ -234,41 +157,121 @@ def evaluate(args): accelerator.register_for_checkpointing(lr_scheduler) -def get_lr(): - return optimizer.param_groups[0]["lr"] - - # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader ) -# load in the weights and states from a previous save -if args.resume_from_checkpoint: - if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": - accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") - accelerator.load_state(args.resume_from_checkpoint) - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = [f.name for f in os.scandir(args.save_dir) if f.is_dir() and "step" in str(f)] - dirs.sort(key=os.path.getctime) - path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last - # Extract the step of the checkpoint to continue from there - training_difference = os.path.splitext(path)[0] - resume_step = int(training_difference.replace("step_", "")) - # Train model model.train() completed_steps = 0 t_start = time.time() loss_tracking = 0 + +samples_per_step = args.train_batch_size * accelerator.state.num_processes +total_samples = args.max_train_steps * samples_per_step + +if args.compute_irred_losses: + irred_losses = torch.zeros(total_samples, device="cpu") + compute_save_control_examples( + accelerator, + train_dataloader, + args.train_batch_size, + limit_control_examples=600, + seq_length=args.seq_length, + save_dir="./control_examples", + ) + if accelerator.is_main_process: + print(f"Total number of samples: {total_samples}") + print("Irreducible loss examples were saved in ./control_examples for future sanity checks on the order") + +elif args.irred_losses and not args.compute_irred_losses: + # Should be of shape total_samples + irred_losses = torch.load(args.irred_losses, map_location=torch.device("cpu")) + assert irred_losses.shape[0] == total_samples, print( + f"shape of irred_losses: {irred_losses.shape[0]}, len of train_dataloader: {total_samples}" + ) + +if args.selection_method and args.selection_method == "rholoss": + if accelerator.is_main_process: + print("Running sanity checks for irreducible losses") + # run sanity checks to verify the order of irreducible losses wrt current batches + sanity_check_irred_losses( + accelerator, + train_dataloader, + args.train_batch_size, + limit_control_examples=600, + seq_length=args.seq_length, + save_dir="./control_examples", + ) + for step, batch in enumerate(train_dataloader, start=1): - if args.resume_from_checkpoint and step < resume_step: - continue # we need to skip steps until we reach the resumed step - loss = model(batch, labels=batch, use_cache=False).loss - avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() - loss_tracking += avg_loss.item() / args.gradient_accumulation_steps + if args.selection_method: + # select a smaller batch of examples to train on using a selection method + if args.selection_method == "uniform": + batch = batch[torch.randperm(batch.size(0))[: args.train_batch_size_select]] + elif args.selection_method == "rholoss": + # in this setting global_batch_size = train_batch_size x nb_workers + with torch.no_grad(): + losses = [] + # To avoid running OOM, compute the loss on a smaller batches + sub_batches = torch.split(batch, args.gradient_accumulation_steps) + for sub_batch in sub_batches: + # we use reduction="none" in GPT2LMHeadModel loss implementation (install transformers from a fork) + loss = model(sub_batch, labels=sub_batch, use_cache=False).loss + loss = loss.view(sub_batch.size(0), -1).mean(dim=1) + losses.append(loss) + losses = torch.cat(losses, dim=0) + assert loss.shape == torch.Size( + [args.train_batch_size] + ), "make sure you are using GPT2LMHeadModel with reduction=none in the loss" + # TODO check data at the end that may be duplicated to divide batch equally among all workers + losses = accelerator.gather(loss) + cur_irred_losses = irred_losses[(step - 1) * samples_per_step : step * samples_per_step] + try: + losses.shape == cur_irred_losses.shape + except: + print( + f"Size mismatch between training losses {losses.shape} and irreducible losses {cur_irred_losses.shape}" + ) + red_losses = losses - cur_irred_losses.to(losses.device) + print(f"shape red_losses {red_losses.shape} and max is {torch.max(red_losses[0])}") + # Select the top args.train_batch_size_select losses & produce a new batch + top_losses, top_indices = torch.topk(red_losses, args.train_batch_size_select) + batches = accelerator.gather(batch) + batch = torch.index_select(batches, 0, top_indices) + # assert first element has the highest loss + assert torch.eq(top_losses[0], torch.max(red_losses)), "first element should have the highest loss" + print(f"new size of batch is {batch.shape}") + + if args.compute_irred_losses: + # compute irreducible losses over the entire dataset and exit + with torch.no_grad(): + # we use reduction="none" in GPT2LMHeadModel loss implementation + loss = model(batch, labels=batch, use_cache=False).loss + loss = loss.view(batch.size(0), -1).mean(dim=1) + assert loss.shape == torch.Size( + [args.train_batch_size] + ), "make sure you are using GPT2LMHeadModel with reduction=none in the loss" + losses = accelerator.gather(loss) + try: + irred_losses[(step - 1) * samples_per_step : step * samples_per_step] = losses + except: + print( + f"Size mismatch, step {step} between current losses {losses.shape} and irreducible losses \ + {irred_losses[(step-1) * samples_per_step: step * samples_per_step].shape}" + ) + break + if step >= args.max_train_steps: + break + continue + + # model training + # TODO! 32 batch doesn't fit => split batch and add "proper" grad accumulation + # we are using reduction="none" in GPT2LMHeadModel loss => we add .mean() over tokens + loss = model(batch, labels=batch, use_cache=False).loss.mean() + # no need to do accelerate gather we would just be repeating the loss (workers have the same batch) + loss_tracking += loss.item() / args.gradient_accumulation_steps log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()}) loss = loss / args.gradient_accumulation_steps if step % args.gradient_accumulation_steps != 0: @@ -279,14 +282,14 @@ def get_lr(): else: accelerator.backward(loss) else: - lr = get_lr() + lr = get_lr(optimizer) accelerator.backward(loss) accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad() elapsed_time = time.time() - t_start - tflops = compute_tflops(elapsed_time, accelerator, args) + tflops = compute_tflops(model, tokenizer, elapsed_time, accelerator, args) log_metrics( step, { @@ -302,25 +305,28 @@ def get_lr(): completed_steps += 1 if step % args.save_checkpoint_steps == 0: logger.info("Evaluating and saving model checkpoint") - eval_loss, perplexity = evaluate(args) + eval_loss, perplexity = evaluate(model, accelerator, eval_dataloader, args) log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) accelerator.wait_for_everyone() save_dir = os.path.join(args.save_dir, f"step_{step}") accelerator.save_state(save_dir) - if accelerator.is_main_process: - hf_repo.push_to_hub(commit_message=f"step {step}") model.train() if completed_steps >= args.max_train_steps: break +# Save irred losses +if args.compute_irred_losses: + if accelerator.is_main_process: + print(f"saving irred losses with shape: {irred_losses.shape}") + torch.save(irred_losses, "irred_losses.pt") + exit() + # Evaluate and save the last checkpoint logger.info("Evaluating and saving model after training") -eval_loss, perplexity = evaluate(args) +eval_loss, perplexity = evaluate(model, accelerator, eval_dataloader, args) log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save) save_dir = os.path.join(args.save_dir, f"step_{step}") accelerator.save_state(save_dir) -if accelerator.is_main_process: - hf_repo.push_to_hub(commit_message="final model") diff --git a/examples/research_projects/codeparrot/scripts/utils_rholoss.py b/examples/research_projects/codeparrot/scripts/utils_rholoss.py new file mode 100644 index 0000000000..cc8c6a7f1b --- /dev/null +++ b/examples/research_projects/codeparrot/scripts/utils_rholoss.py @@ -0,0 +1,58 @@ +import os + +import torch + + +def get_control_examples(accelerator, dataloader, train_batch_size, limit_control_examples=600, seq_length=1024): + """get control examples for computing the irreducible losses over the entire dataset""" + control_examples = torch.zeros((limit_control_examples, seq_length), device="cpu").type(torch.LongTensor) + for step, batch in enumerate(dataloader, start=1): + indexes = list( + range( + (step - 1) * train_batch_size * accelerator.state.num_processes, + step * train_batch_size * accelerator.state.num_processes, + ) + ) + if max(indexes) >= limit_control_examples: + break + with torch.no_grad(): + batches = accelerator.gather(batch) + control_examples[indexes] = batches.cpu() + return control_examples + + +def compute_save_control_examples( + accelerator, + dataloader, + train_batch_size, + limit_control_examples=600, + seq_length=1024, + save_dir="./control_examples", +): + """compute and save control examples for computing the irreducible losses over the entire dataset""" + control_examples = get_control_examples( + accelerator, dataloader, train_batch_size, limit_control_examples, seq_length + ) + if accelerator.is_main_process: + # save each losses and examples + os.makedirs(save_dir, exist_ok=True) + torch.save(control_examples, f"{save_dir}/control_examples.pt") + + +def sanity_check_irred_losses( + accelerator, + dataloader, + train_batch_size, + limit_control_examples=600, + seq_length=1024, + save_dir="./control_examples", +): + """sanity check of the order of loaded irreducible losses wrt batches of the current dataset""" + loaded_examples = torch.load(f"{save_dir}/control_examples.pt") + control_examples = get_control_examples( + accelerator, dataloader, train_batch_size, limit_control_examples, seq_length + ) + # check if the loaded tensors are the same as the ones we saved + assert torch.all(torch.eq(loaded_examples, control_examples)) + if accelerator.is_main_process: + print("Sanity check for irreducible loss order passed") diff --git a/examples/research_projects/codeparrot/scripts/utils_training.py b/examples/research_projects/codeparrot/scripts/utils_training.py new file mode 100644 index 0000000000..632deb6b7a --- /dev/null +++ b/examples/research_projects/codeparrot/scripts/utils_training.py @@ -0,0 +1,84 @@ +import logging +from pathlib import Path + +import datasets +import torch + +import transformers + + +def setup_logging(accelerator, args): + project_name = args.model_ckpt.split("/")[-1] + logger = logging.getLogger(__name__) + log_dir = Path(args.save_dir) / "log/" + log_dir.mkdir(exist_ok=True) + filename = f"debug_{accelerator.process_index}.log" + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()], + ) + if accelerator.is_main_process: # we only want to setup logging once + accelerator.init_trackers(project_name, vars(args)) + run_name = accelerator.trackers[0].run.name + logger.setLevel(logging.INFO) + datasets.utils.logging.set_verbosity_info() + transformers.utils.logging.set_verbosity_info() + else: + run_name = "" + logger.setLevel(logging.ERROR) + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + return logger, run_name + + +def compute_tflops(model, tokenizer, elapsed_time, accelerator, args): + # TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf). + config_model = accelerator.unwrap_model(model).config + checkpoint_factor = 4 if args.gradient_checkpointing else 3 + batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps + factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2) + flops_per_iteration = factor * ( + 1.0 + + (args.seq_length / (6.0 * config_model.n_embd)) + + (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd)) + ) + tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12)) + return tflops + + +def get_grouped_params(model, args, no_decay=["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]): + params_with_wd, params_without_wd = [], [] + for n, p in model.named_parameters(): + if any(nd in n for nd in no_decay): + params_without_wd.append(p) + else: + params_with_wd.append(p) + return [ + {"params": params_with_wd, "weight_decay": args.weight_decay}, + {"params": params_without_wd, "weight_decay": 0.0}, + ] + + +def get_lr(optimizer): + return optimizer.param_groups[0]["lr"] + + +def evaluate(model, accelerator, eval_dataloader, args): + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + outputs = model(batch, labels=batch) + loss = outputs.loss.repeat(args.valid_batch_size) + losses.append(accelerator.gather(loss)) + if args.max_eval_steps > 0 and step >= args.max_eval_steps: + break + losses = torch.cat(losses) + loss = losses[: eval_dataloader.dataset.current_size].mean() + try: + perplexity = torch.exp(loss) + except OverflowError: + perplexity = float("inf") + return loss.item(), perplexity.item()