Skip to content

Commit

Permalink
feat: Updated src/api.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Nov 25, 2023
1 parent 79fa4ad commit e6e1d5f
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from fastapi import FastAPI, UploadFile, File
from PIL import Image
"""
This module provides an API endpoint for predicting the digit in an uploaded image using a pre-trained PyTorch model.
The API endpoint '/predict/' accepts POST requests with an image file, preprocesses the image, and returns the predicted digit.
"""
import torch
from torchvision import transforms
from fastapi import FastAPI, File, UploadFile
from main import Net # Importing Net class from main.py
from PIL import Image
from torchvision import transforms

# Load the model
model = Net()
Expand All @@ -26,3 +31,20 @@ async def predict(file: UploadFile = File(...)):
output = model(image)
_, predicted = torch.max(output.data, 1)
return {"prediction": int(predicted[0])}
output = model(image)
_, predicted = torch.max(output.data, 1)
return {"prediction": int(predicted[0])}
return {"prediction": int(predicted[0])}
Parameters:
- file (UploadFile): The image file to predict.

Returns:
- dict: A dictionary with the key 'prediction' and the predicted digit as the value.
"""
image = Image.open(file.file).convert("L")
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)
return {"prediction": int(predicted[0])}

0 comments on commit e6e1d5f

Please sign in to comment.