-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathgen_bracketplus.py
160 lines (121 loc) · 6.28 KB
/
gen_bracketplus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# coding=utf-8
import os
import cv2
import torch
import argparse
from torch.nn import functional as F
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import torch.nn as nn
import torch.distributions as tdist
import glob
from function import *
from unprocess import add_noise, random_iso1600near_levels
from tqdm import tqdm
import random
import imageio
''' Generating paired data for BracketIRE+ task '''
def write_img(raw_img, meta_data, folder_dir, name='', raw_max=1, down='x2'):
if name == 'gt':
np.save(folder_dir + 'alignratio.npy', raw_max.cpu().numpy())
raw_img = raw_img * raw_max
np.save(folder_dir + 'raw_' + name + '.npy', raw_img.clone().cpu().numpy())
raw_img = torch.clamp(raw_img/16., 0, 1)
rgb_img = get_raw2rgb(raw_img, meta_data, demosaic='net', lineRGB=True) # menon2007, net
rgb_img = torch.clamp(mu_tonemap(rgb_img)*65535, 0.0, 65535.0).cpu().numpy().astype(np.uint16)
cv2.imwrite(folder_dir + 'rgb_vis_' + name + '.png', rgb_img[..., ::-1])
else:
new_folder_dir = folder_dir + down + '/'
os.makedirs(new_folder_dir, exist_ok=True)
raw_img = torch.clamp(raw_img * (2**10-1), 0, 2**10-1)
np.save(new_folder_dir + 'raw_' + name + '.npy', raw_img.cpu().numpy().round().astype(np.uint16))
rgb_img = get_raw2rgb(raw_img / (2**10-1), meta_data, demosaic='net', lineRGB=False) # menon2007, net
rgb_img = torch.clamp(rgb_img * 255.0, 0.0, 255.0).cpu().numpy().astype(np.uint8)
cv2.imwrite(new_folder_dir + 'rgb_vis_' + name + '.png', rgb_img[..., ::-1])
def down_sam(raw_img, scale): # [H,W,C]
img = raw_img.clone().permute(2,0,1).unsqueeze(0)
down_image = F.interpolate(img, scale_factor=1/scale, mode='bicubic', align_corners=True)
down_image = down_image.squeeze(0).permute(1,2,0)
down_image = torch.clamp(down_image, 0, 1)
return down_image
if __name__ == '__main__':
read_path = '/dataset/HDM-HDR-2014/HdM-HDR-2014_Original-HDR-Camera-Footage/'
# Download from https://www.hdm-stuttgart.de/vmlab/hdm-hdr-2014/#FTPdownload
write_root = '/dataset/BracketIRE_Plus/'
device = torch.device("cuda:0")
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
random.seed(0)
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
model = bulid_model(modelDir='train_log', device=device)
frame_num = 12
# print(sorted(os.listdir(read_path)))
img_paths = ['beerfest_lightshow_01', 'beerfest_lightshow_02', 'beerfest_lightshow_02_reconstruction_update_2015',
'beerfest_lightshow_03', 'beerfest_lightshow_04', 'beerfest_lightshow_04_reconstruction_update_2015',
'beerfest_lightshow_05', 'beerfest_lightshow_06', 'beerfest_lightshow_07',
'bistro_01', 'bistro_02', 'bistro_03',
'carousel_fireworks_01', 'carousel_fireworks_02', 'carousel_fireworks_03',
'carousel_fireworks_04', 'carousel_fireworks_05', 'carousel_fireworks_06',
'carousel_fireworks_07', 'carousel_fireworks_08', 'carousel_fireworks_09',
'cars_closeshot', 'cars_fullshot', 'cars_longshot',
'fireplace_01', 'fireplace_02', 'fishing_closeshot',
'showgirl_01', 'showgirl_02', 'smith_hammering',
'smith_welding','fishing_longshot', 'hdr_testimage',
'poker_fullshot', 'poker_travelling_slowmotion']
for path in img_paths:
img_path = read_path + path + '/'
folder_dir = write_root + path + '/'
os.makedirs(folder_dir, exist_ok=True)
print(folder_dir)
frame_paths = read_11_paths(img_path, frame_num)
for frame_path in tqdm(frame_paths):
list_imgs = []
final_dir = folder_dir + split_name(frame_path[0]) + '/'
os.makedirs(final_dir, exist_ok=True)
pre_img, raw_max = read_exr(frame_path[0], device=device)
pre_img_gt = torch.clamp(gamma(gamma_reverse(pre_img)/raw_max), 0, 1)
# print(pre_img_gt.max(), pre_img_gt.min())
clean_raw, features = get_rgb2raw(pre_img_gt, features=None, device=device)
H, W, C = clean_raw.size()
write_img(clean_raw, features, final_dir, name='gt', raw_max=raw_max)
meta = {}
for key in features:
meta[key] = features[key].cpu().numpy().astype(np.float32)
np.save(final_dir + 'metadata.npy', meta)
del clean_raw, pre_img_gt, meta
curr_max = 65535 / pre_img.max()
pre_img = pre_img * curr_max
list_imgs.append(pre_img)
for i in range(frame_num-1):
next_img = read_exr(frame_path[i+1], device=device)[0] * curr_max
list_imgs.extend(frame_inter(pre_img, next_img, model, exp=5))
list_imgs.append(next_img)
pre_img = next_img
del pre_img, next_img
start = 0
for n in range(5):
raw = torch.zeros([H, W, C], dtype=torch.float, device=device)
m = 4 ** n
for i in range(m):
img_ldr = list_imgs[start + i] / curr_max
img_ldr = torch.clamp(gamma(gamma_reverse(img_ldr)*4**(n-2)), 0, 1)
gt_raw, _ = get_rgb2raw(img_ldr, features, device)
raw = raw + gt_raw
raw = torch.clamp(raw / m, 0, 1)
shot_noise, read_noise = random_iso1600near_levels(torch.rand(1, device=device)[0])
raw_x2 = down_sam(raw, scale=2)
raw_x2 = add_noise(raw_x2, shot_noise, read_noise)
raw_x2 = torch.clamp(raw_x2, 0, 1)
write_img(raw_x2, features, final_dir, name=str(m), down='x2')
raw_x4 = down_sam(raw, scale=4)
raw_x4 = add_noise(raw_x4, shot_noise, read_noise)
raw_x4 = torch.clamp(raw_x4, 0, 1)
write_img(raw_x4, features, final_dir, name=str(m), down='x4')
del gt_raw, img_ldr, raw, raw_x2, raw_x4
start = start + m
del list_imgs