-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdataset.py
46 lines (40 loc) · 1.55 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import List, Optional, Tuple
import os
import torch
from torch.utils.data import Dataset
def load_data_from_dir(
data_folder: str, limit: int = 200
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[Optional[torch.Tensor]], List[Optional[torch.Tensor]]]:
latents, targets, conditions, unconditions = [], [], [], []
pt_files = [f for f in os.listdir(data_folder) if f.endswith('pt')]
for file_name in sorted(pt_files)[:limit]:
file_path = os.path.join(data_folder, file_name)
data = torch.load(file_path)
latents.append(data["latent"])
targets.append(data["img"])
conditions.append(data.get("c", None))
unconditions.append(data.get("uc", None))
return latents, targets, conditions, unconditions
class LD3Dataset(Dataset):
def __init__(
self,
ori_latent: List[torch.Tensor],
latent: List[torch.Tensor],
target: List[torch.Tensor],
condition: List[Optional[torch.Tensor]],
uncondition: List[Optional[torch.Tensor]],
):
self.ori_latent = ori_latent
self.latent = latent
self.target = target
self.condition = condition
self.uncondition = uncondition
def __len__(self) -> int:
return len(self.ori_latent)
def __getitem__(self, idx: int):
img = self.target[idx]
latent = self.latent[idx]
ori_latent = self.ori_latent[idx]
condition = self.condition[idx]
uncondition = self.uncondition[idx]
return img, latent, ori_latent, condition, uncondition