-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
72 lines (62 loc) · 2.71 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
from models import BRGM, LBRGM
import torch
import json
import click
import warnings
warnings.filterwarnings('ignore')
MODELS = {"BRGM" : BRGM, "LBRGM" : LBRGM}
def get_testable_fpaths():
with open("testables.txt", "r") as f:
lines = f.readlines()
lines = [line[:-1] for line in lines]
lines = [f"images1024x1024/{line[:2]}000/{line}.png" for line in lines]
testables = lines[:-1]
fpaths = []
for i in range(len(testables)):
im_string = f"256lows/{i}as1024.png"
fpaths.append(im_string)
return fpaths
@click.command()
@click.option('--device', default=None, help='Device to train on.')
@click.option('--model', type=click.Choice(['LBRGM', 'BRGM'],), default='LBRGM')
@click.option('--fpaths', default=get_testable_fpaths(), multiple=True, help='Paths to image file.')
@click.option('--outpath', required=True, help='Output directory to save run progress.')
@click.option('--no-steps', default=2000, help='Number of optimization steps.')
@click.option('--reconstruction-type', type=click.Choice(['inpaint', 'superres'],), help='Corruption process: either inpainting or superresolution.')
@click.option('--input-dim', default=64, help='Height and width of input image to have super-resolution applied')
@click.option('--fpath-corrupted', default=True, help='Whether the input image has already had the corruption applied.')
@click.option('--mask', help='Specify path to the mask to be applied. See masks/1024x1024/ directory for masks')
def run(device, fpaths, outpath, no_steps, reconstruction_type, input_dim, fpath_corrupted, model, mask):
best_lpips = {}
if not os.path.exists(outpath): os.mkdir(outpath)
if reconstruction_type == 'superres':
assert 1024 % input_dim == 0, "Input dimension need be a divisor of 1024, the height/width of images we can generate"
assert input_dim is not None, "Specify an input dimension"
if reconstruction_type == 'inpaint':
assert mask is not None, "Specify a mask to apply. See "
for i, fpath in enumerate(fpaths):
print('-'*50)
print(f"Reconstructing image with file path {fpath}, image {i+1} of {len(fpaths)}")
print('-'*50)
cur_outdir = f"{outpath}/{i}"
if not os.path.exists(cur_outdir):
os.mkdir(cur_outdir)
model_args = {
"fname" : fpath,
"verbose" : False,
"im_verbose" : True,
"out_dir" : cur_outdir,
"device" : device,
"fpath_corrupted" : fpath_corrupted,
"reconstruction_type" : reconstruction_type,
"input_dim" : input_dim,
"mask_file" : mask,
}
model = MODELS[model](**model_args)
model.im_verbose = False
model.lossprint_interval = 250
model.learning_rate = 0.1
model.train_model(max_steps=no_steps)
if __name__ == "__main__":
run()