-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathprepare_data.py
130 lines (112 loc) · 4.44 KB
/
prepare_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# coding: utf-8
import os, shutil, random
# Read label.csv
# For each task, make folders, and copy picture to corresponding folders
label_train_dir = './data2/train/Annotations/label.csv'
label_base_dir = './data1/base/Annotations/label.csv'
label_dir_testa = './data1/fashionAI_attributes_answer_a_20180428.csv'
label_dir_testb = './data1/fashionAI_attributes_answer_b_20180428.csv'
label_dict = {'coat_length_labels': [],
'lapel_design_labels': [],
'neckline_design_labels': [],
'skirt_length_labels': [],
'collar_design_labels': [],
'neck_design_labels': [],
'pant_length_labels': [],
'sleeve_length_labels': []}
task_list = label_dict.keys()
def mkdir_if_not_exist(path):
if not os.path.exists(os.path.join(*path)):
os.makedirs(os.path.join(*path))
all_lable = {}
with open(label_train_dir, 'r') as f:
lines = f.readlines()
tokens = [l.rstrip().split(',') for l in lines]
for path, task, label in tokens:
if path in all_lable:
continue
else:
all_lable.setdefault(path,[])
path1='data2/train/'+path
label_dict[task].append((path1, label))
with open(label_base_dir, 'r') as f:
lines = f.readlines()
tokens = [l.rstrip().split(',') for l in lines]
for path, task, label in tokens:
if path in all_lable:
continue
else:
all_lable.setdefault(path,[])
path1='data1/base/'+path
label_dict[task].append((path1, label))
with open(label_dir_testa, 'r') as f:
lines = f.readlines()
tokens = [l.rstrip().split(',') for l in lines]
for path, task, label in tokens:
if path in all_lable:
continue
else:
all_lable.setdefault(path,[])
path1='data1/rank/'+path
label_dict[task].append((path1, label))
with open(label_dir_testb, 'r') as f:
lines = f.readlines()
tokens = [l.rstrip().split(',') for l in lines]
for path, task, label in tokens:
if path in all_lable:
continue
else:
all_lable.setdefault(path,[])
path1='data1/z_rank/'+path
label_dict[task].append((path1, label))
mkdir_if_not_exist(['data2/train_valid_allset'])
for task, path_label in label_dict.items():
mkdir_if_not_exist(['data2/train_valid_allset', task])
train_count = 0 # 对每一类都要重新置0
n = len(path_label) # 每个task有多少条数据
m = len(list(path_label[0][1])) # 每个task有几类
for mm in range(m):
mkdir_if_not_exist(['data2/train_valid_allset', task, 'train', str(mm)])
mkdir_if_not_exist(['data2/train_valid_allset', task, 'val', str(mm)])
random.seed(2018)
random.shuffle(path_label)
for path, label in path_label:
label_index = list(label).index('y')
# if 'm' in label: # 如果存在m标签,就寻找m标签
# m_index = list(label).index('m')
# label_index = label_index if random.randint(1,5)>2 else m_index # 60%选择y对应的label,40%选择m对应的label
src_path = os.path.join(path)
if train_count < n * 0.95:
shutil.copy(src_path,
os.path.join('data2/train_valid_allset', task, 'train', str(label_index)))
else:
shutil.copy(src_path,
os.path.join('data2/train_valid_allset', task, 'val', str(label_index)))
train_count += 1
print( ' finished ' + task)
print( ' all finished!')
# Add warmup data to skirt task
# label_dict = {'skirt_length_labels': []}
#
# with open(warmup_label_dir, 'r') as f:
# lines = f.readlines()
# tokens = [l.rstrip().split(',') for l in lines]
# for path, task, label in tokens:
# label_dict[task].append((path, label))
#
# for task, path_label in label_dict.items():
# train_count = 0
# n = len(path_label)
# m = len(list(path_label[0][1]))
#
# random.shuffle(path_label)
# for path, label in path_label:
# label_index = list(label).index('y')
# src_path = os.path.join('data/web', path)
# if train_count < n * 0.9:
# shutil.copy(src_path,
# os.path.join('data/train_valid', task, 'train', str(label_index)))
# else:
# shutil.copy(src_path,
# os.path.join('data/train_valid', task, 'val', str(label_index)))
# train_count += 1