-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
26,769 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Graph ConvNets in PyTorch | ||
September 30, 2017 | ||
<br> | ||
<br> | ||
|
||
|
||
|
||
<img align="right" src="pic/graph_convnet.jpg" style="width: 200px;"/> | ||
|
||
|
||
### Xavier Bresson | ||
<img src="pic/home100.jpg" style="width: 15px;"/> http://www.ntu.edu.sg/home/xbresson<br> | ||
<img src="pic/github100.jpg" style="width: 15px;"/> https://github.com/xbresson<br> | ||
<img src="pic/twitter100.jpg" style="width: 15px;"/> https://twitter.com/xbresson <br> | ||
<br> | ||
|
||
|
||
### Description | ||
Prototype implementation in PyTorch of the NIPS'16 paper:<br> | ||
Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering<br> | ||
M Defferrard, X Bresson, P Vandergheynst<br> | ||
Advances in Neural Information Processing Systems, 3844-3852, 2016<br> | ||
ArXiv preprint: [arXiv:1606.09375](https://arxiv.org/pdf/1606.09375.pdf) <br> | ||
<br> | ||
|
||
### Code objective | ||
The code provides a simple example of graph ConvNets for the MNIST classification task.<br> | ||
The graph is a 8-nearest neighbor graph of a 2D grid.<br> | ||
The signals on graph are the MNIST images vectorized as $28^2 \times 1$ vectors.<br> | ||
<br> | ||
|
||
|
||
### Installation | ||
```sh | ||
git clone https://github.com/xbresson/graph_cnn_pytorch | ||
cd graph_cnn_pytorch | ||
pip install -r requirements.txt # installation for python 3.6.2 | ||
python check_install.py | ||
jupyter notebook | ||
``` | ||
|
||
<br> | ||
|
||
|
||
|
||
### Results | ||
GPU Quadro M4000<br> | ||
* Standard ConvNets: **01_standard_convnet_lenet5_mnist_pytorch.ipynb**, accuracy= 99.31, speed= 6.9 sec/epoch. <br> | ||
* Graph ConvNets: **02_graph_convnet_lenet5_mnist_pytorch.ipynb**, accuracy= 99.21, speed= 100.8 sec/epoch <br> | ||
<br> | ||
|
||
|
||
### Note | ||
PyTorch has not yet implemented function torch.mm(sparse, dense) for variables: https://github.com/pytorch/pytorch/issues/2389. It will be certainly implemented in the near future but in the meantime, I implemented a new autograd function for sparse variables, called "my_sparse_mm", by subclassing torch.autograd.function and implementing the forward and backward passes. | ||
|
||
|
||
```python | ||
class my_sparse_mm(torch.autograd.Function): | ||
""" | ||
Implementation of a new autograd function for sparse variables, | ||
called "my_sparse_mm", by subclassing torch.autograd.Function | ||
and implementing the forward and backward passes. | ||
""" | ||
|
||
def forward(self, W, x): # W is SPARSE | ||
self.save_for_backward(W, x) | ||
y = torch.mm(W, x) | ||
return y | ||
|
||
def backward(self, grad_output): | ||
W, x = self.saved_tensors | ||
grad_input = grad_output.clone() | ||
grad_input_dL_dW = torch.mm(grad_input, x.t()) | ||
grad_input_dL_dx = torch.mm(W, grad_input ) | ||
return grad_input_dL_dW, grad_input_dL_dx | ||
``` | ||
<br> | ||
|
||
|
||
### When to use this algorithm? | ||
Any problem that can be cast as analyzing a set of signals on a fixed graph, and you want to use ConvNets for this analysis.<br> | ||
|
||
<br> | ||
|
||
<br> | ||
<br> | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#!/bin/env python3 | ||
|
||
print('\nRun Python installation test for graph ConvNets') | ||
|
||
import os | ||
import sys | ||
major, minor = sys.version_info.major, sys.version_info.minor | ||
|
||
if ( (major is not 3) or (minor is not 6) ): | ||
raise Exception('please use Python 3.6 (for PyTorch), you have Python {}.{}'.format(major,minor)) | ||
|
||
try: | ||
import numpy | ||
print('Recommended version of numpy is {}. You have {}.'.format('1.13.1',numpy.__version__)) | ||
import jupyter | ||
print('Recommended version of jupyter is {}. You have {}.'.format('1.0.0',jupyter.__version__)) | ||
import torch | ||
print('Recommended version of pytorch is {}. You have {}.'.format('0.2.0_1',torch.__version__)) | ||
import tensorflow | ||
print('Recommended version of tensorflow is {}. You have {}.'.format('0.11.0rc0',tensorflow.__version__)) | ||
import scipy | ||
print('Recommended version of scipy is {}. You have {}.'.format('0.19.1',scipy.__version__)) | ||
import sklearn | ||
print('Recommended version of sklearn is {}. You have {}.'.format('0.19.0',sklearn.__version__)) | ||
|
||
|
||
except: | ||
print('A package is missing or the version of the package.') | ||
print('Install the package below.\n') | ||
raise | ||
|
||
print('Successful installation of Python {}.{} and ' | ||
'most of the packages required to run the code.\n'.format(major, minor)) |
Oops, something went wrong.