-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtest_interactive.py
455 lines (341 loc) · 16.2 KB
/
test_interactive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
# python test.py --dataroot ./inp/ --which_direction BtoA --model pix2pix --name facades_label2photo_pretrained --dataset_mode aligned --which_model_netG unet_256 --norm batch --loadSize=256
import time
import os
import glob
import shutil
import util
from options.test_options import TestOptions
from data import CreateDataLoader
from models import create_model
import numpy as np
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from PIL import Image
import copy
import _thread
from util.fit_boxes import fit_boxes, LabelClass, LabelFit
from util.fit_circles import fit_circles
import traceback
import urllib.request
import sys
# class_names
# the order if the classes is their z order (first one is at the back)
cmp_classes = [
LabelClass('other', [0, 0, 0], 0), # black borders or sky (id 0)
LabelClass('background', [0, 0, 170], 1), # background (id 1)
LabelClass('facade', [0, 0, 255], 2), # facade (id 2)
LabelClass('moulding', [255, 85, 0], 3), # moulding (id 3)
LabelClass('cornice', [0, 255, 255], 4), # cornice (id 4)
LabelClass('pillar', [255, 0, 0], 5), # pillar (id 5)
LabelClass('window', [0, 85, 255], 6), # window (id 6)
LabelClass('door', [0, 170, 255], 7), # door (id 7)
LabelClass('sill', [85, 255, 170], 8), # sill (id 8)
LabelClass('blind', [255, 255, 0], 9), # blind (id 9)
LabelClass('balcony', [170, 255, 85], 10), # balcony (id 10)
LabelClass('shop', [170, 0, 0], 11), # shop (id 11)
LabelClass('deco', [255, 170, 0], 12), # deco (id 12)
]
fit_cmp_labels = {'window':LabelFit(-1), 'door':LabelFit(-1), 'sill':LabelFit(-1), 'balcony':LabelFit(-1), 'shop':LabelFit(-1)}
fit_cmp_labels_extended = {'window':LabelFit(-1), 'door':LabelFit(-1), 'sill':LabelFit(-1), 'balcony':LabelFit(-1), 'shop':LabelFit(-1), 'moulding':LabelFit(-1), 'cornice':LabelFit(-1)}
roof_classes = [
LabelClass('other', [0, 0, 0], 0),
LabelClass('flat_roof', [255, 0, 0], 1),
LabelClass('slanted_roof', [0, 255, 255], 2),
LabelClass('edge', [255, 0, 255], 3),
LabelClass('chimney', [255, 200, 0], 4),
LabelClass('velux', [0, 0, 255], 5),
]
fit_roof_labels = {'velux':LabelFit(max_count=3), 'chimney':LabelFit(max_count=3)}
blank_classes = [
LabelClass('other', [0, 0, 0], 0),
LabelClass('wall', [0, 0, 255], 1),
LabelClass('window', [0, 255, 0], 2),
]
fit_blank_labels = {'wall':LabelFit(-1), 'window':LabelFit(-1)}
pane_classes = [
LabelClass('other', [0, 0, 0], 0), # black borders or sky (id 0)
LabelClass('frame', [255, 0, 0], 1),
LabelClass('pane', [0, 0, 255], 2),
LabelClass('object', [0, 255, 0], 3),
]
fit_pane_labels = {'frame':LabelFit(-1), 'pane':LabelFit(-1), 'object':LabelFit(-1)}
def save_image(image_numpy, image_path):
try:
os.makedirs(os.path.dirname(image_path), exist_ok=True)
except:
pass
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path, 'PNG', quality=100)
def touch(fname, times=None):
with open(fname, 'a'):
os.utime(fname, times)
def rmrf (file):
try:
files = glob.glob(file)
for f in files:
if os.path.isfile(f):
os.remove(f)
except:
pass
class RunG(FileSystemEventHandler):
def __init__(self, model, opt, fit_boxes, fit_circles, directory):
self.model = model
self.opt = opt
self.fit_boxes = fit_boxes
self.fit_circles = fit_circles
self.directory = directory
def on_created(self, event): # when file is created
# do something, eg. call your function to process the image
print ("Got G event for file %s" % event.src_path)
try:
go = os.path.abspath(os.path.join (event.src_path, os.pardir, "go"))
if not os.path.isfile(go):
return
with open(go) as f:
name = f.readlines()[0]
print("starting to process %s" % name)
if self.opt.mlabel_condition:
self.opt.mlabel_dataroot = self.opt.dataroot.rstrip('/\\')+'_mlabels'
if self.opt.metrics_condition or self.opt.empty_condition:
self.opt.empty_dataroot = self.opt.dataroot.rstrip('/\\')+'_empty'
self.model.opt = self.opt
data_loader = CreateDataLoader(self.opt)
dataset = data_loader.load_data()
for i, data in enumerate(dataset):
# try:
zs = os.path.basename(data['A_paths'][0])[:-4].split("_")[1:]
z = np.array([float(i) for i in zs], dtype=np.float32)
self.model.set_input(data)
_, real_A, fake_B, real_B, _ = self.model.test_simple(z, encode_real_B=False)
img_path = self.model.get_image_paths()
print('%04d: process image... %s' % (i, img_path))
save_image(fake_B, "./output/%s/%s/%s" % (self.directory, name, os.path.basename(img_path[0])))
save_image(real_A, "./output/%s/%s/%s_label" % (self.directory, name, os.path.basename(img_path[0])))
if self.fit_boxes is not None:
fit_boxes(
img=fake_B, classes=self.fit_boxes[0], fit_labels=self.fit_boxes[1],
json_path="./output/%s/%s/%s_boxes" % (self.directory, name, os.path.basename(img_path[0])))
if self.fit_circles is not None:
fit_circles(
img=fake_B, classes=self.fit_circles[0], fit_labels=self.fit_circles[1],
json_path="./output/%s/%s/%s_circles" % (self.directory, name, os.path.basename(img_path[0])))
except Exception as e:
traceback.print_exc()
print(e)
try:
rmrf('./input/%s/val/*' % self.directory)
rmrf('./input/%s_empty/val/*' % self.directory)
rmrf('./input/%s_mlabel/val/*' % self.directory)
if os.path.isfile(go):
os.remove(go)
except Exception as e:
traceback.print_exc()
print(e)
class RunE(FileSystemEventHandler):
def __init__(self, model, opt, directory):
self.model = model
self.opt = opt
self.directory = directory
def on_created(self, event): # when file is created
# do something, eg. call your function to process the image
print("Got E event for file %s" % event.src_path)
try:
go = os.path.abspath(os.path.join(event.src_path, os.pardir, "go"))
if not os.path.isfile(go):
return
with open(go) as f:
name = f.readlines()[0]
print("starting to process %s" % self.opt.name)
self.model.opt = self.opt
data_loader = CreateDataLoader(self.opt)
dataset = data_loader.load_data()
for i, data in enumerate(dataset):
basename = os.path.basename(data['A_paths'][0])[:-4]
self.model.set_input(data)
z = self.model.encode_real_B()
img_path = self.model.get_image_paths()
print('%04d: process image... %s' % (i, img_path))
outfile = "./output/%s/%s/%s@%s" % (self.directory, name, basename, "_".join([str (s) for s in z[0]]) )
try:
os.makedirs(os.path.dirname(outfile), exist_ok=True)
except:
pass
touch (outfile)
except Exception as e:
traceback.print_exc()
print(e)
try:
rmrf('./input/%s/val/*' % self.directory)
if os.path.isfile(go):
os.remove(go)
except Exception as e:
traceback.print_exc()
print(e)
DOWNLOAD_ROOT = "http://geometry.cs.ucl.ac.uk/projects/2018/frankengan/data/franken_nets"
class Interactive():
def __init__(self, directory, name, size=256, which_model_netE='resnet_256', which_direction="BtoA",
fit_boxes=None, fit_circles=None, lbl_classes=None,
walldist_condition=False, imgpos_condition=False, noise_condition=False,
empty_condition=False, mlabel_condition=False, metrics_condition=False,
metrics_mask_color=None, norm='instance', nz=8, pytorch_v2=False, dataset_mode='aligned',
normalize_metrics=False, normalize_metrics2=False):
# options
optG = TestOptions().parse()
optG.name = name
optG.loadSize = size
optG.fineSize = size
optG.nThreads = 0 # min(1, optG.nThreads) # test code only supports nThreads=1
optG.batchSize = 1 # test code only supports batchSize=1
optG.serial_batches = True # no shuffle
optG.which_model_netE = which_model_netE
optG.which_direction = which_direction
optG.pytorch_v2 = pytorch_v2
self.download_if_missing (name, "latest_net_G.pth")
if not mlabel_condition:
self.download_if_missing (name, "latest_net_E.pth")
if "--download" in sys.argv:
return
optG.G_path = "./checkpoints/%s/latest_net_G.pth" % optG.name
optG.E_path = "./checkpoints/%s/latest_net_E.pth" % optG.name
optG.dataroot = "./input/%s/" % directory
optG.no_flip = True
optG.lbl_classes = lbl_classes
optG.walldist_condition = walldist_condition
optG.imgpos_condition = imgpos_condition
optG.noise_condition = noise_condition
optG.mlabel_condition = mlabel_condition
optG.metrics_condition = metrics_condition
optG.empty_condition = empty_condition
optG.metrics_mask_color = metrics_mask_color
optG.normalize_metrics = normalize_metrics
optG.normalize_metrics2 = normalize_metrics2
optG.norm = norm
optG.nz = nz
optG.dataset_mode = dataset_mode
optG.use_dropout = False
if optG.imgpos_condition:
optG.input_nc += 2 # 2 image position x,y channels
if optG.walldist_condition:
optG.input_nc += 1 # 1 wall distance channel
if optG.noise_condition:
optG.input_nc += 1 # 1 wall noise channel
if optG.mlabel_condition:
optG.input_nc += 3 # 3 additional channels: RGB
if optG.metrics_condition:
optG.input_nc += 6 # 6 additional channels
if optG.empty_condition:
optG.input_nc += 3 # 3 additional channels: RGB
optE = copy.deepcopy(optG)
optE.dataroot = "./input/%s_e/" % directory
optE.name = optG.name + "_e"
optE.walldist_condition = False
optE.imgpos_condition = False
optE.noise_condition = False
optE.mlabel_condition = False
optE.metrics_condition = False
optE.empty_condition = False
model = create_model(optG)
model.eval()
self.optG = optG
self.optE = optE
self.model = model
_thread.start_new_thread(self.go, (directory, name, size, fit_boxes, fit_circles))
def go(self, directory, name, size, fit_boxes, fit_circles):
observer = Observer()
input_folder = './input/%s/' % directory
shutil.rmtree(input_folder, ignore_errors=True)
os.makedirs(input_folder + "val", exist_ok=True)
observer.schedule(
RunG(model=self.model, opt=self.optG, fit_boxes=fit_boxes, fit_circles=fit_circles, directory=directory),
path=input_folder+"val/")
input_folder_e = './input/%s_e/' % directory
shutil.rmtree(input_folder_e, ignore_errors=True)
os.makedirs(input_folder_e+"val", exist_ok=True)
observer.schedule(
RunE(self.model, self.optE, directory+"_e"),
path=input_folder_e+"val/")
observer.start()
shutil.rmtree('./input/%s_empty/' % directory, ignore_errors=True)
shutil.rmtree('./input/%s_mlabels/' % directory, ignore_errors=True)
print('[network %s is awaiting input]' % name)
observer.join()
def download_if_missing(self, directory, file):
# nets should be listed here: http://geometry.cs.ucl.ac.uk/projects/2018/frankengan/data/franken_nets/
local = "./checkpoints/%s/%s" % (directory, file)
if not os.path.isfile(local):
os.makedirs( "./checkpoints/%s" % directory, exist_ok=True)
remote = "%s/%s/%s" % (DOWNLOAD_ROOT, directory, file)
print ("downloading %s" % remote)
try:
urllib.request.urlretrieve( remote, local )
print("done")
except Exception as e:
print(e)
sys.exit("Couldn't find or download weights for %s" % file)
# print("error downloading weights from %s" % remote)
if not os.path.isfile(local):
sys.exit("Couldn't find or download weights for %s" % file)
#------------------------------------------#
# latest set
#------------------------------------------#
os.makedirs( "./input", exist_ok=True)
os.makedirs( "./output", exist_ok=True)
Interactive("door textures", "labels2door_5",
dataset_mode='multi',
empty_condition=True, metrics_condition=True, imgpos_condition=True, normalize_metrics=True,
metrics_mask_color=[255, 0, 0])
Interactive("roof greebles", "r3_clabels2labels_f001_400",
size=512, which_model_netE='resnet_512',
dataset_mode='multi', fit_circles=(roof_classes, fit_roof_labels),
empty_condition=True, metrics_condition=True, imgpos_condition=True,
noise_condition=True,
metrics_mask_color=[0, 0, 255], normalize_metrics=True)
Interactive("roof textures", "r3_labels2image_f001_400",
size=512, which_model_netE='resnet_512',
dataset_mode='multi',
empty_condition=True, metrics_condition=True, imgpos_condition=True,
metrics_mask_color=[0, 0, 255], normalize_metrics=True)
Interactive("pane labels", "w3_empty2labels_f009_200",
dataset_mode='multi', fit_boxes=(pane_classes, fit_pane_labels),
empty_condition=True, metrics_condition=True, imgpos_condition=True,
metrics_mask_color=[255, 0, 0])
Interactive("pane textures", "w3_labels2image_f013_400",
dataset_mode='multi',
empty_condition=True, metrics_condition=True, imgpos_condition=True,
metrics_mask_color=[255, 0, 0])
Interactive("facade labels", "empty2windows_f009v2_400",
dataset_mode='multi', fit_boxes=(blank_classes, fit_blank_labels),
empty_condition=True, metrics_condition=True, imgpos_condition=True,
metrics_mask_color=[0, 0, 255])
Interactive("facade textures", "facade_windows_f013v2_150",
dataset_mode='multi',
empty_condition=True, metrics_condition=True, imgpos_condition=True,
metrics_mask_color=[0, 0, 255])
Interactive("facade greebles", "image2celabels_f001_335",
dataset_mode='multi', fit_boxes=(cmp_classes, fit_cmp_labels_extended),
empty_condition=True, metrics_condition=True, mlabel_condition=True,
metrics_mask_color=[0, 0, 255], nz=0)
Interactive("facade super", "super6", pytorch_v2=True)
Interactive("roof super", "super10", pytorch_v2=True)
#------------------------------------------#
# #------------------------------------------#
# # Pix2Pix comparison:
# #------------------------------------------#
# Interactive("facade textures", "empty2image_p2p001",
# dataset_mode='multi', nz=0, pytorch_v2=True)
# Interactive("roof textures", "r3_clabels2image_p2p001",
# size=512, which_model_netE='resnet_512',
# dataset_mode='multi', nz=0, pytorch_v2=True)
# #------------------------------------------#
# #------------------------------------------#
# # BicycleGAN comparison:
# #------------------------------------------#
# Interactive("facade textures", "empty2image_bg001",
# dataset_mode='multi', pytorch_v2=True)
# Interactive("roof textures", "r3_clabels2image_bg001",
# size=512, which_model_netE='resnet_512',
# dataset_mode='multi', pytorch_v2=True)
# #------------------------------------------#
print("all nets up")
while not "--download" in sys.argv: # loop forever, unless just downloading weights in docker
time.sleep(1000)