Skip to content

Commit

Permalink
added code to train model, trained PI-Net model and other supporting …
Browse files Browse the repository at this point in the history
…files
  • Loading branch information
Anirudh Som committed Jun 2, 2019
1 parent 6cf01d5 commit 0d708e2
Show file tree
Hide file tree
Showing 26 changed files with 398 additions and 1 deletion.
Binary file added Examples/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/12.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/13.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/14.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/15.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/17.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/18.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/19.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/20.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Examples/9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PI-Net_CIFAR10_model.zip
Binary file not shown.
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,23 @@
# PI-Net
# Code for PI-Net: A Deep Learning Approach to Extract Topological Persistence Images

Here we provide sample code to compute persistence images (PIs) using the proposed Image PI-Net model. We load weights from a pre-trained model trained on the CIFAR10 dataset.


## Key Files

For sample test-set images in CIFAR10, both files first load weghts from a pre-trained Image PI-Net model; next, compute PIs using the Image PI-Net model and finally compare the generated PIs to ground-truth PIs obtained using conventional topological data analysis (TDA) tools. In addition, the "main.py" file saves the PI comparisons for each sample image in the "Examples folder".

- main.ipynb

- main.py

## Required Packages

We ask the reviewers to run the code on a linux machine and have the following packages installed. We assume all necessary packages are already installed.

- numpy
- scipy
- matplotlib
- keras (with tensorflow backend)

**Note:** If you have trouble running these codes, we illustrate the generated PIs in the "Examples" folder and for each image compare the generated PIs using the PI-Net model to the ground-truth PIs.
Binary file added Sample_Images_PI.mat
Binary file not shown.
234 changes: 234 additions & 0 deletions main.ipynb

Large diffs are not rendered by default.

53 changes: 53 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
## Load packages
print("\n Loading packages ...")

import os
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from tensorflow.python.keras import models, layers, losses, optimizers, utils
from tensorflow.python.keras import backend as K

from pi_net import *

## Load images and ground-truth persistence images
print("\n Loading images and ground-truth PIs ...")
temp = sio.loadmat('Sample_Images_PI.mat')
imgs = temp['imgs']
PIs = temp['PIs']

## Load model and weights
print("\n Loading model and weights ...")
model = PINet_CIFAR10()
model.load_weights('PI-Net_CIFAR10.h5')

## Generate PIs using PI-Net
print("\n Generating PIs ...")
PIs_generated = model.predict(imgs)

## Saving generated PIs
if not os.path.exists('Examples'):
os.makedirs('Examples')

j = 0
for i in range(len(imgs)):
fig = plt.figure(figsize = (15,5))#,frameon=False)
fig.add_subplot(131)
plt.imshow(imgs[i])
plt.title('Input Image',fontdict={'fontsize':20})

fig.add_subplot(132)
plt.imshow(PIs[i].reshape((3,50,50))[j])
plt.colorbar()
plt.clim(0,1)
plt.title('Ground-truth PI',fontdict={'fontsize':20})

fig.add_subplot(133)
plt.imshow(PIs_generated[i].reshape((3,50,50))[j])
plt.colorbar()
plt.clim(0,0.8)
plt.title('Generated PI',fontdict={'fontsize':20})

fig.savefig('Examples/' + str(i+1) + '.png' )

print("\n Please go into 'Examples' folder to view saved images \n")
88 changes: 88 additions & 0 deletions pi_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from tensorflow.python.keras import models, layers, losses, optimizers, utils
from tensorflow.python.keras import backend as K


def PINet_CIFAR10():

## model
input_shape = [32,32,3]
initial_conv_width=3
initial_stride=1
initial_filters=64
initial_pool_width=3
initial_pool_stride=2
use_global_pooling = True
dropout_rate = 0.2

model_input = layers.Input(shape=input_shape)

x = layers.Conv2D(
128,
initial_conv_width,
strides=initial_stride,
padding="same")(model_input)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.MaxPooling2D(
pool_size=initial_pool_width,
strides=initial_pool_stride,
padding="same")(x)

x = layers.Conv2D(
256,
initial_conv_width,
strides=initial_stride,
padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.MaxPooling2D(
pool_size=initial_pool_width,
strides=initial_pool_stride,
padding="same")(x)

x = layers.Conv2D(
512,
initial_conv_width,
strides=initial_stride,
padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.MaxPooling2D(
pool_size=initial_pool_width,
strides=initial_pool_stride,
padding="same")(x)

x = layers.Conv2D(
1024,
initial_conv_width,
strides=initial_stride,
padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)

if use_global_pooling:
x = layers.GlobalAveragePooling2D()(x)


x_logits1 = layers.Dense(2500, activation="relu")(x)

x_logits1_reshape = layers.Reshape((1,50,50))(x_logits1)

x_logits1_reshape = layers.Permute((2,3,1))(x_logits1_reshape)

x_logits2 = layers.Conv2DTranspose(
3,
50,
strides=initial_stride,
padding="same")(x_logits1_reshape)
x_logits2 = layers.BatchNormalization()(x_logits2)
x_logits2 = layers.Activation("relu")(x_logits2)

model_output = layers.Flatten()(x_logits2)

model = models.Model(model_input, model_output)

return model

0 comments on commit 0d708e2

Please sign in to comment.