Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
xbresson committed Oct 1, 2017
1 parent c61a4a3 commit 6581ba7
Show file tree
Hide file tree
Showing 14 changed files with 26,769 additions and 0 deletions.
575 changes: 575 additions & 0 deletions 01_standard_convnet_lenet5_mnist_pytorch.ipynb

Large diffs are not rendered by default.

746 changes: 746 additions & 0 deletions 02_graph_convnet_lenet5_mnist_pytorch.ipynb

Large diffs are not rendered by default.

87 changes: 87 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Graph ConvNets in PyTorch
September 30, 2017
<br>
<br>



<img align="right" src="pic/graph_convnet.jpg" style="width: 200px;"/>


### Xavier Bresson
<img src="pic/home100.jpg" style="width: 15px;"/> http://www.ntu.edu.sg/home/xbresson<br>
<img src="pic/github100.jpg" style="width: 15px;"/> https://github.com/xbresson<br>
<img src="pic/twitter100.jpg" style="width: 15px;"/> https://twitter.com/xbresson <br>
<br>


### Description
Prototype implementation in PyTorch of the NIPS'16 paper:<br>
Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering<br>
M Defferrard, X Bresson, P Vandergheynst<br>
Advances in Neural Information Processing Systems, 3844-3852, 2016<br>
ArXiv preprint: [arXiv:1606.09375](https://arxiv.org/pdf/1606.09375.pdf) <br>
<br>

### Code objective
The code provides a simple example of graph ConvNets for the MNIST classification task.<br>
The graph is a 8-nearest neighbor graph of a 2D grid.<br>
The signals on graph are the MNIST images vectorized as $28^2 \times 1$ vectors.<br>
<br>


### Installation
```sh
git clone https://github.com/xbresson/graph_cnn_pytorch
cd graph_cnn_pytorch
pip install -r requirements.txt # installation for python 3.6.2
python check_install.py
jupyter notebook
```

<br>



### Results
GPU Quadro M4000<br>
* Standard ConvNets: **01_standard_convnet_lenet5_mnist_pytorch.ipynb**, accuracy= 99.31, speed= 6.9 sec/epoch. <br>
* Graph ConvNets: **02_graph_convnet_lenet5_mnist_pytorch.ipynb**, accuracy= 99.21, speed= 100.8 sec/epoch <br>
<br>


### Note
PyTorch has not yet implemented function torch.mm(sparse, dense) for variables: https://github.com/pytorch/pytorch/issues/2389. It will be certainly implemented in the near future but in the meantime, I implemented a new autograd function for sparse variables, called "my_sparse_mm", by subclassing torch.autograd.function and implementing the forward and backward passes.


```python
class my_sparse_mm(torch.autograd.Function):
"""
Implementation of a new autograd function for sparse variables,
called "my_sparse_mm", by subclassing torch.autograd.Function
and implementing the forward and backward passes.
"""

def forward(self, W, x): # W is SPARSE
self.save_for_backward(W, x)
y = torch.mm(W, x)
return y

def backward(self, grad_output):
W, x = self.saved_tensors
grad_input = grad_output.clone()
grad_input_dL_dW = torch.mm(grad_input, x.t())
grad_input_dL_dx = torch.mm(W, grad_input )
return grad_input_dL_dW, grad_input_dL_dx
```
<br>


### When to use this algorithm?
Any problem that can be cast as analyzing a set of signals on a fixed graph, and you want to use ConvNets for this analysis.<br>

<br>

<br>
<br>

33 changes: 33 additions & 0 deletions check_install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/env python3

print('\nRun Python installation test for graph ConvNets')

import os
import sys
major, minor = sys.version_info.major, sys.version_info.minor

if ( (major is not 3) or (minor is not 6) ):
raise Exception('please use Python 3.6 (for PyTorch), you have Python {}.{}'.format(major,minor))

try:
import numpy
print('Recommended version of numpy is {}. You have {}.'.format('1.13.1',numpy.__version__))
import jupyter
print('Recommended version of jupyter is {}. You have {}.'.format('1.0.0',jupyter.__version__))
import torch
print('Recommended version of pytorch is {}. You have {}.'.format('0.2.0_1',torch.__version__))
import tensorflow
print('Recommended version of tensorflow is {}. You have {}.'.format('0.11.0rc0',tensorflow.__version__))
import scipy
print('Recommended version of scipy is {}. You have {}.'.format('0.19.1',scipy.__version__))
import sklearn
print('Recommended version of sklearn is {}. You have {}.'.format('0.19.0',sklearn.__version__))


except:
print('A package is missing or the version of the package.')
print('Install the package below.\n')
raise

print('Successful installation of Python {}.{} and '
'most of the packages required to run the code.\n'.format(major, minor))
Loading

0 comments on commit 6581ba7

Please sign in to comment.