-
Notifications
You must be signed in to change notification settings - Fork 423
/
omniglot.py
109 lines (90 loc) · 3.54 KB
/
omniglot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch.utils.data as data
import os
import os.path
import errno
class Omniglot(data.Dataset):
urls = [
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
'''
The items are (filename,category). The index of all the categories can be found in self.idx_classes
Args:
- root: the directory where the dataset will be stored
- transform: how to transform the input
- target_transform: how to transform the target
- download: need to download the dataset
'''
def __init__(self, root, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
if not self._check_exists():
if download:
self.download()
else:
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
self.idx_classes = index_classes(self.all_items)
def __getitem__(self, index):
filename = self.all_items[index][0]
img = str.join('/', [self.all_items[index][2], filename])
target = self.idx_classes[self.all_items[index][1]]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.all_items)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
def download(self):
from six.moves import urllib
import zipfile
if self._check_exists():
return
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
for url in self.urls:
print('== Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
file_processed = os.path.join(self.root, self.processed_folder)
print("== Unzip from " + file_path + " to " + file_processed)
zip_ref = zipfile.ZipFile(file_path, 'r')
zip_ref.extractall(file_processed)
zip_ref.close()
print("Download finished.")
def find_classes(root_dir):
retour = []
for (root, dirs, files) in os.walk(root_dir):
for f in files:
if (f.endswith("png")):
r = root.split('/')
lr = len(r)
retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
print("== Found %d items " % len(retour))
return retour
def index_classes(items):
idx = {}
for i in items:
if i[1] not in idx:
idx[i[1]] = len(idx)
print("== Found %d classes" % len(idx))
return idx