Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
songrotek committed Mar 28, 2018
1 parent cada92b commit addd2c9
Show file tree
Hide file tree
Showing 30 changed files with 2,509 additions and 1 deletion.
Binary file added .DS_Store
Binary file not shown.
88 changes: 87 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,88 @@
# LearningToCompare_FSL
PyTorch code for CVPR 2018 paper: Learning to Compare: Relation Network for Few-Shot Learning (Few-Shot Learning part)
PyTorch code for CVPR 2018 paper: [Learning to Compare: Relation Network for Few-Shot Learning](https://arxiv.org/abs/1711.06025) (Few-Shot Learning part)

For Zero-Shot Learning part, please visit [here](https://github.com/lzrobots/LearningToCompare_ZSL).

# Requirements

Python 2.7

Pytorch 0.3

# Data

For Omniglot experiments, I directly attach omniglot 28x28 resized images in the git, which is created based on [omniglot](https://github.com/brendenlake/omniglot) and [maml](https://github.com/cbfinn/maml).

For mini-Imagenet experiments, please download [mini-Imagenet](https://drive.google.com/open?id=0B3Irx3uQNoBMQ1FlNXJsZUdYWEE) and put it in ./datas/mini-Imagenet and run proc_image.py to preprocess generate train/val/test datasets. (This process method is based on [maml](https://github.com/cbfinn/maml)).

# Train

omniglot 5way 1 shot:

```
python omniglot_train_one_shot.py -w 5 -s 1 -b 19
```

omniglot 5way 5 shot:

```
python omniglot_train_few_shot.py -w 5 -s 5 -b 15
```

omniglot 20way 1 shot:

```
python omniglot_train_one_shot.py -w 20 -s 1 -b 10
```

omniglot 20way 5 shot:

```
python omniglot_train_few_shot.py -w 20 -s 5 -b 5
```

mini-Imagenet 5 way 1 shot:

```
python miniimagenet_train_one_shot.py -w 5 -s 1 -b 15
```

mini-Imagenet 5 way 5 shot:

```
python miniimagenet_train_few_shot.py -w 5 -s 5 -b 10
```

you can change -b parameter based on your GPU memory. Currently It will load my trained model, if you want to train from scratch, you can delete models by yourself.

## Test

omniglot 5way 1 shot:

```
python omniglot_test_one_shot.py -w 5 -s 1
```

Other experiments' testings are similar.


## Citing

If you use this code in your research, please use the following BibTeX entry.

```
@inproceedings{sung2018learning,
title={Learning to Compare: Relation Network for Few-Shot Learning},
author={Sung, Flood and Yang, Yongxin and Zhang, Li and Xiang, Tao and Torr, Philip HS and Hospedales, Timothy M},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2018}
}
```

## Reference

[MAML](https://github.com/cbfinn/maml)

[MAML-pytorch](https://github.com/katerakelly/pytorch-maml)


Binary file added datas/.DS_Store
Binary file not shown.
48 changes: 48 additions & 0 deletions datas/miniImagenet/proc_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
code copied from https://github.com/cbfinn/maml/blob/master/data/miniImagenet/proc_images.py
Script for converting from csv file datafiles to a directory for each image (which is how it is loaded by MAML code)
Acquire miniImagenet from Ravi & Larochelle '17, along with the train, val, and test csv files. Put the
csv files in the miniImagenet directory and put the images in the directory 'miniImagenet/images/'.
Then run this script from the miniImagenet directory:
cd data/miniImagenet/
python proc_images.py
"""

from __future__ import print_function
import csv
import glob
import os

from PIL import Image

path_to_images = 'images/'

all_images = glob.glob(path_to_images + '*')

# Resize images

for i, image_file in enumerate(all_images):
im = Image.open(image_file)
im = im.resize((84, 84), resample=Image.LANCZOS)
im.save(image_file)
if i % 500 == 0:
print(i)

# Put in correct directory
for datatype in ['train', 'val', 'test']:
os.system('mkdir ' + datatype)

with open(datatype + '.csv', 'r') as f:
reader = csv.reader(f, delimiter=',')
last_label = ''
for i, row in enumerate(reader):
if i == 0: # skip the headers
continue
label = row[1]
image_name = row[0]
if label != last_label:
cur_dir = datatype + '/' + label + '/'
os.system('mkdir ' + cur_dir)
last_label = label
os.system('cp images/' + image_name + ' ' + cur_dir)
Binary file added datas/omniglot_28x28.zip
Binary file not shown.
Binary file added miniimagenet/.DS_Store
Binary file not shown.
212 changes: 212 additions & 0 deletions miniimagenet/miniimagenet_test_few_shot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
#-------------------------------------
# Project: Learning to Compare: Relation Network for Few-Shot Learning
# Date: 2017.9.21
# Author: Flood Sung
# All Rights Reserved
#-------------------------------------


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
import numpy as np
import task_generator_test as tg
import os
import math
import argparse
import scipy as sp
import scipy.stats

parser = argparse.ArgumentParser(description="One Shot Visual Recognition")
parser.add_argument("-f","--feature_dim",type = int, default = 64)
parser.add_argument("-r","--relation_dim",type = int, default = 8)
parser.add_argument("-w","--class_num",type = int, default = 5)
parser.add_argument("-s","--sample_num_per_class",type = int, default = 5)
parser.add_argument("-b","--batch_num_per_class",type = int, default = 10)
parser.add_argument("-e","--episode",type = int, default= 10)
parser.add_argument("-t","--test_episode", type = int, default = 600)
parser.add_argument("-l","--learning_rate", type = float, default = 0.001)
parser.add_argument("-g","--gpu",type=int, default=0)
parser.add_argument("-u","--hidden_unit",type=int,default=10)
args = parser.parse_args()


# Hyper Parameters
FEATURE_DIM = args.feature_dim
RELATION_DIM = args.relation_dim
CLASS_NUM = args.class_num
SAMPLE_NUM_PER_CLASS = args.sample_num_per_class
BATCH_NUM_PER_CLASS = args.batch_num_per_class
EPISODE = args.episode
TEST_EPISODE = args.test_episode
LEARNING_RATE = args.learning_rate
GPU = args.gpu
HIDDEN_UNIT = args.hidden_unit

def mean_confidence_interval(data, confidence=0.95):
a = 1.0*np.array(data)
n = len(a)
m, se = np.mean(a), scipy.stats.sem(a)
h = se * sp.stats.t._ppf((1+confidence)/2., n-1)
return m,h

class CNNEncoder(nn.Module):
"""docstring for ClassName"""
def __init__(self):
super(CNNEncoder, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,padding=0),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(),
nn.MaxPool2d(2))
self.layer2 = nn.Sequential(
nn.Conv2d(64,64,kernel_size=3,padding=0),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(),
nn.MaxPool2d(2))
self.layer3 = nn.Sequential(
nn.Conv2d(64,64,kernel_size=3,padding=1),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU())
self.layer4 = nn.Sequential(
nn.Conv2d(64,64,kernel_size=3,padding=1),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU())

def forward(self,x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
#out = out.view(out.size(0),-1)
return out # 64

class RelationNetwork(nn.Module):
"""docstring for RelationNetwork"""
def __init__(self,input_size,hidden_size):
super(RelationNetwork, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(64*2,64,kernel_size=3,padding=0),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(),
nn.MaxPool2d(2))
self.layer2 = nn.Sequential(
nn.Conv2d(64,64,kernel_size=3,padding=0),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(),
nn.MaxPool2d(2))
self.fc1 = nn.Linear(input_size*3*3,hidden_size)
self.fc2 = nn.Linear(hidden_size,1)

def forward(self,x):
out = self.layer1(x)
out = self.layer2(out)
out = out.view(out.size(0),-1)
out = F.relu(self.fc1(out))
out = F.sigmoid(self.fc2(out))
return out

def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif classname.find('Linear') != -1:
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data = torch.ones(m.bias.data.size())

def main():
# Step 1: init data folders
print("init data folders")
# init character folders for dataset construction
metatrain_folders,metatest_folders = tg.mini_imagenet_folders()

# Step 2: init neural networks
print("init neural networks")

feature_encoder = CNNEncoder()
relation_network = RelationNetwork(FEATURE_DIM,RELATION_DIM)


feature_encoder.cuda(GPU)
relation_network.cuda(GPU)

feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=100000,gamma=0.5)
relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
relation_network_scheduler = StepLR(relation_network_optim,step_size=100000,gamma=0.5)

if os.path.exists(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
feature_encoder.load_state_dict(torch.load(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
print("load feature encoder success")
if os.path.exists(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
relation_network.load_state_dict(torch.load(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
print("load relation network success")

total_accuracy = 0.0
for episode in range(EPISODE):


# test
print("Testing...")

accuracies = []
for i in range(TEST_EPISODE):
total_rewards = 0
task = tg.MiniImagenetTask(metatest_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,15)
sample_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False)
num_per_class = 5
test_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="test",shuffle=False)

sample_images,sample_labels = sample_dataloader.__iter__().next()
for test_images,test_labels in test_dataloader:
batch_size = test_labels.shape[0]
# calculate features
sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64
sample_features = sample_features.view(CLASS_NUM,SAMPLE_NUM_PER_CLASS,FEATURE_DIM,19,19)
sample_features = torch.sum(sample_features,1).squeeze(1)
test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64

# calculate relations
# each batch sample link to every samples to calculate relations
# to form a 100x128 matrix for relation network
sample_features_ext = sample_features.unsqueeze(0).repeat(batch_size,1,1,1,1)

test_features_ext = test_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1)
test_features_ext = torch.transpose(test_features_ext,0,1)
relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,19,19)
relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

_,predict_labels = torch.max(relations.data,1)

rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(batch_size)]

total_rewards += np.sum(rewards)


accuracy = total_rewards/1.0/CLASS_NUM/15
accuracies.append(accuracy)

test_accuracy,h = mean_confidence_interval(accuracies)

print("test accuracy:",test_accuracy,"h:",h)

total_accuracy += test_accuracy

print("aver_accuracy:",total_accuracy/EPISODE)






if __name__ == '__main__':
main()
Loading

0 comments on commit addd2c9

Please sign in to comment.