-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathtest_dataset.py
61 lines (48 loc) · 1.71 KB
/
test_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import pytest
from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader
from blendtorch import btt
BLENDDIR = Path(__file__).parent / "blender"
#@pytest.mark.background
def test_dataset():
launch_args = dict(
scene="",
script=BLENDDIR / "dataset.blend.py",
num_instances=1,
named_sockets=["DATA"],
background=True,
)
with btt.BlenderLauncher(**launch_args) as bl:
addr = bl.launch_info.addresses["DATA"]
# Note, https://github.com/pytorch/pytorch/issues/44108
ds = btt.RemoteIterableDataset(addr, max_items=16)
dl = DataLoader(ds, batch_size=4, num_workers=4, drop_last=False, shuffle=False)
count = 0
for item in dl:
assert item["img"].shape == (4, 64, 64)
assert item["frameid"].shape == (4,)
count += 1
assert count == 4
#@pytest.mark.background
def test_dataset_robustness():
launch_args = dict(
scene="",
script=BLENDDIR / "dataset_robust.blend.py",
num_instances=2,
named_sockets=["DATA"],
background=True,
)
with btt.BlenderLauncher(**launch_args) as bl:
addr = bl.launch_info.addresses["DATA"]
# Note, https://github.com/pytorch/pytorch/issues/44108
ds = btt.RemoteIterableDataset(addr, max_items=5000)
dl = DataLoader(ds, batch_size=4, num_workers=0, drop_last=False, shuffle=False)
ids = []
for item in dl:
assert item["img"].shape == (4, 64, 64)
assert item["frameid"].shape == (4,)
ids.extend(item["btid"].tolist())
if len(np.unique(ids)) == 2:
break
assert len(np.unique(ids)) == 2