-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathutils.py
42 lines (33 loc) · 1.15 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- coding: utf-8 -*-
"""
Created on Wed Oct 28 11:07:42 2020
@author: lvjf
utils: contain some useful functions for JSSP
"""
from torch import nn
import torch
import numpy as np
def v_wrap(np_array, dtype=np.float32):
if np_array.dtype != dtype:
np_array = np_array.astype(dtype)
return torch.from_numpy(np_array)
def normalized_columns_initializer(weights, std=1.0):
out = torch.randn(weights.size())
out *= std / torch.sqrt(out.pow(2).sum(1, keepdim=True))
return out
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
weight_shape = list(m.weight.data.size())
fan_in = np.prod(weight_shape[1:4])
fan_out = np.prod(weight_shape[2:4]) * weight_shape[0]
w_bound = np.sqrt(6. / (fan_in + fan_out))
m.weight.data.uniform_(-w_bound, w_bound)
m.bias.data.fill_(0)
elif classname.find('Linear') != -1:
weight_shape = list(m.weight.data.size())
fan_in = weight_shape[1]
fan_out = weight_shape[0]
w_bound = np.sqrt(6. / (fan_in + fan_out))
m.weight.data.uniform_(-w_bound, w_bound)
m.bias.data.fill_(0)