-
Notifications
You must be signed in to change notification settings - Fork 6
/
helper.py
152 lines (116 loc) · 4.87 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# ------------------------------------------------------------------------------------
# NeRF-Factory
# Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------
# Modified from NeRF (https://github.com/bmild/nerf)
# Copyright (c) 2020 Google LLC. All Rights Reserved.
# ------------------------------------------------------------------------------------
import numpy as np
import torch
import torch.nn.functional as F
def img2mse(x, y):
return torch.mean((x - y) ** 2)
def mse2psnr(x):
return -10.0 * torch.log(x) / np.log(10)
def cast_rays(t_vals, origins, directions):
return origins[..., None, :] + t_vals[..., None] * directions[..., None, :]
def sample_along_rays(
rays_o,
rays_d,
num_samples,
near,
far,
randomized,
lindisp,
):
bsz = rays_o.shape[0]
t_vals = torch.linspace(0.0, 1.0, num_samples + 1, device=rays_o.device)
if lindisp: # False
t_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * t_vals)
else:
t_vals = near * (1.0 - t_vals) + far * t_vals
if randomized: # True
mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
upper = torch.cat([mids, t_vals[..., -1:]], -1)
lower = torch.cat([t_vals[..., :1], mids], -1)
t_rand = torch.rand((bsz, num_samples + 1), device=rays_o.device)
t_vals = lower + (upper - lower) * t_rand
else:
t_vals = torch.broadcast_to(t_vals, (bsz, num_samples + 1))
coords = cast_rays(t_vals, rays_o, rays_d)
return t_vals, coords
def pos_enc(x, min_deg, max_deg):
scales = torch.tensor([2**i for i in range(min_deg, max_deg)]).type_as(x)
xb = torch.reshape((x[..., None, :] * scales[:, None]), list(x.shape[:-1]) + [-1])
four_feat = torch.sin(torch.cat([xb, xb + 0.5 * np.pi], dim=-1))
return torch.cat([x] + [four_feat], dim=-1)
def volumetric_rendering(rgb, density, t_vals, dirs, white_bkgd):
eps = 1e-10
dists = torch.cat(
[
t_vals[..., 1:] - t_vals[..., :-1], # Distance (Batch, 192)
torch.ones(t_vals[..., :1].shape, device=t_vals.device) * 1e10,
],
dim=-1,
)
dists = dists * torch.norm(dirs[..., None, :], dim=-1) # 位置间隔信息 (Batch, 193)
alpha = 1.0 - torch.exp(-density[..., 0] * dists) # Current Particle Density (Batch, 193)
accum_prod = torch.cat(
[
torch.ones_like(alpha[..., :1]),
torch.cumprod(1.0 - alpha[..., :-1] + eps, dim=-1),
],
dim=-1,
)
weights = alpha * accum_prod
comp_rgb = (weights[..., None] * rgb).sum(dim=-2)
depth = (weights * t_vals).sum(dim=-1)
acc = weights.sum(dim=-1)
inv_eps = 1 / eps
# if white_bkgd: # False
# comp_rgb = comp_rgb + (1.0 - acc[..., None])
return comp_rgb, depth, acc, weights
def sorted_piecewise_constant_pdf(
bins, weights, num_samples, randomized, float_min_eps=2**-32
):
eps = 1e-5
weight_sum = weights.sum(dim=-1, keepdims=True)
padding = torch.fmax(torch.zeros_like(weight_sum), eps - weight_sum)
weights = weights + padding / weights.shape[-1]
weight_sum = weight_sum + padding
pdf = weights / weight_sum
cdf = torch.fmin(
torch.ones_like(pdf[..., :-1]), torch.cumsum(pdf[..., :-1], dim=-1)
)
cdf = torch.cat(
[
torch.zeros(list(cdf.shape[:-1]) + [1], device=weights.device),
cdf,
torch.ones(list(cdf.shape[:-1]) + [1], device=weights.device),
],
dim=-1,
)
s = 1 / num_samples
if randomized:
u = torch.rand(list(cdf.shape[:-1]) + [num_samples], device=cdf.device)
else:
u = torch.linspace(0.0, 1.0 - float_min_eps, num_samples, device=cdf.device)
u = torch.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])
mask = u[..., None, :] >= cdf[..., :, None]
bin0 = (mask * bins[..., None] + ~mask * bins[..., :1, None]).max(dim=-2)[0]
bin1 = (~mask * bins[..., None] + mask * bins[..., -1:, None]).min(dim=-2)[0]
# Debug Here
cdf0 = (mask * cdf[..., None] + ~mask * cdf[..., :1, None]).max(dim=-2)[0]
cdf1 = (~mask * cdf[..., None] + mask * cdf[..., -1:, None]).min(dim=-2)[0]
t = torch.clip(torch.nan_to_num((u - cdf0) / (cdf1 - cdf0), 0), 0, 1)
samples = bin0 + t * (bin1 - bin0)
return samples
def sample_pdf(bins, weights, origins, directions, t_vals, num_samples, randomized):
t_samples = sorted_piecewise_constant_pdf(
bins, weights, num_samples, randomized
).detach()
t_vals = torch.sort(torch.cat([t_vals, t_samples], dim=-1), dim=-1).values
coords = cast_rays(t_vals, origins, directions)
return t_vals, coords