-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_loader_custom_eval.py
67 lines (49 loc) · 2 KB
/
data_loader_custom_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch.autograd import Variable as V
NUM_WORKERS = 4
def get_cifar(num_classes=10, dataset_dir='./data', batch_size=128, crop=False):
"""
:param num_classes: 10 for cifar10, 100 for cifar100
:param dataset_dir: location of datasets, default is a directory named 'data'
:param batch_size: batchsize, default to 128
:param crop: whether or not use randomized horizontal crop, default to False
:return:
"""
normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
simple_transform = transforms.Compose([transforms.ToTensor(), normalize])
if crop is True:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
else:
train_transform = simple_transform
if num_classes == 100:
trainset = torchvision.datasets.CIFAR100(root=dataset_dir, train=True,
download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR100(root=dataset_dir, train=False,
download=True, transform=simple_transform)
else:
trainset = torchvision.datasets.CIFAR10(root=dataset_dir, train=True,
download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR10(root=dataset_dir, train=False,
download=True, transform=simple_transform)
#print(testset)
#print(testset[1].shape)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, num_workers=NUM_WORKERS,
pin_memory=True, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, num_workers=NUM_WORKERS,
pin_memory=True, shuffle=False)
print("data set is loaded")
#for images, _ in testloader:
# print('images.shape:', images.shape)
# plt.figure(figsize=(16,8))
# plt.axis('off')
# plt.imshow(make_grid(images, nrow=16).permute((1, 2, 0)))
# break
return trainloader, testloader