diff --git a/src/api.py b/src/api.py index 36c257a..e60d97a 100644 --- a/src/api.py +++ b/src/api.py @@ -1,22 +1,24 @@ -from fastapi import FastAPI, UploadFile, File -from PIL import Image import torch +from fastapi import FastAPI, File, UploadFile +from PIL import Image from torchvision import transforms + +from cnn import CNN from main import Net # Importing Net class from main.py # Load the model -model = Net() +model = CNN() model.load_state_dict(torch.load("mnist_model.pth")) model.eval() # Transform used for preprocessing the image -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) app = FastAPI() + @app.post("/predict/") async def predict(file: UploadFile = File(...)): image = Image.open(file.file).convert("L") diff --git a/src/cnn.py b/src/cnn.py new file mode 100644 index 0000000..1fa75b2 --- /dev/null +++ b/src/cnn.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CNN(nn.Module): + def __init__(self): + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout2d(0.25) + self.dropout2 = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output diff --git a/src/main.py b/src/main.py index 243a31e..774cc5d 100644 --- a/src/main.py +++ b/src/main.py @@ -1,20 +1,22 @@ -from PIL import Image +import numpy as np import torch import torch.nn as nn import torch.optim as optim -from torchvision import datasets, transforms +from PIL import Image from torch.utils.data import DataLoader -import numpy as np +from torchvision import datasets, transforms + +from cnn import CNN # Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) +trainset = datasets.MNIST(".", download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + # Step 2: Define the PyTorch Model class Net(nn.Module): def __init__(self): @@ -22,7 +24,7 @@ def __init__(self): self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) - + def forward(self, x): x = x.view(-1, 28 * 28) x = nn.functional.relu(self.fc1(x)) @@ -30,8 +32,9 @@ def forward(self, x): x = self.fc3(x) return nn.functional.log_softmax(x, dim=1) + # Step 3: Train the Model -model = Net() +model = CNN() optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() @@ -45,4 +48,4 @@ def forward(self, x): loss.backward() optimizer.step() -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file +torch.save(model.state_dict(), "mnist_model.pth")