From 9588c9198557d98ae578b682e74847822c11edc6 Mon Sep 17 00:00:00 2001 From: Raphael Delhome Date: Fri, 23 Mar 2018 17:41:22 +0100 Subject: [PATCH 1/3] Implement a simplified version of VGG network --- sources/feature_detection.py | 39 +++++++++++++++++++++++++++++++++++- sources/test.py | 6 +++--- sources/train.py | 9 ++++----- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/sources/feature_detection.py b/sources/feature_detection.py index 2582301f..2040d87a 100644 --- a/sources/feature_detection.py +++ b/sources/feature_detection.py @@ -25,8 +25,15 @@ def __init__(self, network_name="mapillary", image_size=512, nb_channels=3, shape=[None, self._nb_labels]) if netsize == "small": self.add_layers_3_1() - else: + elif netsize == "medium": self.add_layers_6_2() + elif netsize == "vgg": + self.add_vgg_layers() + elif netsize == "inception": + self.add_inception_layers() + else: + utils.logger.error("Unsupported network.") + sys.exit(1) self.compute_loss() self.optimize() self._cm = self.compute_dashboard(self._Y, self._Y_pred) @@ -112,6 +119,36 @@ def add_layers_6_2(self): layer = self.fullyconnected_layer(2, self._is_training, layer, 1024, 512, self._dropout) return self.output_layer(layer, 512) + def add_vgg_layers(self): + """Build the structure of a convolutional neural network from image data `input_layer` + to the last hidden layer on the model of a similar manner than VGG-net (see Simonyan & + Zisserman, Very Deep Convolutional Networks for Large-Scale Image Recognition, arXiv + technical report, 2014) ; not necessarily the *same* structure, as the input shape is not + necessarily identical + + Returns + ------- + tensor + Output layer of the neural network, *i.e.* a 1 X 1 X nb_class structure that contains + model predictions + """ + layer = self.convolutional_layer(1, self._is_training, self._X, self._nb_channels, 3, 64) + layer = self.maxpooling_layer(1, layer, 2, 2) + layer = self.convolutional_layer(2, self._is_training, layer, 64, 3, 128) + layer = self.maxpooling_layer(2, layer, 2, 2) + layer = self.convolutional_layer(3, self._is_training, layer, 128, 3, 256) + layer = self.convolutional_layer(4, self._is_training, layer, 256, 3, 256) + layer = self.maxpooling_layer(3, layer, 2, 2) + layer = self.convolutional_layer(5, self._is_training, layer, 256, 3, 512) + layer = self.convolutional_layer(6, self._is_training, layer, 512, 3, 512) + layer = self.maxpooling_layer(4, layer, 2, 2) + layer = self.convolutional_layer(7, self._is_training, layer, 512, 3, 512) + layer = self.convolutional_layer(8, self._is_training, layer, 512, 3, 512) + layer = self.maxpooling_layer(5, layer, 2, 2) + last_layer_dim = self.get_last_conv_layer_dim(32, 512) + layer = self.fullyconnected_layer(1, self._is_training, layer, last_layer_dim, 1024, self._dropout) + return self.output_layer(layer, 1024) + def compute_loss(self): """Define the loss tensor as well as the optimizer; it uses a decaying learning rate following the equation diff --git a/sources/test.py b/sources/test.py index e683adca..5b492818 100644 --- a/sources/test.py +++ b/sources/test.py @@ -50,7 +50,7 @@ args = parser.parse_args() # instance name decomposition (instance name = name + image size + network size) - _, image_size, network_size, _, aggregate_value, _, _ = args.name.split('_') + _, image_size, network, _, aggregate_value, _, _ = args.name.split('_') image_size = int(image_size) if image_size > 1024: @@ -58,9 +58,9 @@ "reasonable image size (less than 1024)")) sys.exit(1) - if not network_size in ["small", "medium"]: + if not network in ["small", "medium", "vgg", "inception", "inception-v2", "inception-v3"]: utils.logger.error(("Unsupported network size. " - "Please choose 'small' or 'medium'.")) + "Please choose 'small' or 'medium'.")) # TO UPDATE sys.exit(1) if not args.model in ["feature_detection", "semantic_segmentation"]: diff --git a/sources/train.py b/sources/train.py index 99f0e2ef..b4475371 100644 --- a/sources/train.py +++ b/sources/train.py @@ -60,11 +60,11 @@ parser.add_argument('-n', '--name', default="cnnmapil", nargs='?', help=("Model name that will be used for results, " "checkout and graph storage on file system")) - parser.add_argument('-ns', '--network-size', default='small', + parser.add_argument('-N', '--network', default='small', help=("Neural network size, either 'small' or 'medium'" "('small' refers to 3 conv/pool blocks and 1 " "fully-connected layer, and 'medium' refers to 6" - "conv/pool blocks and 2 fully-connected layers)")) + "conv/pool blocks and 2 fully-connected layers)")) # TO UPDATE parser.add_argument('-r', '--learning-rate', required=False, nargs="+", default=[0.01, 1000, 0.95], type=float, help=("List of learning rate components (starting LR, " @@ -127,7 +127,7 @@ "each batch (convex weights with min at 50%)...")) sys.exit(1) - if not args.network_size in ["small", "medium"]: + if not args.network in ["small", "medium", "vgg", "inception", "inception-v2", "inception-v3"]: utils.logger.error(("Unsupported network size description" "Please use this parameter with 'small' or " "'medium' values")) @@ -145,7 +145,7 @@ # Instance name (name + image size + network size + batch_size + aggregate? + dropout + # learning_rate) - instance_args = [args.name, args.image_size, args.network_size, args.batch_size, + instance_args = [args.name, args.image_size, args.network, args.batch_size, aggregate_value, args.dropout, utils.list_to_str(args.learning_rate)] instance_name = utils.list_to_str(instance_args, "_") @@ -233,4 +233,3 @@ utils.logger.error(("Unknown type of model. Please use " "'feature_detection' or 'semantic_segmentation'")) sys.exit(1) - From 00043adaa8f744a1e1847b36462769f252eff955 Mon Sep 17 00:00:00 2001 From: Raphael Delhome Date: Tue, 27 Mar 2018 13:39:22 +0200 Subject: [PATCH 2/3] Implement Inception network --- sources/feature_detection.py | 85 ++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/sources/feature_detection.py b/sources/feature_detection.py index 2040d87a..1a84bc81 100644 --- a/sources/feature_detection.py +++ b/sources/feature_detection.py @@ -149,6 +149,91 @@ def add_vgg_layers(self): layer = self.fullyconnected_layer(1, self._is_training, layer, last_layer_dim, 1024, self._dropout) return self.output_layer(layer, 1024) + def inception_block(self, counter, input_layer, input_depth, depth_1, + depth_3_reduce, depth_3, depth_5_reduce, depth_5, depth_pool): + """Apply an Inception block (concatenation of convoluted inputs, see Szegedy et al, 2014) + + Concatenation of several filtered outputs: + - 1*1 convoluted image + - 1*1 and 3*3 convoluted images + - 1*1 and 5*5 convoluted images + - 3*3 max-pooled and 1*1 convoluted images + + Parameters + ---------- + counter : integer + Inception block ID + input_layer : tensor + Input layer that has to be transformed in the Inception block + input_depth : integer + Input layer depth + depth_1 : integer + Depth of the 1*1 convoluted output + depth_3_reduce : integer + Hidden layer depth, between 1*1 and 3*3 convolution + depth_3 : integer + Depth of the 3*3 convoluted output + depth_5_reduce : integer + Hidden layer depth, between 1*1 and 5*5 convolution + depth_5 : integer + Depth of the 5*5 convoluted output + depth_pool : integer + Depth of the max-pooled output (after 1*1 convolution) + + Returns + ------- + tensor + Output layer, after Inception block treatments + """ + filter_1_1 = self.convolutional_layer("i"+str(counter)+"1", self._is_training, input_layer, + input_depth, 1, depth_1) + filter_3_3 = self.convolutional_layer("i"+str(counter)+"3a", self._is_training, input_layer, + input_depth, 1, depth_3_reduce) + filter_3_3 = self.convolutional_layer("i"+str(counter)+"3b", self._is_training, filter_3_3, + depth_3_reduce, 3, depth_3) + filter_5_5 = self.convolutional_layer("i"+str(counter)+"5a", self._is_training, input_layer, + input_depth, 1, depth_5_reduce) + filter_5_5 = self.convolutional_layer("i"+str(counter)+"5b", self._is_training, filter_5_5, + depth_5_reduce, 5, depth_5) + filter_pool = self.maxpooling_layer("i"+str(counter), input_layer, 3, 1) + filter_pool = self.convolutional_layer("i"+str(counter)+"p", self._is_training, + filter_pool, input_depth, 1, depth_pool) + return tf.concat([filter_1_1, filter_3_3, filter_5_5, filter_pool], axis=3) + + def add_inception_layers(self): + """Build the structure of a convolutional neural network from image data `input_layer` + to the last hidden layer on the model of a similar manner than Inception networks (see + Szegedy et al, Going Deeper with Convolutions, arXiv technical report, 2014) ; not + necessarily the *same* structure, as the input shape is not necessarily identical + + Returns + ------- + tensor + Output layer of the neural network, *i.e.* a 1 X 1 X nb_class structure that contains + model predictions + """ + layer = self.convolutional_layer(1, self._is_training, self._X, self._nb_channels, 7, 64, + 2) + layer = self.maxpooling_layer(1, layer, 3, 2) + layer = self.convolutional_layer(2, self._is_training, layer, 64, 3, 192) + layer = self.maxpooling_layer(2, layer, 3, 2) + layer = self.inception_block('3a', layer, 192, 64, 96, 128, 16, 32, 32) + layer = self.inception_block('3b', layer, 256, 128, 128, 192, 32, 96, 64) + layer = self.maxpooling_layer(3, layer, 3, 2) + layer = self.inception_block('4a', layer, 480, 192, 96, 208, 16, 48, 64) + layer = self.inception_block('4b', layer, 512, 160, 112, 224, 24, 64, 64) + layer = self.inception_block('4c', layer, 512, 128, 128, 256, 24, 64, 64) + layer = self.inception_block('4d', layer, 512, 112, 144, 288, 32, 64, 64) + layer = self.inception_block('4e', layer, 528, 256, 160, 320, 32, 128, 128) + layer = self.maxpooling_layer(4, layer, 3, 2) + layer = self.inception_block('5a', layer, 832, 256, 160, 320, 32, 128, 128) + layer = self.inception_block('5b', layer, 832, 384, 192, 384, 48, 128, 128) + layer = tf.nn.avg_pool(layer, ksize=[1, 7, 7, 1], strides=[1, 1, 1, 1], + padding="VALID", name="avg_pool") + layer = tf.reshape(layer, [-1, 1024]) + layer = tf.nn.dropout(layer, self._dropout, name="final_dropout") + return self.output_layer(layer, 1024) + def compute_loss(self): """Define the loss tensor as well as the optimizer; it uses a decaying learning rate following the equation From 91e22138f3bc8df0e66af7b0f34dd4a56baad7cf Mon Sep 17 00:00:00 2001 From: Raphael Delhome Date: Tue, 27 Mar 2018 17:06:20 +0200 Subject: [PATCH 3/3] Update docstrings regarding chosen network architecture --- sources/train.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sources/train.py b/sources/train.py index b4475371..c25b6edd 100644 --- a/sources/train.py +++ b/sources/train.py @@ -61,10 +61,13 @@ help=("Model name that will be used for results, " "checkout and graph storage on file system")) parser.add_argument('-N', '--network', default='small', - help=("Neural network size, either 'small' or 'medium'" + help=("Neural network size, either 'small', 'medium'," + " 'vgg' or 'inception' " "('small' refers to 3 conv/pool blocks and 1 " - "fully-connected layer, and 'medium' refers to 6" - "conv/pool blocks and 2 fully-connected layers)")) # TO UPDATE + "fully-connected layer; 'medium' refers to 6" + "conv/pool blocks and 2 fully-connected layers;" + " 'vgg' and 'inception' refer to " + "state-of-the-art networks)")) parser.add_argument('-r', '--learning-rate', required=False, nargs="+", default=[0.01, 1000, 0.95], type=float, help=("List of learning rate components (starting LR, " @@ -127,10 +130,10 @@ "each batch (convex weights with min at 50%)...")) sys.exit(1) - if not args.network in ["small", "medium", "vgg", "inception", "inception-v2", "inception-v3"]: + if not args.network in ["small", "medium", "vgg", "inception"]: utils.logger.error(("Unsupported network size description" - "Please use this parameter with 'small' or " - "'medium' values")) + "Please use this parameter with 'small', " + "'medium', 'vgg' or 'inception' values")) sys.exit(1) if not args.model in ["feature_detection", "semantic_segmentation"]: