-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
88 lines (76 loc) · 2.59 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from __future__ import print_function, division
import torch
import torch.nn as nn
import argparse
import time
import os
import cv2
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from PIL import Image as pil_image
from tqdm import tqdm
from network.classifier import *
from network.transform import mesonet_data_transforms
def preprocess_image(image, cuda=True):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
preprocess = mesonet_data_transforms['test']
preprocessed_image = preprocess(pil_image.fromarray(image))
preprocessed_image = preprocessed_image.unsqueeze(0)
cuda = True
if cuda:
preprocessed_image = preprocessed_image.cuda()
return preprocessed_image
def predict_with_model(image, model, post_function=nn.Softmax(dim=1), cuda=True):
cuda = True
preprocessed_image = preprocess_image(image, cuda)
output = model(preprocessed_image)
output = post_function(output)
#print(output)
_, prediction = torch.max(output, 1)
prediction = float(prediction.cpu().numpy())
return int(prediction), output
def test_images(images_path_1, images_path_2, model_path, cuda=True):
if model_path is not None:
model = torch.load(model_path, map_location='cpu')
print('Model found in {}'.format(model_path))
else:
print('No model found, please check it!')
cuda = True
if cuda:
model = model.cuda()
fake_count = 0
real_count = 0
images_list = os.listdir(images_path_1)
for images in images_list:
image = cv2.imread(os.path.join(images_path_1, images))
prediction, output = predict_with_model(image, model, cuda=cuda)
#print(prediction, output)
if prediction == 0:
fake_count += 1
else:
real_count += 1
print("Testing real images: ")
print("fake frame is:", fake_count)
print("real frame is:", real_count)
fake_count = 0
real_count = 0
images_list = os.listdir(images_path_2)
for images in images_list:
image = cv2.imread(os.path.join(images_path_2, images))
prediction, output = predict_with_model(image, model, cuda=cuda)
#print(prediction, output)
if prediction == 0:
fake_count += 1
else:
real_count += 1
print("Testing fake images: ")
print("fake frame is:", fake_count)
print("real frame is:", real_count)
if __name__ == '__main__':
p = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('--images_path_1', '-i1', default="extended_database/test/real", type=str)
p.add_argument('--images_path_2', '-i2', default="extended_database/test/fake", type=str)
p.add_argument('--model_path', '-mi', default="output/mobilenetv3_drop_4_layer.pkl", type=str)
args = p.parse_args()
test_images(**vars(args))