forked from prs-eth/Marigold-DC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
277 lines (228 loc) · 9.48 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
from pathlib import Path
from typing import Any, Sequence
import click
import cv2
import imagesize
import numpy as np
from PIL import Image
def mae(preds: np.ndarray, depth: np.ndarray, mask: np.ndarray | None = None) -> float:
"""Calculate the mean absolute error between two depth maps.
Args:
preds (np.ndarray): Predicted depth map
depth (np.ndarray): Ground truth depth map
mask (np.ndarray | None, optional): Mask to apply to the depth maps.
Returns:
float: Mean absolute error between the two depth maps
"""
if mask is not None:
preds = preds[mask]
depth = depth[mask]
return float(np.mean(np.abs(preds - depth)))
def rmse(preds: np.ndarray, depth: np.ndarray, mask: np.ndarray | None = None) -> float:
"""Calculate the root mean squared error between two depth maps.
Args:
preds (np.ndarray): Predicted depth map
depth (np.ndarray): Ground truth depth map
mask (np.ndarray | None, optional): Mask to apply to the depth maps.
Returns:
float: Root mean squared error between the two depth maps
"""
if mask is not None:
preds = preds[mask]
depth = depth[mask]
return float(np.sqrt(np.mean((preds - depth) ** 2)))
class CommaSeparated(click.ParamType):
"""A Click parameter type that parses comma-separated values into a list.
This class extends Click's ParamType to handle comma-separated input strings,
converting them into a list of values of a specified type. It can optionally
enforce a specific number of values.
Args:
type_ (type): The type to convert each comma-separated value to. Defaults to str.
n (int | None): If specified, enforces exactly this many comma-separated values.
Must be None or a positive integer. Defaults to None.
Raises:
ValueError: If n is not None and not a positive integer.
Examples:
Basic usage with strings:
@click.command()
@click.option("--names", type=CommaSeparated())
def cmd(names):
# --names "alice,bob,charlie" -> ["alice", "bob", "charlie"]
pass
With integers and fixed length:
@click.command()
@click.option("--coords", type=CommaSeparated(int, n=2))
def cmd(coords):
# --coords "10,20" -> [10, 20]
# --coords "1,2,3" -> Error: not exactly 2 values
pass
With floats:
@click.command()
@click.option("--weights", type=CommaSeparated(float))
def cmd(weights):
# --weights "0.1,0.2,0.7" -> [0.1, 0.2, 0.7]
pass
"""
name = "comma_separated"
def __init__(self, type_: type = str, n: int | None = None) -> None:
if n is not None and n <= 0:
raise ValueError("n must be None or a positive integer")
self.type = type_
self.n = n
def convert(
self,
value: str | None,
param: click.Parameter | None,
ctx: click.Context | None,
) -> list[Any] | None:
if value is None:
return None
value = value.strip()
if value == "":
return []
items = value.split(",")
if self.n is not None and len(items) != self.n:
self.fail(
f"{value} does not contain exactly {self.n} comma separated values",
param,
ctx,
)
try:
return [self.type(item) for item in items]
except ValueError:
self.fail(
f"{value} is not a valid comma separated list of {self.type.__name__}",
param,
ctx,
)
def is_empty_img(img: Image.Image) -> bool:
"""Check if a PIL Image is empty (all values are 0).
Args:
img (Image.Image): Input PIL Image
Returns:
bool: True if image is empty (all values are 0), False otherwise
"""
return not np.any(np.array(img))
def make_grid(
imgs: np.ndarray,
rows: int | None = None,
cols: int | None = None,
resize: tuple[int, int] | None = None,
interpolation: int | list[int] = cv2.INTER_LINEAR,
) -> np.ndarray:
"""Create a grid of images from a numpy array.
Takes a batch of images and arranges them in a grid pattern. Can optionally resize
the final grid output.
Args:
imgs (np.ndarray): Array of images with shape (N,H,W,C) where:
N is number of images
H is height of each image
W is width of each image
C is number of channels per image
rows (int | None, optional): Number of rows in output grid. If None:
- Will be calculated from cols if cols is specified
- Will create a square-ish grid if cols is also None
cols (int | None, optional): Number of columns in output grid. If None:
- Will be calculated from rows if rows is specified
- Will create a square-ish grid if rows is also None
resize (tuple[int, int] | None, optional): Target (height, width) to resize final grid to.
- If None: No resizing is performed
- If either dimension is -1: That dimension is calculated to preserve aspect ratio
- If both dimensions are -1: No resizing is performed
interpolation (cv2.InterpolationFlags | list[cv2.InterpolationFlags], optional):
OpenCV interpolation method(s) for resizing. Can be either:
- A single interpolation flag to use for all images
- A list of flags matching the number of input images
Defaults to cv2.INTER_LINEAR.
Returns:
np.ndarray: Grid image with shape (grid_height, grid_width, C) containing all input
images arranged in a grid pattern.
Raises:
ValueError: If imgs is empty or not a 4D array
ValueError: If a list of interpolation methods is provided but length doesn't match
number of input images
Example:
>>> # Create 2x2 grid from 4 images
>>> grid = make_grid(images, rows=2, cols=2)
>>> # Create auto-sized grid, resized to 512x512
>>> grid = make_grid(images, resize=(512,512))
>>> # Create grid with different interpolation per image
>>> grid = make_grid(images, interpolation=[cv2.INTER_LINEAR, cv2.INTER_NEAREST])
"""
if imgs.size == 0 or len(imgs.shape) != 4:
raise ValueError("Images must be non-empty 4D array (N,H,W,C)")
n = imgs.shape[0]
if isinstance(interpolation, Sequence) and len(interpolation) != n:
raise ValueError(
f"Interpolation list length ({len(interpolation)}) must match number of images ({n})"
)
# Calculate grid dimensions
if rows is None and cols is None:
cols = int(np.ceil(np.sqrt(n)))
if rows is None:
rows = int(np.ceil(n / cols))
if cols is None:
cols = int(np.ceil(n / rows))
h, w = imgs.shape[1:3]
grid_h, grid_w = h * rows, w * cols
# Calculate target size for the grid
if resize is not None:
th, tw = resize
if th != -1 or tw != -1:
if isinstance(interpolation, Sequence):
methods = interpolation
else:
methods = [interpolation] * n
target_h = th if th != -1 else int(tw * grid_h / grid_w)
target_w = tw if tw != -1 else int(th * grid_w / grid_h)
# Calculate individual image size based on grid target size
h = target_h // rows
w = target_w // cols
# Resize all images to the new size
imgs = np.array(
[
cv2.resize(img, (w, h), interpolation=method)
for img, method in zip(imgs, methods, strict=False)
]
)
# Create and fill grid
grid = np.zeros((h * rows, w * cols) + imgs.shape[3:], dtype=imgs.dtype)
for idx in range(n):
i, j = idx // cols, idx % cols
grid[i * h : (i + 1) * h, j * w : (j + 1) * w] = imgs[idx]
return grid
def to_depth(
img: Image.Image, dtype: str = "float32", max_distance: float = 120.0
) -> np.ndarray:
"""Convert a PIL Image to a depth map.
Args:
img (Image.Image): Input PIL image
dtype (str, optional): Data type for output array. Defaults to "float32".
max_distance (float, optional): Maximum depth value in meters.
Defaults to 120.0.
Returns:
np.ndarray: Depth map as numpy array with values ranging from 0 to max_distance,
where 0 represents the closest depth and max_distance the farthest.
Raises:
ValueError: If input image is not in RGB format
"""
if img.mode != "RGB":
raise ValueError(f"Input image must be RGB format, got {img.mode}")
return max_distance * np.array(img, dtype=dtype)[..., 0] / 255.0
def is_img_file(path: Path) -> bool:
"""Check if a path points to a valid image file.
Args:
path (Path): Path to check
Returns:
bool: True if path points to a valid image file that can be opened,
False otherwise
"""
return path.is_file() and imagesize.get(path) != (-1, -1)
def get_img_paths(root: Path) -> list[Path]:
"""Get all image file paths under the given root directory.
Args:
root (Path): Root directory to search for images
Returns:
list[Path]: List of paths to image files
"""
return [path for path in root.rglob("*") if is_img_file(path)]