-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #112 from andreped/batch-norm-fix
Added method to replace BN layers [no ci]
- Loading branch information
Showing
7 changed files
with
196 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import tensorflow as tf | ||
|
||
from .layers import AccumBatchNormalization | ||
|
||
|
||
def replace_batchnorm_layers(model, accum_steps, position="replace"): | ||
# Auxiliary dictionary to describe the network graph | ||
network_dict = {"input_layers_of": {}, "new_output_tensor_of": {}} | ||
|
||
# Set the input layers of each layer | ||
for layer in model.layers: | ||
for node in layer._outbound_nodes: | ||
layer_name = node.outbound_layer.name | ||
if layer_name not in network_dict["input_layers_of"]: | ||
network_dict["input_layers_of"].update( | ||
{layer_name: [layer.name]} | ||
) | ||
else: | ||
network_dict["input_layers_of"][layer_name].append(layer.name) | ||
|
||
# Set the output tensor of the input layer | ||
network_dict["new_output_tensor_of"].update( | ||
{model.layers[0].name: model.input} | ||
) | ||
|
||
# Iterate over all layers after the input | ||
model_outputs = [] | ||
iter_ = 0 | ||
for layer in model.layers[1:]: | ||
|
||
# Determine input tensors | ||
layer_input = [ | ||
network_dict["new_output_tensor_of"][layer_aux] | ||
for layer_aux in network_dict["input_layers_of"][layer.name] | ||
] | ||
if len(layer_input) == 1: | ||
layer_input = layer_input[0] | ||
|
||
# Insert layer if name matches | ||
if isinstance(layer, tf.keras.layers.BatchNormalization): | ||
if position == "replace": | ||
x = layer_input | ||
else: | ||
raise ValueError("position must be: replace") | ||
|
||
# build new layer | ||
new_layer = AccumBatchNormalization( | ||
accum_steps=accum_steps, | ||
name="AccumBatchNormalization_" + str(iter_), | ||
) | ||
new_layer.build(input_shape=layer.input_shape) | ||
|
||
iter_ += 1 | ||
|
||
# set weights in new layer to match old layer | ||
new_layer.accum_mean = layer.moving_mean | ||
new_layer.moving_mean = layer.moving_mean | ||
|
||
new_layer.accum_variance = layer.moving_variance | ||
new_layer.moving_variance = layer.moving_variance | ||
|
||
# forward step | ||
x = new_layer(x) | ||
|
||
else: | ||
x = layer(layer_input) | ||
|
||
# Set new output tensor (original one/the one of the inserted layer) | ||
network_dict["new_output_tensor_of"].update({layer.name: x}) | ||
|
||
# Save tensor in output list if it is output in initial model | ||
if layer_name in model.output_names: | ||
model_outputs.append(x) | ||
|
||
return tf.keras.Model(inputs=model.inputs, outputs=x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
|
||
setuptools.setup( | ||
name="gradient-accumulator", | ||
version="0.5.1", | ||
version="0.5.2", | ||
author="André Pedersen and David Bouget and Javier Pérez de Frutos and Tor-Arne Schmidt Nordmo", | ||
author_email="[email protected]", | ||
description="Package for gradient accumulation in TensorFlow", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import os | ||
import random as python_random | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
from tensorflow.keras.models import load_model | ||
|
||
from gradient_accumulator import GradientAccumulateModel | ||
from gradient_accumulator.layers import AccumBatchNormalization | ||
from gradient_accumulator.utils import replace_batchnorm_layers | ||
|
||
from .utils import gray2rgb | ||
from .utils import normalize_img | ||
from .utils import reset | ||
from .utils import resizeImage | ||
|
||
|
||
def test_swap_layer( | ||
custom_bn: bool = True, bs: int = 100, accum_steps: int = 1, epochs: int = 1 | ||
): | ||
# load dataset | ||
(ds_train, ds_test), ds_info = tfds.load( | ||
"mnist", | ||
split=["train", "test"], | ||
shuffle_files=True, | ||
as_supervised=True, | ||
with_info=True, | ||
) | ||
|
||
# build train pipeline | ||
ds_train = ds_train.map(normalize_img) | ||
ds_train = ds_train.map(gray2rgb) | ||
ds_train = ds_train.map(resizeImage) | ||
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) | ||
ds_train = ds_train.batch(bs) | ||
ds_train = ds_train.prefetch(1) | ||
|
||
# build test pipeline | ||
ds_test = ds_test.map(normalize_img) | ||
ds_test = ds_test.map(gray2rgb) | ||
ds_test = ds_test.map(resizeImage) | ||
ds_test = ds_test.batch(bs) | ||
ds_test = ds_test.prefetch(1) | ||
|
||
# create model | ||
base_model = tf.keras.applications.MobileNetV2(input_shape=(32, 32, 3), weights="imagenet", include_top=False) | ||
base_model = replace_batchnorm_layers(base_model, accum_steps=accum_steps) | ||
|
||
input_ = tf.keras.layers.Input(shape=(32, 32, 3)) | ||
x = base_model(input_) | ||
x = tf.keras.layers.Dense(10, activation="softmax")(x) | ||
model = tf.keras.Model(inputs=input_, outputs=x) | ||
|
||
# wrap model to use gradient accumulation | ||
if accum_steps > 1: | ||
model = GradientAccumulateModel( | ||
accum_steps=accum_steps, inputs=model.input, outputs=model.output | ||
) | ||
|
||
# compile model | ||
model.compile( | ||
optimizer=tf.keras.optimizers.SGD(1e-2), | ||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], | ||
) | ||
|
||
# train model | ||
model.fit( | ||
ds_train, | ||
epochs=epochs, | ||
validation_data=ds_test, | ||
steps_per_epoch=4, | ||
validation_steps=4, | ||
) | ||
|
||
model.save("./trained_model") | ||
|
||
# load trained model and test | ||
del model | ||
trained_model = load_model("./trained_model", compile=True) | ||
|
||
result = trained_model.evaluate(ds_test, verbose=1) | ||
print(result) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters