Skip to content

Commit

Permalink
fix predict() function error about image_dim_ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
nhzc123 committed Jan 13, 2017
1 parent 15c1a3a commit d08e425
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions boss_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,12 @@ def load(self, file_path=FILE_PATH):
self.model = load_model(file_path)

def predict(self, image):
if image.shape != (1, 3, IMAGE_SIZE, IMAGE_SIZE):
if K.image_dim_ordering() == 'th' and image.shape != (1, 3, IMAGE_SIZE, IMAGE_SIZE):
image = resize_with_pad(image)
image = image.reshape((1, 3, IMAGE_SIZE, IMAGE_SIZE))
elif K.image_dim_ordering() == 'tf' and image.shape != (1, IMAGE_SIZE, IMAGE_SIZE, 3):
image = resize_with_pad(image)
image = image.reshape((1, IMAGE_SIZE, IMAGE_SIZE, 3))
image = image.astype('float32')
image /= 255
result = self.model.predict_proba(image)
Expand All @@ -177,4 +180,4 @@ def evaluate(self, dataset):

model = Model()
model.load()
model.evaluate(dataset)
model.evaluate(dataset)

0 comments on commit d08e425

Please sign in to comment.