Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature (WIP)] Add Fgd #401

Open
wants to merge 10 commits into
base: 0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions configs/distill/fgd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# FGD

> [Focal and Global Knowledge Distillation for Detectors](https://arxiv.org/abs/2111.11837)

<!-- [ALGORITHM] -->

## Abstract

Knowledge distillation has been applied to image classification successfully. However, object detection is much more sophisticated and most knowledge distillation methods have failed on it. In this paper, we point out that in object detection, the features of the teacher and student vary greatly in different areas, especially in the foreground and background. If we distill them equally, the uneven differences between feature maps will negatively affect the distillation. Thus, we propose Focal and Global Distillation (FGD). Focal distillation separates the foreground and background, forcing the student to focus on the teacher's critical pixels and channels. Global distillation rebuilds the relation between different pixels and transfers it from teachers to students, compensating for missing global information in focal distillation. As our method only needs to calculate the loss on the feature map, FGD can be applied to various detectors. We experiment on various detectors with different backbones and the results show that the student detector achieves excellent mAP improvement. For example, ResNet-50 based RetinaNet, Faster RCNN, RepPoints and Mask RCNN with our distillation method achieve 40.7%, 42.0%, 42.0% and 42.1% mAP on COCO2017, which are 3.3, 3.6, 3.4 and 2.9 higher than the baseline, respectively.

![pipeline](https://user-images.githubusercontent.com/41630003/220037957-25a1440f-fcb3-413a-a350-97937bf6a042.png)

## Results and models

### Detection

| Location | Dataset | Teacher | Student | mAP | mAP(T) | mAP(S) | Config | Download |
| :------: | :-----: | :---------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :--: | :----: | :----: | :-------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| FPN | COCO | [retina_x101_1x](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_x101_64x4d_fpn_1x_coco.py) | [retina_r50_2x](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 40.5 | 41.0 | 37.4 | [config](./fgd_retina_x101_fpn_retina_r50_fpn_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth) \|[model](https://download.openmmlab.com/mmrazor/v0.1/distill/fgd/fgd_retina_x101_retina_r50_2x_coco_20221216_114845-c4c7496d.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/distill/fgd/fgd_retina_x101_retina_r50_2x_coco_20221216_114845-c4c7496d.json) |

## Citation

```latex
@article{yang2021focal,
title={Focal and Global Knowledge Distillation for Detectors},
author={Yang, Zhendong and Li, Zhe and Jiang, Xiaohu and Gong, Yuan and Yuan, Zehuan and Zhao, Danpei and Yuan, Chun},
journal={arXiv preprint arXiv:2111.11837},
year={2021}
}
```
228 changes: 228 additions & 0 deletions configs/distill/fgd/fgd_retina_x101_fpn_retina_r50_fpn_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
_base_ = [
'../../_base_/datasets/mmdet/coco_detection.py',
'../../_base_/schedules/mmdet/schedule_2x.py',
'../../_base_/mmdet_runtime.py'
]

# model settings
t_weight = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth' # noqa: E501

student = dict(
type='mmdet.RetinaNet',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
num_outs=5),
bbox_head=dict(
type='RetinaHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))

teacher = dict(
type='mmdet.RetinaNet',
init_cfg=dict(type='Pretrained', checkpoint=t_weight),
backbone=dict(
type='ResNeXt',
depth=101,
groups=64,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
num_outs=5),
bbox_head=dict(
type='RetinaHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))

# algorithm setting
in_channels = 256
temp = 0.5
alpha_fgd = 0.001
beta_fgd = 0.0005
gamma_fgd = 0.0005
lambda_fgd = 0.000005
algorithm = dict(
type='GeneralDistill',
architecture=dict(
type='MMDetArchitecture',
model=student,
),
distiller=dict(
type='SingleTeacherDistiller',
teacher=teacher,
teacher_trainable=False,
components=[
dict(
student_module='neck.fpn_convs.0.conv',
teacher_module='neck.fpn_convs.0.conv',
losses=[
dict(
type='FGDLoss',
name='loss_fgd_0',
in_channels=in_channels,
alpha_fgd=alpha_fgd,
beta_fgd=beta_fgd,
gamma_fgd=gamma_fgd,
lambda_fgd=lambda_fgd,
)
]),
dict(
student_module='neck.fpn_convs.1.conv',
teacher_module='neck.fpn_convs.1.conv',
losses=[
dict(
type='FGDLoss',
name='loss_fgd_1',
in_channels=in_channels,
alpha_fgd=alpha_fgd,
beta_fgd=beta_fgd,
gamma_fgd=gamma_fgd,
lambda_fgd=lambda_fgd,
)
]),
dict(
student_module='neck.fpn_convs.2.conv',
teacher_module='neck.fpn_convs.2.conv',
losses=[
dict(
type='FGDLoss',
name='loss_fgd_2',
in_channels=in_channels,
alpha_fgd=alpha_fgd,
beta_fgd=beta_fgd,
gamma_fgd=gamma_fgd,
lambda_fgd=lambda_fgd,
)
]),
dict(
student_module='neck.fpn_convs.3.conv',
teacher_module='neck.fpn_convs.3.conv',
losses=[
dict(
type='FGDLoss',
name='loss_fgd_3',
in_channels=in_channels,
alpha_fgd=alpha_fgd,
beta_fgd=beta_fgd,
gamma_fgd=gamma_fgd,
lambda_fgd=lambda_fgd,
)
]),
dict(
student_module='neck.fpn_convs.4.conv',
teacher_module='neck.fpn_convs.4.conv',
losses=[
dict(
type='FGDLoss',
name='loss_fgd_4',
in_channels=in_channels,
alpha_fgd=alpha_fgd,
beta_fgd=beta_fgd,
gamma_fgd=gamma_fgd,
lambda_fgd=lambda_fgd,
)
]),
]),
)

find_unused_parameters = True

# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
32 changes: 32 additions & 0 deletions configs/distill/fgd/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Collections:
- Name: FGD
Metadata:
Training Data:
- COCO
Paper:
URL: https://arxiv.org/abs/2111.11837
Title: Focal and Global Knowledge Distillation for Detectors
README: configs/distill/fgd/README.md
Code:
URL:
Version: v0.1.0
Converted From:
Code:
- https://github.com/yzd-v/FGD
Models:
- Name: fgd_retina_x101_fpn_retina_r50_fpn_2x_coco
In Collection: FGD
Metadata:
Location: FPN
Student: retinanet-r50
Teacher: retinanet-x101
Teacher Checkpoint: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 40.5
box AP(S): 37.4
box AP(T): 41.0
Config: configs/distill/fgd/fgd_retina_x101_fpn_retina_r50_fpn_2x_coco.py
Weights: https://download.openmmlab.com/mmrazor/v0.1/distill/fgd/fgd_retina_x101_retina_r50_2x_coco_20221216_114845-c4c7496d.pth
3 changes: 2 additions & 1 deletion mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cwd import ChannelWiseDivergence
from .fgd import FGDLoss
from .kl_divergence import KLDivergence
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
from .weighted_soft_label_distillation import WSLD

__all__ = [
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD'
'WSLD', 'FGDLoss'
]
Loading