Skip to content

Commit

Permalink
fix anarchy-ai#397 LLM-VM does not support multiple GPUs currently
Browse files Browse the repository at this point in the history
  • Loading branch information
Aryan8912 committed Dec 1, 2023
1 parent 3a65d0c commit 35d113c
Showing 1 changed file with 134 additions and 0 deletions.
134 changes: 134 additions & 0 deletions src/llm_vm/onsite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,140 @@
from sentence_transformers import SentenceTransformer


def train_func(config):
"""Your training function that will be launched on each worker."""

# Unpack training configs
lr = config["lr"]
seed = config["seed"]
num_epochs = config["num_epochs"]
train_batch_size = config["train_batch_size"]
eval_batch_size = config["eval_batch_size"]
train_ds_size = config["train_dataset_size"]

set_seed(seed)

# Initialize accelerator
accelerator = Accelerator()

# Load datasets and metrics
metric = evaluate.load("glue", "mrpc")

# Prepare Ray Data loaders
# ====================================================
train_ds = ray.train.get_dataset_shard("train")
eval_ds = ray.train.get_dataset_shard("validation")

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

def collate_fn(batch):
outputs = tokenizer(
list(batch["sentence1"]),
list(batch["sentence2"]),
truncation=True,
padding="longest",
return_tensors="pt",
)
outputs["labels"] = torch.LongTensor(batch["label"])
outputs = {k: v.to(accelerator.device) for k, v in outputs.items()}
return outputs

train_dataloader = train_ds.iter_torch_batches(
batch_size=train_batch_size, collate_fn=collate_fn
)
eval_dataloader = eval_ds.iter_torch_batches(
batch_size=eval_batch_size, collate_fn=collate_fn
)
# ====================================================

# Instantiate the model, optimizer, lr_scheduler
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-cased", return_dict=True
)

optimizer = AdamW(params=model.parameters(), lr=lr)

steps_per_epoch = train_ds_size // (accelerator.num_processes * train_batch_size)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=(steps_per_epoch * num_epochs),
)

# Prepare everything with accelerator
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)

for epoch in range(num_epochs):
# Training
model.train()
for batch in train_dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

# Evaluation
model.eval()
for batch in eval_dataloader:
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)

predictions, references = accelerator.gather_for_metrics(
(predictions, batch["labels"])
)
metric.add_batch(
predictions=predictions,
references=references,
)

eval_metric = metric.compute()
accelerator.print(f"epoch {epoch}:", eval_metric)

# Report Checkpoint and metrics to Ray Train
# ==========================================
with TemporaryDirectory() as tmpdir:
if accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
accelerator.save(unwrapped_model, f"{tmpdir}/ckpt_{epoch}.bin")
checkpoint = Checkpoint.from_directory(tmpdir)
else:
checkpoint = None
ray.train.report(metrics=eval_metric, checkpoint=checkpoint)


if __name__ == "__main__":
config = {
"lr": 2e-5,
"num_epochs": 3,
"seed": 42,
"train_batch_size": 16,
"eval_batch_size": 32,
}

# Prepare Ray Datasets
hf_datasets = load_dataset("glue", "mrpc")
ray_datasets = {
"train": ray.data.from_huggingface(hf_datasets["train"]),
"validation": ray.data.from_huggingface(hf_datasets["validation"]),
}
config["train_dataset_size"] = ray_datasets["train"].count()

trainer = TorchTrainer(
train_func,
train_loop_config=config,
datasets=ray_datasets,
dataset_config=DataConfig(datasets_to_split=["train", "validation"]),
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
)

result = trainer.fit()

# __accelerate_torch_basic_example_end__



__private_key_value_models_map = {}
# [] {
Expand Down

0 comments on commit 35d113c

Please sign in to comment.