forked from Zheng222/DMFN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
27 lines (21 loc) · 1.02 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import yaml
import torch
import torchvision.utils as vutils
import os
def get_config(config):
with open(config, 'r') as stream:
return yaml.load(stream, Loader=yaml.FullLoader)
def _write_images(images, display_image_num, file_name): # images is a list that contains tensors with shape [N,C,H,W]
image_tensor = torch.stack(images, dim=0)
image_grid = vutils.make_grid(image_tensor, nrow=display_image_num, padding=0, normalize=True)
vutils.save_image(image_grid, file_name, nrow=1)
def prepare_sub_folder(output_directory):
image_directory = os.path.join(output_directory, 'images')
if not os.path.exists(image_directory):
print("Creating directory: {}".format(image_directory))
os.makedirs(image_directory)
checkpoint_directory = os.path.join(output_directory, 'checkpoints')
if not os.path.exists(checkpoint_directory):
print("Creating directory: {}".format(checkpoint_directory))
os.makedirs(checkpoint_directory)
return checkpoint_directory, image_directory