-
Notifications
You must be signed in to change notification settings - Fork 110
/
Copy pathweightG_fmt_converter.py
44 lines (34 loc) · 1.22 KB
/
weightG_fmt_converter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import os
import sys
if len(sys.argv) < 2:
print('ERROR! Not enough input arguments.')
print('Usage: {} <weights ckpt file> .'.format(sys.argv[0]))
ckpt_file = sys.argv[1]
# Converts old SEGAN-G weights namings
# Encoder: gen_enc.i.conv.weight/bias (i-th layer) --> enc_blocks.i.conv.weight/bias
# Decoder: gen_dec.i.conv.weight/bias (i-th layer) --> dec_blocks.i.deconv.weight/bias
out_file = ckpt_file + '.v2'
st_dict = torch.load(ckpt_file,
map_location=lambda storage, loc: storage)
new_dict = {'state_dict':{}}
# copy first level keys and values, but state_dict (weights per-se)
for k, v in st_dict.items():
if 'state_dict' in k:
continue
new_dict[k] = v
st_dict = st_dict['state_dict']
for k, v in st_dict.items():
if 'gen_enc' in k:
nk = k.replace('gen_enc', 'enc_blocks')
print('{} -> {}'.format(k, nk))
new_dict['state_dict'][nk] = v
elif 'gen_dec' in k:
nk = k.replace('gen_dec', 'dec_blocks')
nk = nk.replace('conv', 'deconv')
print('{} -> {}'.format(k, nk))
new_dict['state_dict'][nk] = v
else:
print('Keeping {}'.format(k))
new_dict['state_dict'][k] = v
torch.save(new_dict, out_file)