-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
37 lines (25 loc) · 928 Bytes
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import io
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from PIL import Image
def model_fn(model_dir):
model = models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
num_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(num_features, 5))
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
model.load_state_dict(torch.load(f))
return model
def input_fn(request_body, content_type):
image = Image.open(io.BytesIO(request_body))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transformation = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(), normalize])
return transformation(image).unsqueeze(0)