Skip to content

Commit

Permalink
Merge pull request #16 from dennistang742/master
Browse files Browse the repository at this point in the history
Change for route parser and compatible with Tiny YOLO V3
  • Loading branch information
xiaochus authored Nov 30, 2018
2 parents bc1866c + e0d29b1 commit db63e48
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions yad2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
from keras import backend as K
from keras.layers import (Conv2D, GlobalAveragePooling2D, Input, Reshape,
ZeroPadding2D, UpSampling2D, Activation)
ZeroPadding2D, UpSampling2D, Activation, Lambda, MaxPooling2D)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.merge import concatenate, add
from keras.layers.normalization import BatchNormalization
Expand Down Expand Up @@ -198,7 +198,17 @@ def _main(args):
elif activation == 'leaky':
act_layer = LeakyReLU(alpha=0.1)(prev_layer)
prev_layer = act_layer
all_layers.append(act_layer)
all_layers.append(prev_layer)

elif section.startswith('maxpool'):
size = int(cfg_parser[section]['size'])
stride = int(cfg_parser[section]['stride'])
all_layers.append(
MaxPooling2D(
padding='same',
pool_size=(size, size),
strides=(stride, stride))(prev_layer))
prev_layer = all_layers[-1]

elif section.startswith('avgpool'):
if cfg_parser.items(section) != []:
Expand All @@ -208,6 +218,11 @@ def _main(args):

elif section.startswith('route'):
ids = [int(i) for i in cfg_parser[section]['layers'].split(',')]
if len(ids) == 2:
for i, item in enumerate(ids):
if item != -1:
ids[i] = item + 1

layers = [all_layers[i] for i in ids]

if len(layers) > 1:
Expand Down Expand Up @@ -268,9 +283,8 @@ def _main(args):
if remaining_weights > 0:
print('Warning: {} unused weights'.format(remaining_weights))

if args.plot_model:
plot(model, to_file='{}.png'.format(output_root), show_shapes=True)
print('Saved model plot to {}.png'.format(output_root))
plot(model, to_file='{}.png'.format(output_root), show_shapes=True)
print('Saved model plot to {}.png'.format(output_root))


if __name__ == '__main__':
Expand Down

0 comments on commit db63e48

Please sign in to comment.