-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathutils_network.py
executable file
·107 lines (91 loc) · 4.51 KB
/
utils_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import tensorflow as tf
import numpy as np
from tensorflow.contrib.layers import fully_connected as FC_Net
### CONSTRUCT MULTICELL FOR MULTI-LAYER RNNS
def create_rnn_cell(num_units, num_layers, keep_prob, RNN_type, activation_fn):
'''
GOAL : create multi-cell (including a single cell) to construct multi-layer RNN
num_units : number of units in each layer
num_layers : number of layers in MulticellRNN
keep_prob : keep probabilty [0, 1] (if None, dropout is not employed)
RNN_type : either 'LSTM' or 'GRU'
'''
if activation_fn == 'None':
activation_fn = tf.nn.tanh
cells = []
for _ in range(num_layers):
if RNN_type == 'GRU':
cell = tf.contrib.rnn.GRUCell(num_units, activation=activation_fn)
elif RNN_type == 'LSTM':
cell = tf.contrib.rnn.LSTMCell(num_units, activation=activation_fn, state_is_tuple=True)
# cell = tf.contrib.rnn.LSTMCell(num_units, activation=activation_fn)
if not keep_prob is None:
cell = tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=keep_prob, output_keep_prob=keep_prob) # state_keep_prob=keep_prob
cells.append(cell)
cell = tf.contrib.rnn.MultiRNNCell(cells)
return cell
### EXTRACT STATE OUTPUT OF MULTICELL-RNNS
def create_concat_state(state, num_layers, RNN_type, BiRNN=None):
'''
GOAL : concatenate the tuple-type tensor (state) into a single tensor
state : input state is a tuple ofo MulticellRNN (i.e. output of MulticellRNN)
consist of only hidden states h for GRU and hidden states c and h for LSTM
num_layers : number of layers in MulticellRNN
RNN_type : either 'LSTM' or 'GRU'
'''
for i in range(num_layers):
if BiRNN != None:
if RNN_type == 'LSTM':
tmp = tf.concat([state[0][i][1], state[1][i][1]], axis=1) ## i-th layer, h state for LSTM
elif RNN_type == 'GRU':
tmp = tf.concat([state[0][i], state[1][i]], axis=1) ## i-th layer, h state for GRU
else:
print('ERROR: WRONG RNN CELL TYPE')
else:
if RNN_type == 'LSTM':
tmp = state[i][1] ## i-th layer, h state for LSTM
elif RNN_type == 'GRU':
tmp = state[i] ## i-th layer, h state for GRU
else:
print('ERROR: WRONG RNN CELL TYPE')
if i == 0:
rnn_state_out = tmp
else:
rnn_state_out = tf.concat([rnn_state_out, tmp], axis = 1)
return rnn_state_out
### FEEDFORWARD NETWORK
def create_FCNet(inputs, num_layers, h_dim, h_fn, o_dim, o_fn, w_init, w_reg=None, keep_prob=1.0):
'''
GOAL : Create FC network with different specifications
inputs (tensor) : input tensor
num_layers : number of layers in FCNet
h_dim (int) : number of hidden units
h_fn : activation function for hidden layers (default: tf.nn.relu)
o_dim (int) : number of output units
o_fn : activation function for output layers (defalut: None)
w_init : initialization for weight matrix (defalut: Xavier)
keep_prob : keep probabilty [0, 1] (if None, dropout is not employed)
'''
# default active functions (hidden: relu, out: None)
if h_fn is None:
h_fn = tf.nn.relu
if o_fn is None:
o_fn = None
# default initialization functions (weight: Xavier, bias: None)
if w_init is None:
w_init = tf.contrib.layers.xavier_initializer() # Xavier initialization
for layer in range(num_layers):
if num_layers == 1:
out = FC_Net(inputs, o_dim, activation_fn=o_fn, weights_initializer=w_init, weights_regularizer=w_reg)
else:
if layer == 0:
h = FC_Net(inputs, h_dim, activation_fn=h_fn, weights_initializer=w_init, weights_regularizer=w_reg)
if not keep_prob is None:
h = tf.nn.dropout(h, keep_prob=keep_prob)
elif layer > 0 and layer != (num_layers-1): # layer > 0:
h = FC_Net(h, h_dim, activation_fn=h_fn, weights_initializer=w_init, weights_regularizer=w_reg)
if not keep_prob is None:
h = tf.nn.dropout(h, keep_prob=keep_prob)
else: # layer == num_layers-1 (the last layer)
out = FC_Net(h, o_dim, activation_fn=o_fn, weights_initializer=w_init, weights_regularizer=w_reg)
return out