diff --git a/src/main.py b/src/main.py index 243a31e..eb11d9c 100644 --- a/src/main.py +++ b/src/main.py @@ -1,10 +1,10 @@ -from PIL import Image +import numpy as np import torch import torch.nn as nn import torch.optim as optim -from torchvision import datasets, transforms +from PIL import Image from torch.utils.data import DataLoader -import numpy as np +from torchvision import datasets, transforms # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ @@ -29,20 +29,29 @@ def forward(self, x): x = nn.functional.relu(self.fc2(x)) x = self.fc3(x) return nn.functional.log_softmax(x, dim=1) +class Trainer: + def __init__(self, learning_rate, model_path): + self.model = Net() + self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate) + self.criterion = nn.NLLLoss() + self.model_path = model_path -# Step 3: Train the Model -model = Net() -optimizer = optim.SGD(model.parameters(), lr=0.01) -criterion = nn.NLLLoss() + def train(self, epochs): + for epoch in range(epochs): + for images, labels in trainloader: + self.optimizer.zero_grad() + output = self.model(images) + loss = self.criterion(output, labels) + loss.backward() + self.optimizer.step() + + def save_model(self): + torch.save(self.model.state_dict(), self.model_path) -# Training loop -epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file +# Step 3: Train the Model + +# Now let's create a Trainer instance and train and save the model +trainer = Trainer(learning_rate=0.01, model_path="mnist_model.pth") +trainer.train(epochs=3) +trainer.save_model()