-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_loaders.py
122 lines (103 loc) · 4.19 KB
/
data_loaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from torchvision.transforms import transforms
from PIL import Image
import math
__all__ = ['CIFAR10DataLoader', 'ImageNetDataLoader', 'CIFAR100DataLoader']
class CIFAR10DataLoader(DataLoader):
def __init__(self, data_dir, split='train', image_size=224, batch_size=16, num_workers=8):
if split == 'train':
train = True
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
else:
train = False
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
self.dataset = CIFAR10(root=data_dir, train=train, transform=transform, download=True)
super(CIFAR10DataLoader, self).__init__(
dataset=self.dataset,
batch_size=batch_size,
shuffle=False if not train else True,
num_workers=num_workers)
class CIFAR100DataLoader(DataLoader):
def __init__(self, data_dir, split='train', image_size=224, batch_size=16, num_workers=8):
if split == 'train':
train = True
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
else:
train = False
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
self.dataset = CIFAR100(root=data_dir, train=train, transform=transform, download=True)
super(CIFAR100DataLoader, self).__init__(
dataset=self.dataset,
batch_size=batch_size,
shuffle=False if not train else True,
num_workers=num_workers)
class ImageNetDataLoader(DataLoader):
def __init__(self, model, data_dir, split='train', image_size=224, batch_size=16, num_workers=8):
if 'vit' in model:
transform = build_transform(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop_pct=0.9)
else:
transform = build_transform(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], crop_pct=0.875)
self.dataset = ImageFolder(root=os.path.join(data_dir, split), transform=transform)
super(ImageNetDataLoader, self).__init__(
dataset=self.dataset,
batch_size=batch_size,
shuffle=True if split == 'train' else False,
num_workers=num_workers,
pin_memory=True)
def build_transform(input_size=224,
interpolation='bicubic',
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
crop_pct=0.875):
def _pil_interp(method):
if method == 'bicubic':
return Image.BICUBIC
elif method == 'lanczos':
return Image.LANCZOS
elif method == 'hamming':
return Image.HAMMING
else:
return Image.BILINEAR
resize_im = input_size > 32
t = []
if resize_im:
size = int(math.floor(input_size / crop_pct))
ip = _pil_interp(interpolation)
t.append(
transforms.Resize(
size,
interpolation=ip), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)
if __name__ == '__main__':
data_loader = ImageNetDataLoader(
data_dir='/home/hchen/Projects/vat_contrast/data/ImageNet',
split='val',
image_size=384,
batch_size=16,
num_workers=0)
for images, targets in data_loader:
print(targets)