forked from Jianbo-Lab/HSJA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_data.py
147 lines (87 loc) · 3.18 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import absolute_import, division, print_function
import numpy as np
import tensorflow as tf
import os
import time
import numpy as np
import sys
import os
import keras
import math
from keras.utils import to_categorical
class ImageData():
def __init__(self, dataset_name):
if dataset_name == 'mnist':
from keras.datasets import mnist
(x_train, y_train), (x_val, y_val) = mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_val = x_val.reshape(x_val.shape[0], 28, 28, 1)
elif dataset_name == 'cifar100':
from keras.datasets import cifar100
(x_train, y_train), (x_val, y_val) = cifar100.load_data()
elif dataset_name == 'cifar10':
from keras.datasets import cifar10
# Load CIFAR10 Dataset
(x_train, y_train), (x_val, y_val) = cifar10.load_data()
x_train = x_train.astype('float32')/255
x_val = x_val.astype('float32')/255
y_train = to_categorical(y_train)
y_val = to_categorical(y_val)
x_train_mean = np.zeros(x_train.shape[1:])
x_train -= x_train_mean
x_val -= x_train_mean
self.clip_min = 0.0
self.clip_max = 1.0
self.x_train = x_train
self.x_val = x_val
self.y_train = y_train
self.y_val = y_val
self.x_train_mean = x_train_mean
def split_data(x, y, model, num_classes = 10, split_rate = 0.8, sample_per_class = 100):
# print('x.shape', x.shape)
# print('y.shape', y.shape)
np.random.seed(10086)
pred = model.predict(x)
label_pred = np.argmax(pred, axis = 1)
label_truth = np.argmax(y, axis = 1)
correct_idx = label_pred==label_truth
print('Accuracy is {}'.format(np.mean(correct_idx)))
x, y = x[correct_idx], y[correct_idx]
label_pred = label_pred[correct_idx]
x_train, x_test, y_train, y_test = [], [], [], []
for class_id in range(num_classes):
_x = x[label_pred == class_id][:sample_per_class]
_y = y[label_pred == class_id][:sample_per_class]
l = len(_x)
x_train.append(_x[:int(l * split_rate)])
x_test.append(_x[int(l * split_rate):])
y_train.append(_y[:int(l * split_rate)])
y_test.append(_y[int(l * split_rate):])
x_train = np.concatenate(x_train, axis = 0)
x_test = np.concatenate(x_test, axis = 0)
y_train = np.concatenate(y_train, axis = 0)
y_test = np.concatenate(y_test, axis = 0)
idx_train = np.random.permutation(len(x_train))
idx_test = np.random.permutation(len(x_test))
x_train = x_train[idx_train]
y_train = y_train[idx_train]
x_test = x_test[idx_test]
y_test = y_test[idx_test]
return x_train, y_train, x_test, y_test
if __name__ == '__main__':
import argparse
from build_model import ImageModel
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type = str,
choices = ['mnist', 'cifar10', 'cifar100'],
default = 'mnist')
parser.add_argument('--model_name', type = str,
choices = ['cnn', 'resnet', 'densenet'],
default = 'cnn')
args = parser.parse_args()
dict_a = vars(args)
data_model = args.dataset_name + args.model_name
dataset = ImageData(args.dataset_name)
model = ImageModel(args.model_name, args.dataset_name, train = False, load = True)
x, y = dataset.x_val, dataset.y_val
x_train, y_train, x_test, y_test = split_data(x, y, model, num_classes = 10, split_rate = 0.8)