-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathtorch_utils.py
90 lines (79 loc) · 3.51 KB
/
torch_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from torch import nn
import os
import re
class JoinedDataLoader:
"""Loader for sampling from multiple loaders with probability proportional to their length. Stops when all loaders are exausthed.
Useful in case you can't join samples of different datasets in a single batch.
"""
def __init__(self, loaderA, loaderB):
self.loaderA = loaderA
self.loaderB = loaderB
self.probA = len(loaderA)/(len(loaderA)+len(loaderB))
self.loaderAiter, self.loaderBiter = iter(loaderA), iter(loaderB)
def __iter__(self):
return self
def __next__(self):
loader_choice = torch.rand(1).item()
if loader_choice < self.probA:
try:
n = next(self.loaderAiter)
except StopIteration:
try:
n = next(self.loaderBiter)
except StopIteration:
self._reset_iterators()
raise StopIteration
else:
try:
n = next(self.loaderBiter)
except StopIteration:
try:
n = next(self.loaderAiter)
except StopIteration:
self._reset_iterators()
raise StopIteration
return n
def _reset_iterators(self):
self.loaderAiter, self.loaderBiter = iter(self.loaderA), iter(self.loaderB)
def __len__(self):
return len(self.loaderAiter) + len(self.loaderBiter)
def conv_out_shape(dims, conv):
"""Computes the output shape for given convolution module
Args:
dims (tuples): a tuple of kind (w, h)
conv (module): a pytorch convolutional module
"""
kernel_size, stride, pad, dilation = conv.kernel_size, conv.stride, conv.padding, conv.dilation
return tuple(int(((dims[i] + (2 * pad[i]) - (dilation[i]*(kernel_size[i]-1))-1)/stride[i])+1) for i in range(len(dims)))
def general_same_padding(i, k, d=1, s=1, dims=2):
"""Compute the padding to obtain the same output shape when using convolution
Args:
- input_size, kernel_size, dilation, stride (tuple or ints)
- dims (int): number of dimensions for the padding
"""
#Convert i, k and d to tuples if they are int
i = tuple([i for j in range(dims)]) if type(i) == int else i
k = tuple([k for j in range(dims)]) if type(k) == int else k
d = tuple([d for j in range(dims)]) if type(d) == int else d
s = tuple([s for j in range(dims)]) if type(s) == int else s
return tuple([int(0.5*(d[j]*(k[j]-1)-(1-i[j])*(s[j]-1))) for j in range(dims)])
def same_padding(k, d=1, dims=2):
"""Compute the padding to obtain the same output shape when using convolution,
considering the case when the stride is unitary
Args:
- input_size, kernel_size, dilation, stride (tuple or ints)
- dims (int): number of dimensions for the padding
"""
#Convert i, k and d to tuples if they are int
k = tuple([k for j in range(dims)]) if type(k) == int else k
d = tuple([d for j in range(dims)]) if type(d) == int else d
return tuple([int(0.5*(d[j]*(k[j]-1))) for j in range(dims)])
def load_model(model,model_dir,run_tag):
epoch = 0
check = [f for f in os.listdir(model_dir) if f.startswith(run_tag)]
if len(check)>0:
epoch = max([int(re.findall('epoch\d+',c)[0][5:]) for c in check])
model.load_state_dict(torch.load(os.path.join(model_dir,run_tag+'_epoch'+str(epoch)+'.pt')))
print("Resuming trainig from epoch %d" % epoch)
return model, epoch