This repository contains Pytorch Implementation of SimpleNet and MDNSal. Appearing in the proceedings of the 21st International Conference on Intelligent Robots and Systems (IROS).
Please cite with the following Bibtex code:
@inproceedings{Navya-IROS-2020,
AUTHOR = {Navyasri Reddy, Samyak Jain, Pradeep Yarlagadda, Vineet Gandhi},
TITLE = {Tidying Deep Saliency Prediction Architectures},
BOOKTITLE = {IROS},
YEAR = {2020}
}
Learning computational models for visual attention (saliency estimation) is an effort to inch machines/robots closer to human visual cognitive abilities. Data-driven efforts have dominated the landscape since the introduction of deep neural network architectures. In deep learning research, the choices in architecture design are often empirical and frequently lead to more complex models than necessary. The complexity, in turn, hinders the application requirements. In this paper, we identify four key components of saliency models, i.e., input features, multi-level integration, readout architecture, and loss functions. We review the existing state of the art models on these four components and propose novel and simpler alternatives. As a result, we propose two novel end-to-end architectures called SimpleNet and MDNSal, which are neater, minimal, more interpretable and achieve state of the art performance on public saliency benchmarks. SimpleNet is an optimized encoder-decoder architecture and brings notable performance gains on the SALICON dataset (the largest saliency benchmark). MDNSal is a parametric model that directly predicts parameters of a GMM distribution and is aimed to bring more interpretability to the prediction maps. The proposed saliency models run at 25fps, making them ideal for real-time applications.
Clone this repository and download the pretrained weights of SimpleNet, for multiple encoders, trained on SALICON dataset from this link. The trained weights for MobileNetV2 can be found here.
Then just run the code using
$ python3 test.py --val_img_dir path/to/test/images --results_dir path/to/results --model_val_path path/to/saved/models
This will generate saliency maps for all images in the images directory and dump these maps into results directory
For training the model from scratch, download the pretrained weights of PNASNet from here and place these weights in the PNAS/ folder. Run the following command to train
$ python3 train.py --dataset_dir path/to/dataset
The dataset directory structure should be
└── Dataset
├── fixations
│ ├── train
│ └── val
├── images
│ ├── train
│ └── val
├── maps
├── train
└── val
For training the model with MIT1003 or CAT2000 dataset, first train the model with SALICON dataset and finetune the model weights on MIT1003 or CAT2000 dataset.
For training the model, we provide encoders based out of PNASNet, DenseNet-161, VGG-16 and ResNet-50. Run the command -
$ python3 train.py --enc_model <model> --train_enc <boolean value>
<model> : {"pnas", "densenet", "resnet", "vgg", "mobilenet"}
train_enc is 1 if we want to finetune the encoder and 0 otherwise.
Similarly for testing the model,
$ python3 test.py --enc_model <model> --model_val_path path/to/pretrained/model --save_results <binary> --validate <binary>
If you want to save the results of the generated map make save_results flag to 1 and if you want to evaluate the model quantitatively make the validate flag to 1.
For the training the model with a combination of loss functions, run the following command -
$ python3 train.py --<loss_function> True --<loss_function>_coeff <coefficient of the loss>
<loss_function> : {"kldiv", "cc", "nss", "sim"}
By default the loss function is KLDiv with coefficient 1.0
The results of our models on SALICON test dataset can be viewed here under the name SimpleNet and MDNSal. Comparison with other state-of-the-art saliency detection models
Comparison with other state-of-the-art saliency detection models on MIT300 test set
If any question, please contact [email protected], [email protected] or [email protected] , or use public issues section of this repository
This code is distributed under MIT LICENSE.