Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Fix a bug in base distiller and add yolox distill example #192

Open
wants to merge 1 commit into
base: 0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions configs/distill/cwd/cwd_fpn_yolox_s_yolox_s_300e_coco_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
_base_ = [
'../../_base_/schedules/mmdet/schedule_1x.py',
'../../_base_/mmdet_runtime.py'
]

img_scale = (640, 640) # height, width

student = dict(
type='mmdet.YOLOX',
input_size=img_scale,
random_size_range=(15, 25),
random_size_interval=10,
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
neck=dict(
type='YOLOXPAFPN',
in_channels=[128, 256, 512],
out_channels=128,
num_csp_blocks=1),
bbox_head=dict(
type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128),
train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
# In order to align the source code, the threshold of the val phase is
# 0.01, and the threshold of the test phase is 0.001.
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))

teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth' # noqa: E501
teacher = dict(
init_cfg=dict(type='Pretrained', checkpoint=teacher_ckpt),
type='mmdet.YOLOX',
input_size=img_scale,
random_size_range=(15, 25),
random_size_interval=10,
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
neck=dict(
type='YOLOXPAFPN',
in_channels=[128, 256, 512],
out_channels=128,
num_csp_blocks=1),
bbox_head=dict(
type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128),
train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
# In order to align the source code, the threshold of the val phase is
# 0.01, and the threshold of the test phase is 0.001.
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))

algorithm = dict(
type='AlignMethodDistill',
architecture=dict(type='MMDetArchitecture', model=student),
with_student_loss=True,
with_teacher_loss=False,
distiller=dict(
type='SingleTeacherDistiller',
teacher=teacher,
teacher_trainable=False,
align_methods=[
dict(method='YOLOX._preprocess', import_module='mmdet.models')
],
components=[
dict(
student_module='neck.out_convs.0.conv',
teacher_module='neck.out_convs.0.conv',
losses=[
dict(
type='ChannelWiseDivergence',
name='loss_cwd_logits',
tau=1,
loss_weight=5)
]),
dict(
student_module='neck.out_convs.1.conv',
teacher_module='neck.out_convs.1.conv',
losses=[
dict(
type='ChannelWiseDivergence',
name='loss_cwd_logits',
tau=1,
loss_weight=5)
]),
dict(
student_module='neck.out_convs.2.conv',
teacher_module='neck.out_convs.2.conv',
losses=[
dict(
type='ChannelWiseDivergence',
name='loss_cwd_logits',
tau=1,
loss_weight=5)
])
]))

# dataset settings
data_root = 'data/coco/'
dataset_type = 'CocoDataset'

train_pipeline = [
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
dict(
type='RandomAffine',
scaling_ratio_range=(0.1, 2),
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='MixUp',
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
# According to the official implementation, multi-scale
# training is not considered here but in the
# 'mmdet/models/detectors/yolox.py'.
dict(type='Resize', img_scale=img_scale, keep_ratio=True),
dict(
type='Pad',
pad_to_square=True,
# If the image is three-channel, the pad value needs
# to be set separately for each channel.
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]

train_dataset = dict(
type='MultiImageMixDataset',
dataset=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True)
],
filter_empty_gt=False,
),
pipeline=train_pipeline)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Pad',
pad_to_square=True,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img'])
])
]

data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
persistent_workers=True,
train=train_dataset,
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))

# optimizer
# default 8 gpu
optimizer = dict(
type='SGD',
lr=0.01,
momentum=0.9,
weight_decay=5e-4,
nesterov=True,
paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
optimizer_config = dict(grad_clip=None)

max_epochs = 300
num_last_epochs = 15
resume_from = None
interval = 10

# learning policy
lr_config = dict(
_delete_=True,
policy='YOLOX',
warmup='exp',
by_epoch=False,
warmup_by_epoch=True,
warmup_ratio=1,
warmup_iters=5, # 5 epoch
num_last_epochs=num_last_epochs,
min_lr_ratio=0.05)

runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)

custom_hooks = [
dict(
type='YOLOXModeSwitchHook',
num_last_epochs=num_last_epochs,
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=num_last_epochs,
interval=interval,
priority=48),
dict(
type='ExpMomentumEMAHook',
resume_from=resume_from,
momentum=0.0001,
priority=49)
]
checkpoint_config = dict(interval=interval)
evaluation = dict(
save_best='auto',
# The evaluation interval is 'interval' when running epoch is
# less than ‘max_epochs - num_last_epochs’.
# The evaluation interval is 1 when running epoch is greater than
# or equal to ‘max_epochs - num_last_epochs’.
interval=interval,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)],
metric='bbox')
log_config = dict(interval=50)

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (8 samples per GPU)
auto_scale_lr = dict(base_batch_size=64)

find_unused_parameters = True
4 changes: 1 addition & 3 deletions mmrazor/models/distillers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,9 @@ def _set_method(self, method):
def __enter__(self):
"""Rewrite the function."""
self.method_impl = eval(self.method_exec_str)

if self.method_impl:
self._set_method(
function_wrapper(self.ctx, self.method_impl, self.method_str,
self.align_mode))
function_wrapper(self.ctx, self.method_impl, self.method_str))

def __exit__(self, exc_type, exc_value, traceback):
"""Restore the function."""
Expand Down