-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_preprocess.py
384 lines (341 loc) · 15.9 KB
/
data_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
# !/usr/bin/python
# -*- coding: utf-8 -*-
# @Time: 2020/1/16 14:28
# @Author: Casually
# @File: data_preprocess.py
# @Email: [email protected]
# @Software: PyCharm
import json, random, os, sys, time
from glob import glob
from pyecharts import options as opts
from pyecharts.charts import Bar
from PIL import Image
import seaborn as sns # 导入可视化库
import matplotlib.pyplot as plt
import numpy as np
import webbrowser # 调用浏览器
import shutil
from collections import Counter # 统计
import torch
import torchvision
from torchvision import datasets, models, transforms
from PIL import Image
# from os import walk
from torch.backends import cudnn
# 整体探测
# class detection():
# 写入json
def __json_store(path, data):
'''
将字典写入json文件
:param path: 存储目录
:param data: 待写入数据
:return:
'''
with open(path, 'w', ) as fw:
json.dump(data, fw, indent=1, ensure_ascii=False)
# 加载json
def __json_load(path):
'''
读取json文件
:param path: 文件目录
:return: dict
'''
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
# 取得该路径下所所有image的文件名、路径及标签
def get_img_info(path):
'''
取得该路径下所所有image的文件名、路径及标签
:param path:train_data path
:return img_path_list : dict{'img_name', 'img_lable', 'img_path' }
:return img_name2label_dict : dict{'img_name', 'img_lable' }
:return img_label2count_dict : dict{'img_name', 'count' }
'''
img_path_list = []
img_name2label_dict = {}
img_label2count_dict = {}
data_path_txt = os.path.join(path, 'All_data/*.txt')
txt_file_list = glob(data_path_txt) # 所有.txt文件的列表
garbage_classify_rule = __json_load(os.path.join(path, 'garbage_classify_rule.json'))
garbage_index_classify = __json_load(os.path.join(path, 'garbage_index_classify.json'))
for item in garbage_classify_rule: # 更新字典 40类更新为4类
garbage_classify_rule[item] = garbage_index_classify[garbage_classify_rule[item].split('/')[0]]
# 遍历文件列表
for i, file in enumerate(txt_file_list):
__processBar('{} 提取数据'.format(time.strftime("[%d/%b/%Y %H:%M:%S]", time.localtime())), i + 1, len(txt_file_list))
with open(file, 'r') as f:
line = f.readline().split(',') # [0] img_name,[1]img_lable
img_name2label_dict[line[0]] = line[1]
img_path = os.path.join(path, 'All_data/{}'.format(line[0]))
img_path_list.append({'img_name': line[0], 'img_lable': int(line[1]),
'img_4lable': int(garbage_classify_rule[line[1].strip()]), 'img_path': img_path})
img_label2count_dict[int(line[1])] = img_label2count_dict.get(int(line[1]), 0) + 1
img_label2count_dict = dict(
sorted(img_label2count_dict.items(), key=lambda item: item[0])) # 对img_label2count_dict内部按照Key进行从小到大排序
# img_label2count_dict = dict(
# sorted(img_label2count_dict.items(), key=lambda item: item[1])) # 对img_label2count_dict内部按照Value进行从小到大排序
return img_path_list, img_name2label_dict, img_label2count_dict
# 取得该路径下所所有image的ID、尺寸、标签
def get_img_size(path):
'''
获取训练集中图片规格,返回一个包含(id, width, height, ratio, label)的列表
:param path: train_data path
:return : list(id, width, height, ratio, label)
'''
data = []
_, img_name2label_dict, _ = get_img_info(path)
data_path = os.path.join(path, 'All_data')
img_file_path = os.path.join(data_path, '*.jpg')
imgs_path = glob(img_file_path)
for i, img_path in enumerate(imgs_path):
__processBar('{} 提取图片信息'.format(time.strftime("[%d/%b/%Y %H:%M:%S]", time.localtime())), i + 1, len(imgs_path))
img_id = img_path.split('_')[-1].split('.')[0]
img_label = img_name2label_dict['img_{}.jpg'.format(img_id)]
img = Image.open(img_path)
data.append([int(img_id), img.size[0], img.size[1], float('{:.2f}'.format(img.size[0] / img.size[1])),
int(img_label)])
return data
# 将train_data拆分为train集合verify集了两部分
def data_shuffle_split(size, img_path_list):
'''
将train_data拆分为train集合verify集了两部分
:param size: 训练集所占比例
:param img_path_list: get_img_info的一个return
:return:
'''
# 原始数据进行随机排序清洗
random.shuffle(img_path_list)
# 设置数据分布分布比例 训练:验证 = size:1-size
if isinstance(size, float) and 0 < size < 1:
# train_img_dict = {}
# verify_img_dict = {}
path = img_path_list[0]['img_path'].split(img_path_list[0]['img_name'])[0].split('All_data')[0]
# 取得分界
train_size = int(len(img_path_list) * size)
# 截取训练集
train_img_list = img_path_list[:train_size]
train_img_dict = dict([[item['img_path'], item['img_4lable']] for item in train_img_list])
train_json_path = os.path.join(path, 'train_img_dict.json')
# 训练集写入json文件
__json_store(train_json_path, train_img_dict)
# 截取测试集
verify_img_list = img_path_list[train_size:]
verify_img_dict = dict([[item['img_path'], item['img_4lable']] for item in verify_img_list])
verify_json_path = os.path.join(path, 'verify_img_dict.json')
# 测试集写入json文件
__json_store(verify_json_path, verify_img_dict)
# with open(os.path.join(path,'train_img_list.txt'),'w') as f:
# for dict in train_img_list:
# f.write('{},{}\n'.format(dict['img_path'],dict['img_lable']))
# with open(os.path.join(path, 'verify_img_list.txt'), 'w') as f:
# for dict in verify_img_list:
# f.write('{},{}\n'.format(dict['img_path'], dict['img_lable']))
# 清空目录,避免函数重复调用时数据重复
print('{} 清空目录数据: '.format(time.strftime("[%d/%b/%Y %H:%M:%S]", time.localtime())), end='')
shutil.rmtree(os.path.join(path, 'train_data'), ignore_errors=True)
print('完成')
# 图片-标签目录
for i, item in enumerate(train_img_list):
train_dir = os.path.join(path, 'train_data/train/{}'.format(item['img_4lable']))
# 目录创建
if not os.path.exists(train_dir):
os.makedirs(train_dir)
# 图片数据进行拷贝
shutil.copy(item['img_path'], train_dir)
__processBar('{} 拷贝训练集'.format(time.strftime("[%d/%b/%Y %H:%M:%S]", time.localtime())), i + 1,
len(train_img_list))
for i, item in enumerate(verify_img_list):
verify_dir = os.path.join(path, 'train_data/val/{}'.format(item['img_4lable']))
# 目录创建
if not os.path.exists(verify_dir):
os.makedirs(verify_dir)
# 图片数据进行拷贝
shutil.copy(item['img_path'], verify_dir)
__processBar('{} 拷贝验证集'.format(time.strftime("[%d/%b/%Y %H:%M:%S]", time.localtime())), i + 1,
len(verify_img_list))
return train_json_path, verify_json_path, train_img_list, verify_img_list
else:
raise TypeError('<size> 参数必须是在(0,1)上的浮点数')
# 数据可视化
def data_visualization(data_path, show_web=True,
size1={'width': 1500, 'height': 500, 'rotate': 15},
size2={'width': 1000, 'height': 500, 'rotate': 0}):
'''
用于展示train_data中全部数据集和分割后的训练集与验证集的各类数据分布
用于展示train_data中图片尺寸数据分布
:param data_path: train_data path
:param show_web: 自动调用浏览器展示表格,默认值为True
:param size1: {'width': 1500, 'height': 500, 'rotate': 15}
设置表格宽 默认值1500 单位px
设置表格高 默认值500 单位px
x轴数据倾斜角度 默认值15
:param size2: {'width': 500, 'height': 500, 'rotate': 15}
设置表格宽 默认值500 单位px
设置表格高 默认值500 单位px
x轴数据倾斜角度 默认值0
:return:
'''
train_json_path = os.path.join(data_path.lstrip('/All_data'), 'train_img_dict.json')
if not os.path.exists(train_json_path):
raise FileNotFoundError('请先执行 data_shuffle_split 方法划分数据集')
train_json_dict = __json_load(train_json_path)
train_dict = dict(Counter(train_json_dict.values())) # Counter统计元素出现的次数
train_dict = dict(sorted(train_dict.items()))
verify_json_path = os.path.join(data_path.lstrip('/All_data'), 'verify_img_dict.json')
if not os.path.exists(verify_json_path):
raise FileNotFoundError('请先执行 data_shuffle_split 方法划分数据集')
verify_json_dict = __json_load(verify_json_path)
verify_dict = dict(Counter(verify_json_dict.values())) # Counter统计元素出现的次数
verify_dict = dict(sorted(verify_dict.items()))
# 校验数据
print('校验数据: ', end='')
assert train_dict.keys() == verify_dict.keys()
print('完成')
# 40分类数据可视化
garbage_classify_rule = __json_load(
os.path.join(data_path.lstrip('/All_data'), 'garbage_classify_rule.json'))
# x轴输数据
x = ["{}-{}".format(id, garbage_classify_rule[str(id)]) for id in train_dict.keys()]
img_path_list, _, img_label2count_dict = get_img_info(data_path)
# y轴数据
data_y = list(img_label2count_dict.values())
train_y = list(train_dict.values())
verify_y = list(verify_dict.values())
# 创建Bar示例
bar = Bar(init_opts=opts.InitOpts(width='{}px'.format(size1['width']), height='{}px'.format(size1['height'])))
bar.add_xaxis(xaxis_data=x)
bar.add_yaxis(series_name='All', yaxis_data=data_y)
bar.add_yaxis(series_name='Train', yaxis_data=train_y)
bar.add_yaxis(series_name='Verify', yaxis_data=verify_y)
# 设置全局参数
bar.set_global_opts(
title_opts=opts.TitleOpts(title='40分类\n垃圾分类 All/Train/Verify 不同类别数据分布'),
xaxis_opts=opts.AxisOpts(axislabel_opts=opts.LabelOpts(rotate=size1['rotate'])) # 使x轴数据标签倾斜
)
# 展示图表
bar.render('All_Train_Verify.html')
bar_path = os.getcwd()
if show_web:
url = 'file://' + bar_path + '/All_Train_Verify.html'
webbrowser.open(url)
else:
print('打开 \'' + bar_path + '/All_Train_Verify.html\'以查看图表')
# 4分类数据可视化
data_4y = {}
# train_4y = {}
# verify_4y = {}
garbage_classify_rule = __json_load(os.path.join(data_path, 'garbage_classify_rule.json'))
garbage_index_classify = __json_load(os.path.join(data_path, 'garbage_index_classify.json'))
train_img_dict = __json_load(os.path.join(data_path, 'train_img_dict.json'))
verify_img_dict = __json_load(os.path.join(data_path, 'verify_img_dict.json'))
# 更新字典 40类更新为4类
for item in garbage_classify_rule:
garbage_classify_rule[item] = garbage_index_classify[garbage_classify_rule[item].split('/')[0]]
# 统计总数据
for item in img_path_list:
data_4y[item['img_4lable']] = data_4y.get(item['img_4lable'], 0) + 1
# 总数据排序
data_4y = dict(sorted(data_4y.items(), key=lambda item: item[0]))
# 更新train_img_dict字典 40类更新为4类
for item in train_img_dict:
train_img_dict[item] = garbage_classify_rule[str(train_img_dict[item])]
# 统计各类数量并排序
train_4y = dict(sorted((dict(Counter(train_img_dict.values()))).items(), key=lambda item: item[0]))
# 更新verify_img_dict字典 40类更新为4类
for item in verify_img_dict:
verify_img_dict[item] = garbage_classify_rule[str(verify_img_dict[item])]
# 统计各类数量并排序
verify_4y = dict(sorted((dict(Counter(verify_img_dict.values()))).items(), key=lambda item: item[0]))
# x轴数据
x = ["其他垃圾", "厨余垃圾", "可回收物", "有害垃圾"]
# y轴数据
data_4y = list(data_4y.values())
train_4y = list(train_4y.values())
verify_4y = list(verify_4y.values())
# 创建Bar示例
bar = Bar(init_opts=opts.InitOpts('{}px'.format(size2['width']), '{}px'.format(size2['height'])))
bar.add_xaxis(xaxis_data=x)
bar.add_yaxis(series_name='All', yaxis_data=data_4y)
bar.add_yaxis(series_name='Train', yaxis_data=train_4y)
bar.add_yaxis(series_name='Verify', yaxis_data=verify_4y)
# 设置全局参数
bar.set_global_opts(
title_opts=opts.TitleOpts(title='4分类\n垃圾分类 不同类别数据分布'),
xaxis_opts=opts.AxisOpts(axislabel_opts=opts.LabelOpts(rotate=size2['rotate'])) # 使x轴数据标签倾斜
)
# 展示图表
bar.render('4classify.html')
# bar.render_notebook()
bar_path = os.getcwd()
if show_web:
url = 'file://' + bar_path + '/4classify.html'
webbrowser.open(url)
else:
print('打开 \'' + bar_path + '/4classify.html\'以查看图表')
# 图片尺寸数据直方图
data = get_img_size(data_path)
ratio_list = [item[3] for item in data] # 获取img_size中的比例数据
new_ratio_list = list(filter(lambda x: x > 0.5 and x <= 2, ratio_list))
# 创建示例对象
sns.set()
np.random.seed(0)
__set_zh() # 设置中文
# seaborn 直方图展示
ax = sns.distplot(ratio_list)
plt.title('原始数据分布')
plt.show()
ax = sns.distplot(new_ratio_list) # 数据分布(0,2)
plt.title('过滤后的数据分布(0.5<x<2)')
plt.show()
# 进度展示
def __processBar(message, num, total):
'''
进度展示 message: num/total 100%
:param message: 消息
:param num: 当前进度
:param total: 总体
:return:
'''
rate = num / total
rate_num = int(rate * 100)
if rate_num == 100:
r = '\r{}:\t{}/{}\t100%\n'.format(message, num, total)
else:
r = '\r{}:\t{}/{}\t{}%'.format(message, num, total, rate_num)
sys.stdout.write(r)
sys.stdout.flush
# 解决matplotlib中文显示
def __set_zh():
'''
matplotlib中文显示
:return:
'''
type = sys.platform
if type == 'win32':
plt.rcParams['font.sans-serif'] = ['KaiTi']
elif type == 'linux':
plt.rcParams['font.sans-serif'] = ['SimHei']
else:
plt.rcParams['font.sans-serif'] = ['Songti SC'] # 正常显示中文标签
plt.rc('axes', unicode_minus=False) # 解决坐标轴负数的负号显示问题
def __version__():
device = ('GPU' if torch.cuda.is_available() else 'CPU')
print('torch version:', torch.__version__)
print('torchvision version:', torchvision.__version__)
if torch.cuda.is_available():
print('Device:GPU')
print('GPU list:')
for i in range(torch.cuda.device_count()):
print(' {}. {}'.format(i, torch.cuda.get_device_name(i)))
print('Using:', torch.cuda.get_device_name(torch.cuda.current_device()))
else:
print('Device:CPU')
# 测试
if __name__ == '__main__':
Root_Path = ''
Data_Path = Root_Path + 'data/'
img_path_list, img_name2label_dict, img_label2count_dict = get_img_info(Data_Path)
train_json_path, verify_json_path, train_img_list, verify_img_list = data_shuffle_split(0.8, img_path_list)
# data_visualization(Data_Path, show_web=False)