Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobgil committed Dec 31, 2020
2 parents ceb6ddb + 9706692 commit 4f75bc6
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This repository implements methods for explainability in Vision Transformers.
- Gradient Attention Rollout for class specific explainability.
*This is our attempt to further build upon and improve Attention Rollout.*

- TBD Attention flow is work un progress.
- TBD Attention flow is work in progress.

Includes some tweaks and tricks to get it working:
- Different Attention Head fusion methods,
Expand Down Expand Up @@ -38,14 +38,34 @@ otherwise Gradient Attention Rollout will be used.
Notice that by default, this uses the 'Tiny' model from [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877)
hosted on torch hub.

##Where did the Transformer pay attention to in this image?##
## Where did the Transformer pay attention to in this image?

| Image | Vanilla Attention Rollout | With discard_ratio+max fusion |
| -------------------------|-------------------------|------------------------- |
| ![](examples/both.png) | ![](examples/both_attention_rollout_0.000_mean.png) | ![](examples/both_attention_rollout_0.990_max.png)
![](examples/plane.png) | ![](examples/plane_attention_rollout_0.000_mean.png) | ![](examples/plane_attention_rollout_0.900_max.png) |
![](examples/dogbird.png) | ![](examples/dogbird_attention_rollout_0.000_mean.png) | ![](examples/dogbird_attention_rollout_0.900_max.png) |
![](examples/plane2.png) | ![](examples/plane2_attention_rollout_0.000_mean.png) | ![](examples/plane2_attention_rollout_0.900_max.png) |

## Gradient Attention Rollout for class specific explainability

The Attention that flows in the transformer passes along information belonging to different classes.
Gradient roll out lets us see what locations the network paid attention too,
but it tells us nothing about if it ended up using those locations for the final classification.

We can multiply the attention with the gradient of the target class output, and take the average among the attention heads (while masking out negative attentions) to keep only attention that contributes to the target category (or categories).


### Where does the Transformer see a Dog (category 243), and a Cat (category 282)?
![](examples/both_grad_rollout_243_0.900_max.png) ![](examples/both_grad_rollout_282_0.900_max.png)

### Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87):
![](examples/dogbird_grad_rollout_161_0.900_max.png) ![](examples/dogbird_grad_rollout_87_0.900_max.png)

## Filtering the lowest attentions in every layer

## Tricks and Tweaks to get this working

### Filtering the lowest attentions in every layer

`--discard_ratio <value between 0 and 1>`

Expand All @@ -55,7 +75,7 @@ Results for dIfferent values:

![](examples/both_discard_ratio.gif) ![](examples/plane_discard_ratio.gif)

## Different Attention Head Fusions
### Different Attention Head Fusions

The Attention Rollout method suggests taking the average attention accross the attention heads,

Expand All @@ -67,25 +87,10 @@ but emperically it looks like taking the Minimum value, Or the Maximum value com
| -------------------------|-------------------------|------------------------- |
![](examples/both.png) | ![](examples/both_attention_rollout_0.000_mean.png) | ![](examples/both_attention_rollout_0.000_min.png)

## Gradient Attention Rollout for class specific explainability

The Attention that flows in the transformer passes along information belonging to different classes.
Gradient roll out lets us see what locations the network paid attention too,
but it tells us nothing about if it ended up using those locations for the final classification.

We can multiply the attention with the gradient of the target class output, and then take the average among the attention heads (and mask out negative attentions).
Where does the Transformer see a Dog (category 243), and a Cat (category 282)?

![](examples/both_grad_rollout_243_0.900_max.png) ![](examples/both_grad_rollout_282_0.900_max.png)

### Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87):
![](examples/dogbird_grad_rollout_161_0.900_max.png) ![](examples/dogbird_grad_rollout_87_0.900_max.png)

## References
- [Quantifying Attention Flow in Transformers](https://arxiv.org/abs/2005.00928)
- [timm: a great collection of models in PyTorch](https://github.com/rwightman/pytorch-image-models)
and especially
[the vision transformer implementation](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)
and especially [the vision transformer implementation](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)

- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)
- Credit for https://github.com/jeonsworld/ViT-pytorch for being a good starting point.
Expand Down

0 comments on commit 4f75bc6

Please sign in to comment.