Skip to content

Latest commit

 

History

History
121 lines (81 loc) · 5.37 KB

README.md

File metadata and controls

121 lines (81 loc) · 5.37 KB

DRIVE (Deep ReInforced Accident Anticipation with Visual Explanation)

Project | Paper & Supp | Demo

Wentao Bao, Qi Yu, Yu Kong

International Conference on Computer Vision (ICCV), 2021.

Table of Contents

  1. Introduction
  2. Installation
  3. Datasets
  4. Testing
  5. Training
  6. Citation

Introduction

We propose the DRIVE model which uses Deep Reinforcement Learning (DRL) to solve explainable traffic accident anticipation problem. This model simulates both the bottom-up and top-down visual attention mechanisms in a dashcam observation environment so that the decision from the proposed stochastic multi-task agent can be visually explained by attentive regions. Moreover, the proposed dense anticipation reward and sparse fixation reward are effective in training the DRIVE model with the improved Soft Actor Critic DRL algorithm.

DRIVE

Installation

Note: This repo is developed using pytorch 1.4.0 in Ubuntu 18.04 LTS OS with CUDA 10.1 GPU environment. However, more recent pytorch and CUDA versions are also compatible with this repo, such as pytorch 1.7.1 and CUDA 11.3.

a. Create a conda virtual environment of this repo, and activate it:

conda create -n pyRL python=3.7 -y
conda activate pyRL

b. Install official pytorch. Take the pytorch==1.4.0 as an example:

conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch

c. Install the rest dependencies.

pip install -r requirements.txt

Datasets

This repo currently supports for the down-sized version of DADA-2000 dataset. Specifically, we reduced the image size at a half and trimmed the videos into accident clips with at most 450 frames. For more details, please refer to the code script data/reduce_data.py.

We also noticed that the original DATA-2000 dataset was updated, here we provide the processed DADA-2000-small.zip for your convenience. Simply download and unzip it into data folder:

cd data
unzip DADA-2000-small.zip ./

Testing

a. Download the pretrained saliency models.

The pre-trained saliency models are provided here: saliency_models, where the mlnet_25.pth is used by default in this repo. Please place the file to the path models/saliency/mlnet_25.pth.

b. Download the pre-trained DRIVE model:

The pre-trained DRIVE model is provided here: DADA2KS_Full_SACAE_Final, and place the model file to the path output/DADA2KS_Full_SACAE_Final/checkpoints/sac_epoch_50.pt.

c. Run the DRIVE testing.

bash script_RL.sh test 0 4 DADA2KS_Full_SACAE_Final

Wait for a while, results will be reported.

Training

This repo suports for training DRIVE models based on two DRL algorithms, i.e., REINFORCE and SAC, and two kinds of visual saliency features, i.e., MLNet and TASED-Net. By default, we use SAC + MLNet to achieve the best speed and accuracy trade-off.

a. Download the pretrained saliency models.

The pre-trained saliency models are provided here: saliency_models, where the mlnet_25.pth is used by default in this repo. Please place the file to the path models/saliency/mlnet_25.pth.

b. Run the DRIVE training.

bash script_RL.sh train 0 4 DADA2KS_Full_SACAE_Final

c. Monitoring the training on Tensorboard.

Visualizing the training curves (losses, accuracies, etc.) on TensorBoard by the following commands:

cd output/DADA2KS_Full_SACAE_Final/tensorboard
tensorboard --logdir=./ --port 6008

Then, you will see the generated url address http://localhost:6008. Open this address with your Internet Browser (such as Chrome), you will monitoring the status of training.

TIPs:

If you are using SSH connection to a remote server without monitor, tensorboard visualization can be done on your local machine by manually mapping the SSH port number:

ssh -L 16008:localhost:6008 {your_remote_name}@{your_remote_ip}

Then, you can monitor the tensorboard by the port number 16008 by typing http://localhost:16008 in your browser.

Citation

If you find the code useful in your research, please cite:

@inproceedings{BaoICCV2021DRIVE,
  author = "Bao, Wentao and Yu, Qi and Kong, Yu",
  title = "Deep Reinforced Accident Anticipation with Visual Explanation",
  booktitle = "International Conference on Computer Vision (ICCV)",
  year = "2021"
}

License

See MiT License

Acknowledgement

We sincerely thank all of the following great repos: pytorch-soft-actor-critic, pytorch-REINFORCE, MLNet-Pytorch, and TASED-Net.