From d08e425a792c028835169c26fc1f6cc3a5197a48 Mon Sep 17 00:00:00 2001 From: nhzc123 Date: Fri, 13 Jan 2017 14:38:07 +0800 Subject: [PATCH] fix predict() function error about image_dim_ordering --- boss_train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/boss_train.py b/boss_train.py index 5a14eb9..b915231 100644 --- a/boss_train.py +++ b/boss_train.py @@ -151,9 +151,12 @@ def load(self, file_path=FILE_PATH): self.model = load_model(file_path) def predict(self, image): - if image.shape != (1, 3, IMAGE_SIZE, IMAGE_SIZE): + if K.image_dim_ordering() == 'th' and image.shape != (1, 3, IMAGE_SIZE, IMAGE_SIZE): image = resize_with_pad(image) image = image.reshape((1, 3, IMAGE_SIZE, IMAGE_SIZE)) + elif K.image_dim_ordering() == 'tf' and image.shape != (1, IMAGE_SIZE, IMAGE_SIZE, 3): + image = resize_with_pad(image) + image = image.reshape((1, IMAGE_SIZE, IMAGE_SIZE, 3)) image = image.astype('float32') image /= 255 result = self.model.predict_proba(image) @@ -177,4 +180,4 @@ def evaluate(self, dataset): model = Model() model.load() - model.evaluate(dataset) \ No newline at end of file + model.evaluate(dataset)