diff --git a/train.py b/train.py index 590e815..f7a9e0b 100644 --- a/train.py +++ b/train.py @@ -111,12 +111,14 @@ def preprocess_data(data, tokenizer, max_length, overlap): return sequences -def train(model, train_loader, val_loader, optimizer, criterion, device, epochs=10): +def train(model, train_loader, val_loader, optimizer, criterion, epochs=10): """Training loop for the model.""" model.train() for epoch in range(epochs): for inputs, targets in train_loader: - inputs, targets = inputs.to(device), targets.to(device) + inputs, targets = inputs.to(accelerator.device), targets.to( + accelerator.device + ) print(f"Input shape: {inputs.shape}") print(f"Target shape: {targets.shape}") @@ -130,7 +132,7 @@ def train(model, train_loader, val_loader, optimizer, criterion, device, epochs= optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs.permute(0, 2, 1), targets) - loss.backward() + accelerator.backward(loss) optimizer.step() # Validation step @@ -138,7 +140,9 @@ def train(model, train_loader, val_loader, optimizer, criterion, device, epochs= val_loss = 0 with torch.no_grad(): for inputs, targets in val_loader: - inputs, targets = inputs.to(device), targets.to(device) + inputs, targets = inputs.to(accelerator.device), targets.to( + accelerator.device + ) outputs = model(inputs) loss = criterion(outputs.permute(0, 2, 1), targets) val_loss += loss.item()