forked from anirudhsom/PI-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added code to train model, trained PI-Net model and other supporting …
…files
- Loading branch information
Anirudh Som
committed
Jun 2, 2019
1 parent
6cf01d5
commit 0d708e2
Showing
26 changed files
with
398 additions
and
1 deletion.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,23 @@ | ||
# PI-Net | ||
# Code for PI-Net: A Deep Learning Approach to Extract Topological Persistence Images | ||
|
||
Here we provide sample code to compute persistence images (PIs) using the proposed Image PI-Net model. We load weights from a pre-trained model trained on the CIFAR10 dataset. | ||
|
||
|
||
## Key Files | ||
|
||
For sample test-set images in CIFAR10, both files first load weghts from a pre-trained Image PI-Net model; next, compute PIs using the Image PI-Net model and finally compare the generated PIs to ground-truth PIs obtained using conventional topological data analysis (TDA) tools. In addition, the "main.py" file saves the PI comparisons for each sample image in the "Examples folder". | ||
|
||
- main.ipynb | ||
|
||
- main.py | ||
|
||
## Required Packages | ||
|
||
We ask the reviewers to run the code on a linux machine and have the following packages installed. We assume all necessary packages are already installed. | ||
|
||
- numpy | ||
- scipy | ||
- matplotlib | ||
- keras (with tensorflow backend) | ||
|
||
**Note:** If you have trouble running these codes, we illustrate the generated PIs in the "Examples" folder and for each image compare the generated PIs using the PI-Net model to the ground-truth PIs. |
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
## Load packages | ||
print("\n Loading packages ...") | ||
|
||
import os | ||
import numpy as np | ||
import scipy.io as sio | ||
import matplotlib.pyplot as plt | ||
from tensorflow.python.keras import models, layers, losses, optimizers, utils | ||
from tensorflow.python.keras import backend as K | ||
|
||
from pi_net import * | ||
|
||
## Load images and ground-truth persistence images | ||
print("\n Loading images and ground-truth PIs ...") | ||
temp = sio.loadmat('Sample_Images_PI.mat') | ||
imgs = temp['imgs'] | ||
PIs = temp['PIs'] | ||
|
||
## Load model and weights | ||
print("\n Loading model and weights ...") | ||
model = PINet_CIFAR10() | ||
model.load_weights('PI-Net_CIFAR10.h5') | ||
|
||
## Generate PIs using PI-Net | ||
print("\n Generating PIs ...") | ||
PIs_generated = model.predict(imgs) | ||
|
||
## Saving generated PIs | ||
if not os.path.exists('Examples'): | ||
os.makedirs('Examples') | ||
|
||
j = 0 | ||
for i in range(len(imgs)): | ||
fig = plt.figure(figsize = (15,5))#,frameon=False) | ||
fig.add_subplot(131) | ||
plt.imshow(imgs[i]) | ||
plt.title('Input Image',fontdict={'fontsize':20}) | ||
|
||
fig.add_subplot(132) | ||
plt.imshow(PIs[i].reshape((3,50,50))[j]) | ||
plt.colorbar() | ||
plt.clim(0,1) | ||
plt.title('Ground-truth PI',fontdict={'fontsize':20}) | ||
|
||
fig.add_subplot(133) | ||
plt.imshow(PIs_generated[i].reshape((3,50,50))[j]) | ||
plt.colorbar() | ||
plt.clim(0,0.8) | ||
plt.title('Generated PI',fontdict={'fontsize':20}) | ||
|
||
fig.savefig('Examples/' + str(i+1) + '.png' ) | ||
|
||
print("\n Please go into 'Examples' folder to view saved images \n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import numpy as np | ||
import scipy.io as sio | ||
import matplotlib.pyplot as plt | ||
from tensorflow.python.keras import models, layers, losses, optimizers, utils | ||
from tensorflow.python.keras import backend as K | ||
|
||
|
||
def PINet_CIFAR10(): | ||
|
||
## model | ||
input_shape = [32,32,3] | ||
initial_conv_width=3 | ||
initial_stride=1 | ||
initial_filters=64 | ||
initial_pool_width=3 | ||
initial_pool_stride=2 | ||
use_global_pooling = True | ||
dropout_rate = 0.2 | ||
|
||
model_input = layers.Input(shape=input_shape) | ||
|
||
x = layers.Conv2D( | ||
128, | ||
initial_conv_width, | ||
strides=initial_stride, | ||
padding="same")(model_input) | ||
x = layers.BatchNormalization()(x) | ||
x = layers.Activation("relu")(x) | ||
x = layers.MaxPooling2D( | ||
pool_size=initial_pool_width, | ||
strides=initial_pool_stride, | ||
padding="same")(x) | ||
|
||
x = layers.Conv2D( | ||
256, | ||
initial_conv_width, | ||
strides=initial_stride, | ||
padding="same")(x) | ||
x = layers.BatchNormalization()(x) | ||
x = layers.Activation("relu")(x) | ||
x = layers.MaxPooling2D( | ||
pool_size=initial_pool_width, | ||
strides=initial_pool_stride, | ||
padding="same")(x) | ||
|
||
x = layers.Conv2D( | ||
512, | ||
initial_conv_width, | ||
strides=initial_stride, | ||
padding="same")(x) | ||
x = layers.BatchNormalization()(x) | ||
x = layers.Activation("relu")(x) | ||
x = layers.MaxPooling2D( | ||
pool_size=initial_pool_width, | ||
strides=initial_pool_stride, | ||
padding="same")(x) | ||
|
||
x = layers.Conv2D( | ||
1024, | ||
initial_conv_width, | ||
strides=initial_stride, | ||
padding="same")(x) | ||
x = layers.BatchNormalization()(x) | ||
x = layers.Activation("relu")(x) | ||
|
||
if use_global_pooling: | ||
x = layers.GlobalAveragePooling2D()(x) | ||
|
||
|
||
x_logits1 = layers.Dense(2500, activation="relu")(x) | ||
|
||
x_logits1_reshape = layers.Reshape((1,50,50))(x_logits1) | ||
|
||
x_logits1_reshape = layers.Permute((2,3,1))(x_logits1_reshape) | ||
|
||
x_logits2 = layers.Conv2DTranspose( | ||
3, | ||
50, | ||
strides=initial_stride, | ||
padding="same")(x_logits1_reshape) | ||
x_logits2 = layers.BatchNormalization()(x_logits2) | ||
x_logits2 = layers.Activation("relu")(x_logits2) | ||
|
||
model_output = layers.Flatten()(x_logits2) | ||
|
||
model = models.Model(model_input, model_output) | ||
|
||
return model |