-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
485 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Learning Student Networks in the Wild (DFND) | ||
|
||
> [Learning Student Networks in the Wild](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Learning_Student_Networks_in_the_Wild_CVPR_2021_paper.pdf) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
Data-free learning for student networks is a new paradigm for solving users’ anxiety caused by the privacy problem of using original training data. Since the architectures of modern convolutional neural networks (CNNs) are compact and sophisticated, the alternative images or meta-data generated from the teacher network are often broken. Thus, the student network cannot achieve the comparable performance to that of the pre-trained teacher network especially on the large-scale image dataset. Different to previous works, we present to maximally utilize the massive available unlabeled data in the wild. Specifically, we first thoroughly analyze the output differences between teacher and student network on the original data and develop a data collection method. Then, a noisy knowledge distillation algorithm is proposed for achieving the performance of the student network. In practice, an adaptation matrix is learned with the student network for correcting the label noise produced by the teacher network on the collected unlabeled images. The effectiveness of our DFND (DataFree Noisy Distillation) method is then verified on several benchmarks to demonstrate its superiority over state-of-theart data-free distillation methods. Experiments on various datasets demonstrate that the student networks learned by the proposed method can achieve comparable performance with those using the original dataset. | ||
|
||
<img width="910" alt="pipeline" src="./dfnd.PNG"> | ||
|
||
## Results and models | ||
|
||
### Classification | ||
|
||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | | | ||
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 94.78 | 95.34 | 94.82 | [config](./dfnd_logits_resnet34_resnet18_8xb32_cifar10.py) | [student](https://drive.google.com/file/d/1_MekfTkCsEl68meWPqtdNZIxdJO2R2Eb/view?usp=drive_link) | | ||
|
||
## Citation | ||
|
||
```latex | ||
@inproceedings{chen2021learning, | ||
title={Learning student networks in the wild}, | ||
author={Chen, Hanting and Guo, Tianyu and Xu, Chang and Li, Wenshuo and Xu, Chunjing and Xu, Chao and Wang, Yunhe}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
pages={6428--6437}, | ||
year={2021} | ||
} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
100 changes: 100 additions & 0 deletions
100
configs/distill/mmcls/dfnd/dfnd_logits_resnet34_resnet18_8xb32_cifar10.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
_base_ = ['mmcls::_base_/default_runtime.py'] | ||
|
||
# optimizer | ||
optim_wrapper = dict( | ||
optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)) | ||
# learning policy | ||
param_scheduler = dict( | ||
type='MultiStepLR', by_epoch=True, milestones=[320, 640], gamma=0.1) | ||
|
||
# train, val, test setting | ||
train_cfg = dict(by_epoch=True, max_epochs=800, val_interval=1) | ||
test_cfg = dict() | ||
|
||
# NOTE: `auto_scale_lr` is for automatically scaling LR | ||
# based on the actual training batch size. | ||
auto_scale_lr = dict(base_batch_size=128) | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='RandomResizedCrop', scale=32), | ||
dict(type='RandomFlip', prob=0.5, direction='horizontal'), | ||
dict(type='PackClsInputs'), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=256, | ||
num_workers=5, | ||
dataset=dict( | ||
type='ImageNet', | ||
data_root='/cache/data/imagenet/', | ||
data_prefix='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
) | ||
|
||
test_pipeline = [ | ||
dict(type='PackClsInputs'), | ||
] | ||
|
||
val_dataloader = dict( | ||
batch_size=16, | ||
num_workers=2, | ||
dataset=dict( | ||
type='CIFAR10', | ||
data_prefix='/cache/data/cifar', | ||
test_mode=True, | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
) | ||
val_evaluator = dict(type='Accuracy', topk=(1, )) | ||
|
||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator | ||
|
||
teacher_ckpt = '/cache/models/resnet_model.pth' # noqa: E501 | ||
|
||
model = dict( | ||
_scope_='mmrazor', | ||
type='DFNDDistill', | ||
calculate_student_loss=False, | ||
data_preprocessor=dict( | ||
type='ImgDataPreprocessor', | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
bgr_to_rgb=True), | ||
val_data_preprocessor=dict( | ||
type='ImgDataPreprocessor', | ||
# RGB format normalization parameters | ||
mean=[125.307, 122.961, 113.8575], | ||
std=[51.5865, 50.847, 51.255], | ||
# convert image from BGR to RGB | ||
bgr_to_rgb=False), | ||
architecture=dict( | ||
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False), | ||
teacher=dict( | ||
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', pretrained=False), | ||
teacher_ckpt=teacher_ckpt, | ||
distiller=dict( | ||
type='ConfigurableDistiller', | ||
student_recorders=dict( | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
teacher_recorders=dict( | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
distill_losses=dict( | ||
loss_kl=dict( | ||
type='DFNDLoss', | ||
tau=4, | ||
loss_weight=1, | ||
num_classes=10, | ||
batch_select=0.5)), | ||
loss_forward_mappings=dict( | ||
loss_kl=dict( | ||
preds_S=dict(from_student=True, recorder='fc'), | ||
preds_T=dict(from_student=False, recorder='fc'))))) | ||
|
||
find_unused_parameters = True | ||
|
||
val_cfg = dict(type='mmrazor.DFNDValLoop') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.