forked from floodsung/LearningToCompare_FSL
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
songrotek
committed
Mar 28, 2018
1 parent
cada92b
commit addd2c9
Showing
30 changed files
with
2,509 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,88 @@ | ||
# LearningToCompare_FSL | ||
PyTorch code for CVPR 2018 paper: Learning to Compare: Relation Network for Few-Shot Learning (Few-Shot Learning part) | ||
PyTorch code for CVPR 2018 paper: [Learning to Compare: Relation Network for Few-Shot Learning](https://arxiv.org/abs/1711.06025) (Few-Shot Learning part) | ||
|
||
For Zero-Shot Learning part, please visit [here](https://github.com/lzrobots/LearningToCompare_ZSL). | ||
|
||
# Requirements | ||
|
||
Python 2.7 | ||
|
||
Pytorch 0.3 | ||
|
||
# Data | ||
|
||
For Omniglot experiments, I directly attach omniglot 28x28 resized images in the git, which is created based on [omniglot](https://github.com/brendenlake/omniglot) and [maml](https://github.com/cbfinn/maml). | ||
|
||
For mini-Imagenet experiments, please download [mini-Imagenet](https://drive.google.com/open?id=0B3Irx3uQNoBMQ1FlNXJsZUdYWEE) and put it in ./datas/mini-Imagenet and run proc_image.py to preprocess generate train/val/test datasets. (This process method is based on [maml](https://github.com/cbfinn/maml)). | ||
|
||
# Train | ||
|
||
omniglot 5way 1 shot: | ||
|
||
``` | ||
python omniglot_train_one_shot.py -w 5 -s 1 -b 19 | ||
``` | ||
|
||
omniglot 5way 5 shot: | ||
|
||
``` | ||
python omniglot_train_few_shot.py -w 5 -s 5 -b 15 | ||
``` | ||
|
||
omniglot 20way 1 shot: | ||
|
||
``` | ||
python omniglot_train_one_shot.py -w 20 -s 1 -b 10 | ||
``` | ||
|
||
omniglot 20way 5 shot: | ||
|
||
``` | ||
python omniglot_train_few_shot.py -w 20 -s 5 -b 5 | ||
``` | ||
|
||
mini-Imagenet 5 way 1 shot: | ||
|
||
``` | ||
python miniimagenet_train_one_shot.py -w 5 -s 1 -b 15 | ||
``` | ||
|
||
mini-Imagenet 5 way 5 shot: | ||
|
||
``` | ||
python miniimagenet_train_few_shot.py -w 5 -s 5 -b 10 | ||
``` | ||
|
||
you can change -b parameter based on your GPU memory. Currently It will load my trained model, if you want to train from scratch, you can delete models by yourself. | ||
|
||
## Test | ||
|
||
omniglot 5way 1 shot: | ||
|
||
``` | ||
python omniglot_test_one_shot.py -w 5 -s 1 | ||
``` | ||
|
||
Other experiments' testings are similar. | ||
|
||
|
||
## Citing | ||
|
||
If you use this code in your research, please use the following BibTeX entry. | ||
|
||
``` | ||
@inproceedings{sung2018learning, | ||
title={Learning to Compare: Relation Network for Few-Shot Learning}, | ||
author={Sung, Flood and Yang, Yongxin and Zhang, Li and Xiang, Tao and Torr, Philip HS and Hospedales, Timothy M}, | ||
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, | ||
year={2018} | ||
} | ||
``` | ||
|
||
## Reference | ||
|
||
[MAML](https://github.com/cbfinn/maml) | ||
|
||
[MAML-pytorch](https://github.com/katerakelly/pytorch-maml) | ||
|
||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
""" | ||
code copied from https://github.com/cbfinn/maml/blob/master/data/miniImagenet/proc_images.py | ||
Script for converting from csv file datafiles to a directory for each image (which is how it is loaded by MAML code) | ||
Acquire miniImagenet from Ravi & Larochelle '17, along with the train, val, and test csv files. Put the | ||
csv files in the miniImagenet directory and put the images in the directory 'miniImagenet/images/'. | ||
Then run this script from the miniImagenet directory: | ||
cd data/miniImagenet/ | ||
python proc_images.py | ||
""" | ||
|
||
from __future__ import print_function | ||
import csv | ||
import glob | ||
import os | ||
|
||
from PIL import Image | ||
|
||
path_to_images = 'images/' | ||
|
||
all_images = glob.glob(path_to_images + '*') | ||
|
||
# Resize images | ||
|
||
for i, image_file in enumerate(all_images): | ||
im = Image.open(image_file) | ||
im = im.resize((84, 84), resample=Image.LANCZOS) | ||
im.save(image_file) | ||
if i % 500 == 0: | ||
print(i) | ||
|
||
# Put in correct directory | ||
for datatype in ['train', 'val', 'test']: | ||
os.system('mkdir ' + datatype) | ||
|
||
with open(datatype + '.csv', 'r') as f: | ||
reader = csv.reader(f, delimiter=',') | ||
last_label = '' | ||
for i, row in enumerate(reader): | ||
if i == 0: # skip the headers | ||
continue | ||
label = row[1] | ||
image_name = row[0] | ||
if label != last_label: | ||
cur_dir = datatype + '/' + label + '/' | ||
os.system('mkdir ' + cur_dir) | ||
last_label = label | ||
os.system('cp images/' + image_name + ' ' + cur_dir) |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
#------------------------------------- | ||
# Project: Learning to Compare: Relation Network for Few-Shot Learning | ||
# Date: 2017.9.21 | ||
# Author: Flood Sung | ||
# All Rights Reserved | ||
#------------------------------------- | ||
|
||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.autograd import Variable | ||
from torch.optim.lr_scheduler import StepLR | ||
import numpy as np | ||
import task_generator_test as tg | ||
import os | ||
import math | ||
import argparse | ||
import scipy as sp | ||
import scipy.stats | ||
|
||
parser = argparse.ArgumentParser(description="One Shot Visual Recognition") | ||
parser.add_argument("-f","--feature_dim",type = int, default = 64) | ||
parser.add_argument("-r","--relation_dim",type = int, default = 8) | ||
parser.add_argument("-w","--class_num",type = int, default = 5) | ||
parser.add_argument("-s","--sample_num_per_class",type = int, default = 5) | ||
parser.add_argument("-b","--batch_num_per_class",type = int, default = 10) | ||
parser.add_argument("-e","--episode",type = int, default= 10) | ||
parser.add_argument("-t","--test_episode", type = int, default = 600) | ||
parser.add_argument("-l","--learning_rate", type = float, default = 0.001) | ||
parser.add_argument("-g","--gpu",type=int, default=0) | ||
parser.add_argument("-u","--hidden_unit",type=int,default=10) | ||
args = parser.parse_args() | ||
|
||
|
||
# Hyper Parameters | ||
FEATURE_DIM = args.feature_dim | ||
RELATION_DIM = args.relation_dim | ||
CLASS_NUM = args.class_num | ||
SAMPLE_NUM_PER_CLASS = args.sample_num_per_class | ||
BATCH_NUM_PER_CLASS = args.batch_num_per_class | ||
EPISODE = args.episode | ||
TEST_EPISODE = args.test_episode | ||
LEARNING_RATE = args.learning_rate | ||
GPU = args.gpu | ||
HIDDEN_UNIT = args.hidden_unit | ||
|
||
def mean_confidence_interval(data, confidence=0.95): | ||
a = 1.0*np.array(data) | ||
n = len(a) | ||
m, se = np.mean(a), scipy.stats.sem(a) | ||
h = se * sp.stats.t._ppf((1+confidence)/2., n-1) | ||
return m,h | ||
|
||
class CNNEncoder(nn.Module): | ||
"""docstring for ClassName""" | ||
def __init__(self): | ||
super(CNNEncoder, self).__init__() | ||
self.layer1 = nn.Sequential( | ||
nn.Conv2d(3,64,kernel_size=3,padding=0), | ||
nn.BatchNorm2d(64, momentum=1, affine=True), | ||
nn.ReLU(), | ||
nn.MaxPool2d(2)) | ||
self.layer2 = nn.Sequential( | ||
nn.Conv2d(64,64,kernel_size=3,padding=0), | ||
nn.BatchNorm2d(64, momentum=1, affine=True), | ||
nn.ReLU(), | ||
nn.MaxPool2d(2)) | ||
self.layer3 = nn.Sequential( | ||
nn.Conv2d(64,64,kernel_size=3,padding=1), | ||
nn.BatchNorm2d(64, momentum=1, affine=True), | ||
nn.ReLU()) | ||
self.layer4 = nn.Sequential( | ||
nn.Conv2d(64,64,kernel_size=3,padding=1), | ||
nn.BatchNorm2d(64, momentum=1, affine=True), | ||
nn.ReLU()) | ||
|
||
def forward(self,x): | ||
out = self.layer1(x) | ||
out = self.layer2(out) | ||
out = self.layer3(out) | ||
out = self.layer4(out) | ||
#out = out.view(out.size(0),-1) | ||
return out # 64 | ||
|
||
class RelationNetwork(nn.Module): | ||
"""docstring for RelationNetwork""" | ||
def __init__(self,input_size,hidden_size): | ||
super(RelationNetwork, self).__init__() | ||
self.layer1 = nn.Sequential( | ||
nn.Conv2d(64*2,64,kernel_size=3,padding=0), | ||
nn.BatchNorm2d(64, momentum=1, affine=True), | ||
nn.ReLU(), | ||
nn.MaxPool2d(2)) | ||
self.layer2 = nn.Sequential( | ||
nn.Conv2d(64,64,kernel_size=3,padding=0), | ||
nn.BatchNorm2d(64, momentum=1, affine=True), | ||
nn.ReLU(), | ||
nn.MaxPool2d(2)) | ||
self.fc1 = nn.Linear(input_size*3*3,hidden_size) | ||
self.fc2 = nn.Linear(hidden_size,1) | ||
|
||
def forward(self,x): | ||
out = self.layer1(x) | ||
out = self.layer2(out) | ||
out = out.view(out.size(0),-1) | ||
out = F.relu(self.fc1(out)) | ||
out = F.sigmoid(self.fc2(out)) | ||
return out | ||
|
||
def weights_init(m): | ||
classname = m.__class__.__name__ | ||
if classname.find('Conv') != -1: | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
if m.bias is not None: | ||
m.bias.data.zero_() | ||
elif classname.find('BatchNorm') != -1: | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
elif classname.find('Linear') != -1: | ||
n = m.weight.size(1) | ||
m.weight.data.normal_(0, 0.01) | ||
m.bias.data = torch.ones(m.bias.data.size()) | ||
|
||
def main(): | ||
# Step 1: init data folders | ||
print("init data folders") | ||
# init character folders for dataset construction | ||
metatrain_folders,metatest_folders = tg.mini_imagenet_folders() | ||
|
||
# Step 2: init neural networks | ||
print("init neural networks") | ||
|
||
feature_encoder = CNNEncoder() | ||
relation_network = RelationNetwork(FEATURE_DIM,RELATION_DIM) | ||
|
||
|
||
feature_encoder.cuda(GPU) | ||
relation_network.cuda(GPU) | ||
|
||
feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) | ||
feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=100000,gamma=0.5) | ||
relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) | ||
relation_network_scheduler = StepLR(relation_network_optim,step_size=100000,gamma=0.5) | ||
|
||
if os.path.exists(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")): | ||
feature_encoder.load_state_dict(torch.load(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))) | ||
print("load feature encoder success") | ||
if os.path.exists(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")): | ||
relation_network.load_state_dict(torch.load(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))) | ||
print("load relation network success") | ||
|
||
total_accuracy = 0.0 | ||
for episode in range(EPISODE): | ||
|
||
|
||
# test | ||
print("Testing...") | ||
|
||
accuracies = [] | ||
for i in range(TEST_EPISODE): | ||
total_rewards = 0 | ||
task = tg.MiniImagenetTask(metatest_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,15) | ||
sample_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False) | ||
num_per_class = 5 | ||
test_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="test",shuffle=False) | ||
|
||
sample_images,sample_labels = sample_dataloader.__iter__().next() | ||
for test_images,test_labels in test_dataloader: | ||
batch_size = test_labels.shape[0] | ||
# calculate features | ||
sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64 | ||
sample_features = sample_features.view(CLASS_NUM,SAMPLE_NUM_PER_CLASS,FEATURE_DIM,19,19) | ||
sample_features = torch.sum(sample_features,1).squeeze(1) | ||
test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64 | ||
|
||
# calculate relations | ||
# each batch sample link to every samples to calculate relations | ||
# to form a 100x128 matrix for relation network | ||
sample_features_ext = sample_features.unsqueeze(0).repeat(batch_size,1,1,1,1) | ||
|
||
test_features_ext = test_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1) | ||
test_features_ext = torch.transpose(test_features_ext,0,1) | ||
relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,19,19) | ||
relations = relation_network(relation_pairs).view(-1,CLASS_NUM) | ||
|
||
_,predict_labels = torch.max(relations.data,1) | ||
|
||
rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(batch_size)] | ||
|
||
total_rewards += np.sum(rewards) | ||
|
||
|
||
accuracy = total_rewards/1.0/CLASS_NUM/15 | ||
accuracies.append(accuracy) | ||
|
||
test_accuracy,h = mean_confidence_interval(accuracies) | ||
|
||
print("test accuracy:",test_accuracy,"h:",h) | ||
|
||
total_accuracy += test_accuracy | ||
|
||
print("aver_accuracy:",total_accuracy/EPISODE) | ||
|
||
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.