-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_utils.py
452 lines (354 loc) · 13.2 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
""" SETI ML model generation utilities.
Notes
-----
This module is not intended to be run as a script.
Authors
-------
| Paul Pinchuk ([email protected])
Jean-Luc Margot UCLA SETI Group.
University of California, Los Angeles.
Copyright 2021. All rights reserved.
"""
import re
import tensorflow as tf
import tempfile
from functools import partial
from itertools import chain
Conv2D = partial(tf.keras.layers.Conv2D, padding='same', use_bias=False)
class MCDropout(tf.keras.layers.Dropout):
""" A Dropout Layer that applies Monte Carlo Dropout to the input.
References
----------
[1] Géron, Aurélien. Hands-On Machine Learning with Scikit-Learn,
Keras, and TensorFlow: Concepts, Tools, and Techniques to Build
Intelligent Systems. O'Reilly Media, 2019. pp. 368-370.
"""
def call(self, inputs, **_):
return super().call(inputs, training=True)
class ResidualUnit(tf.keras.layers.Layer):
""" A single residual unit (building block of residual networks).
Parameters
----------
filters : int
Number of filters to use in each convolutional layer.
strides : int, optional
Number of strides to use for each convolutional layer.
Recommended to be either 1 or 2.
activation : str, optional
Name of activation function to use.
**kwargs
Keyword arguments for :cls:`tf.keras.layer.Layer` class.
References
----------
[1] Géron, Aurélien. Hands-On Machine Learning with Scikit-Learn,
Keras, and TensorFlow: Concepts, Tools, and Techniques to Build
Intelligent Systems. O'Reilly Media, 2019. pp. 471-474, 478.
"""
def __init__(self, filters, strides=1, activation='relu', **kwargs):
super().__init__(**kwargs)
self.activation = tf.keras.activations.get(activation)
self.filters = filters
self.strides = strides
self.main_layers = [
Conv2D(filters, kernel_size=3, strides=strides),
tf.keras.layers.BatchNormalization(),
self.activation,
Conv2D(filters, kernel_size=3, strides=1),
tf.keras.layers.BatchNormalization()
]
self.skip_layers = []
if strides > 1:
self.skip_layers = [
Conv2D(filters, kernel_size=1, strides=strides),
tf.keras.layers.BatchNormalization()
]
def call(self, inputs, **_):
skip_inputs = inputs
# apply main forward pass
for layer in self.main_layers:
inputs = layer(inputs)
# perform skip connection processing, if needed
for layer in self.skip_layers:
skip_inputs = layer(skip_inputs)
# noinspection PyCallingNonCallable
return self.activation(inputs + skip_inputs)
def get_config(self):
base_config = super().get_config()
return {
**base_config,
'filters': self.filters,
'strides': self.strides,
'activation': tf.keras.activations.serialize(self.activation),
}
class SqueezeAndExcitationUnit(tf.keras.layers.Layer):
""" A single squeeze-and-excitation unit.
Parameters
----------
n_chan : int
Number of channels in the input to the SE unit.
ratio : int, optional
Ratio of units in the latent space compared to input channels.
Recommended to be factor of `n_chan`.
dense_act : str, optional
Name of activation function to use for Dense layer.
**kwargs
Keyword arguments for :cls:`tf.keras.layer.Layer` class.
References
----------
[1] Hu, Jie, Li Shen, and Gang Sun. "Squeeze-and-excitation networks."
Proceedings of the IEEE conference on computer vision
and pattern recognition. 2018.
[2] Géron, Aurélien. Hands-On Machine Learning with Scikit-Learn,
Keras, and TensorFlow: Concepts, Tools, and Techniques to Build
Intelligent Systems. O'Reilly Media, 2019. pp. 476-478.
[3] https://towardsdatascience.com/squeeze-and-excitation-networks-9ef5e71eacd7
"""
def __init__(self, n_chan, ratio=16, dense_act='relu', **kwargs):
super().__init__(**kwargs)
self.n_chan = n_chan
self.ratio = ratio
self.dense_act = dense_act
self.main_layers = [
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(n_chan // ratio, activation=dense_act),
tf.keras.layers.Dense(n_chan, activation='sigmoid'),
]
def call(self, inputs, **_):
out = self.main_layers[0](inputs)
for layer in self.main_layers[1:]:
out = layer(out)
return tf.keras.layers.Multiply()([inputs, out])
def get_config(self):
base_config = super().get_config()
return {
**base_config,
'n_chan': self.n_chan,
'ratio': self.ratio,
'dense_act': self.dense_act
}
def name_model(model, name):
""" Name a keras model.
Parameters
----------
model : `tensorflow.keras.Model`
Model to be renamed.
name : str
New name for model.
Returns
-------
`tensorflow.keras.Model`
Model identical to the `unnamed_model with the given name.
"""
model = tf.keras.Model(
inputs=model.inputs,
outputs=model.outputs,
name=name
)
return model
def seti_input_layers(
input_layer,
return_input=False,
include_bn=True,
bn_axis=None
):
"""
Parameters
----------
input_layer : tuple or `tensorflow.keras.layers.Layer` instance
Either the deesired shape of the input tensor in tuple
format or an instance of the desired input layer.
return_input : bool, optional
Option to return the input layer along with the
output of the batch_normalization layer, if applicable.
include_bn : bool, optional
Option to include a batch normalization layer
immediately after the input layer.
bn_axis : int or tuple, optional
Axis parameter to pass to Batch Normalization layer.
If `None`, then axis=[1, 2].
Returns
-------
outputs
Either a tensor corresponding to the output
of the input layers or the same tensor plus
the input layer, if requested.
"""
try:
input_layer = tf.keras.layers.Input(shape=input_layer, name='inputs')
except TypeError:
# if this error is thrown, we assume
# input layer is already a layer
pass
if include_bn:
# this may actually not help
# see https://towardsdatascience.com/how-to-potty-train-a-siamese-network-3df6ca5e44da
if bn_axis is None:
bn_axis = [1, 2]
out = tf.keras.layers.BatchNormalization(axis=bn_axis)(input_layer)
else:
out = input_layer
if return_input:
return input_layer, out
else:
return out
def transfer_weights(old_model, new_model):
""" Transfer the weights from one model to another layer by layer.
For this function to work properly, the layer structure
of both models should be identical. This is useful, for
example, in cases where a dropout layer is replaced with
a monte carlo dropout layer.
Parameters
----------
old_model, new_model : `tensorflow.keras.Model`
`old_model` contains the weights to be transferred
to `new_model`.
"""
for new_layer, old_layer in zip(new_model.layers, old_model.layers):
new_layer.set_weights(old_layer.get_weights())
def insert_layer(
model,
layer_regex,
new_layer_factory,
position='after',
reset_model=True
):
""" Insert new layers into an existing model.
This implementation is a heavily modified version of
the top StackOverflow answer in the reference link below.
Parameters
----------
model : `tensorflow.keras.Model`
Model instance containing the original layers
that should be modified in some way.
layer_regex : str
Regular expression used to match to layers
for which the new layer insertion
should happen around.
new_layer_factory : callable
A callable that takes the layer matching the
`layer_regex` as input and outputs
an *iterable* of layers to insert.
position : {'before', 'after', 'replace'}, optional
Flag indicating the position of the inserted layer w.r.t
the layer matching the `layer_regex`. Must be one of
'before', 'after', or 'replace'.
reset_model : bool, optional
Option to save and immediately load model graph. This
is highly recommended in order to avoid any problems
when using this function multiple times.
Returns
-------
`tensorflow.keras.Model`
New model instance containing the inserted layers.
References
----------
[1] https://stackoverflow.com/questions/49492255/how-to-replace-or-insert-intermediate-layer-in-keras-model
"""
if position not in ('before', 'after', 'replace'):
raise ValueError('position must be: before, after or replace')
network_dict = convert_model_to_input_output_dict(model)
tf.keras.backend.clear_session()
model_outputs = []
for layer in model.layers[1:]:
x = _inputs_to_layer(layer, network_dict)
if re.match(layer_regex, layer.name):
new_layers = _insert_new_layer(layer, new_layer_factory, position)
for new_layer in new_layers:
x = new_layer(x)
else:
x = _recreate_layer(layer)(x)
network_dict['output_tensor_of'][layer.name] = x
if layer.name in model.output_names:
model_outputs.append(x)
model = tf.keras.Model(
inputs=model.inputs,
outputs=model_outputs,
name=model.name
)
if reset_model:
with tempfile.TemporaryDirectory() as dir_name:
model.save(dir_name)
model = tf.keras.models.load_model(dir_name)
return model
def convert_model_to_input_output_dict(model):
""" Convert a model graph to a dictionary with input/output info.
The dictionary will contain two pieces of info for each layer:
the input layers and the output tensor.
Parameters
----------
model : tensorflow.keras.Model instance
Model containing layers to analyze.
Returns
-------
dict
For every layer in the model, this dictionary contains the
input layers ('input_layers_of') as well as the output tensor
('output_tensor_of').
"""
# Auxiliary dictionary to describe the network graph
network_dict = {'input_layers_of': {}, 'output_tensor_of': {}}
# Set the input layers of each layer
for layer in model.layers:
for node in layer.outbound_nodes:
layer_name = node.outbound_layer.name
network_dict['input_layers_of'].setdefault(
layer_name, []
).append(layer.name)
network_dict['output_tensor_of'][model.layers[0].name] = model.input
return network_dict
def _inputs_to_layer(layer, network_dict):
""" Extract the inputs to `layer` from the network dict. """
x = [
network_dict['output_tensor_of'][layer_in]
for layer_in in network_dict['input_layers_of'][layer.name]
]
if len(x) == 1:
x = x[0]
return x
def _insert_new_layer(old_layer, new_layer_factory, position):
""" Insert the new layer(s) in correct position w.r.t the old layer. """
new_layers = _new_layers_as_iterable(new_layer_factory, old_layer)
# Instead of re-using the old layer, we create a copy so that the
# graph is clean
layer = _recreate_layer(old_layer)
if position == 'after':
new_layers = chain([layer], new_layers)
elif position == 'before':
new_layers = chain(new_layers, [layer])
return new_layers
def _recreate_layer(layer):
layer_conf = layer.get_config()
# the catch below is not fail-safe... If the user names the layer
# something other than "Squeeze", this function will fail
if 'squeeze' in layer_conf['name']:
return SqueezeAndExcitationUnit(**layer_conf)
return layer.__class__(**layer_conf)
def _new_layers_as_iterable(new_layer_factory, old_layer):
""" Call the layer factory and return output as an iterable container. """
new_layers = new_layer_factory(old_layer)
try:
iter(new_layers)
except TypeError:
new_layers = [new_layers]
return new_layers
def __sequential_standard_ResNet34():
""" Create a ResNet34 using `tf.keras.models.Sequential`. """
model = tf.keras.models.Sequential()
model.add(Conv2D(
filters=64, kernel_size=7, strides=2, input_shape=[224, 224, 3]
))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Activation('relu'))
model.add(tf.keras.layers.MaxPool2D(
pool_size=3, strides=2, padding='same'
))
prev_filters = 64
for filters in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
model.add(ResidualUnit(
filters=filters, strides=1 if filters == prev_filters else 2
))
prev_filters = filters
model.add(tf.keras.layers.GlobalAvgPool2D())
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(units=10, activation='softmax'))
return model