-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_model.py
44 lines (36 loc) · 1.14 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from pathlib import Path
import numpy as np
import torchvision.transforms as transforms
import torchvision.utils
from torchvision.utils import save_image
from digit_recognition import Net
# get input image
#image_path = 'opencv_frame_0.png'
image_path = 'printed_digit/train/3/0_0_7.jpeg'
# processing -> resize and turn B&W (white on black)
og_img = Image.open(image_path)
img = transforms.Compose([transforms.ToTensor(),
transforms.Resize((28, 28)),
transforms.Grayscale(1),
transforms.Normalize((0.5,), (0.5,))])(og_img)
#torchvision.transforms.functional.invert(img)
input = img.reshape(-1, 784) # grayscale color channel transforms.Grayscale
# load model
model_dir = 'models/transfer-ce-sgd-4.pt'
state = torch.load(model_dir)
model = Net()
try:
state.eval()
except AttributeError as error:
print(error)
model.load_state_dict(state['model_state_dict'])
model.eval()
# make prediction
with torch.no_grad():
output = model.predict(input)
print(output)
print(torch.argmax(output))