From 35d113cc33410edb513987f9c8cec297e100b4d9 Mon Sep 17 00:00:00 2001 From: Aryan Pandey Date: Fri, 1 Dec 2023 07:33:17 +0530 Subject: [PATCH] fix #397 LLM-VM does not support multiple GPUs currently --- src/llm_vm/onsite_llm.py | 134 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/src/llm_vm/onsite_llm.py b/src/llm_vm/onsite_llm.py index f78db919..aabc09a6 100644 --- a/src/llm_vm/onsite_llm.py +++ b/src/llm_vm/onsite_llm.py @@ -28,6 +28,140 @@ from sentence_transformers import SentenceTransformer +def train_func(config): + """Your training function that will be launched on each worker.""" + + # Unpack training configs + lr = config["lr"] + seed = config["seed"] + num_epochs = config["num_epochs"] + train_batch_size = config["train_batch_size"] + eval_batch_size = config["eval_batch_size"] + train_ds_size = config["train_dataset_size"] + + set_seed(seed) + + # Initialize accelerator + accelerator = Accelerator() + + # Load datasets and metrics + metric = evaluate.load("glue", "mrpc") + + # Prepare Ray Data loaders + # ==================================================== + train_ds = ray.train.get_dataset_shard("train") + eval_ds = ray.train.get_dataset_shard("validation") + + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + + def collate_fn(batch): + outputs = tokenizer( + list(batch["sentence1"]), + list(batch["sentence2"]), + truncation=True, + padding="longest", + return_tensors="pt", + ) + outputs["labels"] = torch.LongTensor(batch["label"]) + outputs = {k: v.to(accelerator.device) for k, v in outputs.items()} + return outputs + + train_dataloader = train_ds.iter_torch_batches( + batch_size=train_batch_size, collate_fn=collate_fn + ) + eval_dataloader = eval_ds.iter_torch_batches( + batch_size=eval_batch_size, collate_fn=collate_fn + ) + # ==================================================== + + # Instantiate the model, optimizer, lr_scheduler + model = AutoModelForSequenceClassification.from_pretrained( + "bert-base-cased", return_dict=True + ) + + optimizer = AdamW(params=model.parameters(), lr=lr) + + steps_per_epoch = train_ds_size // (accelerator.num_processes * train_batch_size) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=(steps_per_epoch * num_epochs), + ) + + # Prepare everything with accelerator + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + for epoch in range(num_epochs): + # Training + model.train() + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Evaluation + model.eval() + for batch in eval_dataloader: + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + + predictions, references = accelerator.gather_for_metrics( + (predictions, batch["labels"]) + ) + metric.add_batch( + predictions=predictions, + references=references, + ) + + eval_metric = metric.compute() + accelerator.print(f"epoch {epoch}:", eval_metric) + + # Report Checkpoint and metrics to Ray Train + # ========================================== + with TemporaryDirectory() as tmpdir: + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + accelerator.save(unwrapped_model, f"{tmpdir}/ckpt_{epoch}.bin") + checkpoint = Checkpoint.from_directory(tmpdir) + else: + checkpoint = None + ray.train.report(metrics=eval_metric, checkpoint=checkpoint) + + +if __name__ == "__main__": + config = { + "lr": 2e-5, + "num_epochs": 3, + "seed": 42, + "train_batch_size": 16, + "eval_batch_size": 32, + } + + # Prepare Ray Datasets + hf_datasets = load_dataset("glue", "mrpc") + ray_datasets = { + "train": ray.data.from_huggingface(hf_datasets["train"]), + "validation": ray.data.from_huggingface(hf_datasets["validation"]), + } + config["train_dataset_size"] = ray_datasets["train"].count() + + trainer = TorchTrainer( + train_func, + train_loop_config=config, + datasets=ray_datasets, + dataset_config=DataConfig(datasets_to_split=["train", "validation"]), + scaling_config=ScalingConfig(num_workers=4, use_gpu=True), + ) + + result = trainer.fit() + +# __accelerate_torch_basic_example_end__ + + __private_key_value_models_map = {} # [] {