From bb07d920d3639675d7302ef0d534d3febc4018ff Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Thu, 2 Mar 2023 12:33:54 +0800 Subject: [PATCH 1/2] clear distillation-related log scalars after stopping distillation --- mmrazor/engine/hooks/stop_distillation_hook.py | 13 +++++++++++-- .../test_hooks/test_stop_distillation_hook.py | 17 ++++++++++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/mmrazor/engine/hooks/stop_distillation_hook.py b/mmrazor/engine/hooks/stop_distillation_hook.py index 3b907dc61..f45a4d703 100644 --- a/mmrazor/engine/hooks/stop_distillation_hook.py +++ b/mmrazor/engine/hooks/stop_distillation_hook.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.hooks import Hook +from mmengine.logging import MessageHub from mmengine.model import is_model_wrapper from mmrazor.registry import HOOKS @@ -13,11 +14,17 @@ class StopDistillHook(Hook): stop_epoch (int): Stop distillation at this epoch. """ - priority = 'LOW' - def __init__(self, stop_epoch: int) -> None: self.stop_epoch = stop_epoch + def _clear_message_hub(self): + """Private method to clear distillation-related log scalars.""" + message_hub = MessageHub.get_current_instance() + log_scalars = message_hub.log_scalars + keys_del = [key for key in log_scalars.keys() if 'distill' in key] + for key in keys_del: + del log_scalars[key] + def before_train_epoch(self, runner) -> None: """Stop distillation.""" if runner.epoch >= self.stop_epoch: @@ -29,3 +36,5 @@ def before_train_epoch(self, runner) -> None: runner.logger.info('Distillation has been stopped!') model.distillation_stopped = True + + self._clear_message_hub() diff --git a/tests/test_engine/test_hooks/test_stop_distillation_hook.py b/tests/test_engine/test_hooks/test_stop_distillation_hook.py index 49e753b64..f22878d04 100644 --- a/tests/test_engine/test_hooks/test_stop_distillation_hook.py +++ b/tests/test_engine/test_hooks/test_stop_distillation_hook.py @@ -1,14 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. +import random from unittest import TestCase from unittest.mock import Mock +from mmengine.logging import MessageHub + from mmrazor.engine import StopDistillHook class TestStopDistillHook(TestCase): def setUp(self): - self.hook = StopDistillHook(stop_epoch=5) + self.stop_epoch = 5 + self.hook = StopDistillHook(stop_epoch=self.stop_epoch) runner = Mock() runner.model = Mock() runner.model.distillation_stopped = False @@ -16,6 +20,9 @@ def setUp(self): runner.epoch = 0 self.runner = runner + message_hub = dict(name='test') + self.message_hub = MessageHub.get_instance(**message_hub) + def test_before_train_epoch(self): max_epochs = 10 target = [False] * 5 + [True] * 5 @@ -24,3 +31,11 @@ def test_before_train_epoch(self): self.assertEquals(self.runner.model.distillation_stopped, target[epoch]) self.runner.epoch += 1 + + if not self.runner.model.distillation_stopped: + self.message_hub.update_scalar('distill.loss', random.random()) + + if self.runner.model.distillation_stopped: + self.assertNotIn('distill.loss', self.message_hub.log_scalars) + else: + self.assertIn('distill.loss', self.message_hub.log_scalars) From eb77787fe4286932b9bf103039b7e42e0f292ce4 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Tue, 7 Mar 2023 11:00:16 +0800 Subject: [PATCH 2/2] fix stop distillation hook by registering a distillation_stopped buffer --- mmrazor/engine/hooks/stop_distillation_hook.py | 4 ++-- .../distill/configurable/single_teacher_distill.py | 2 +- .../test_engine/test_hooks/test_stop_distillation_hook.py | 7 +++++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mmrazor/engine/hooks/stop_distillation_hook.py b/mmrazor/engine/hooks/stop_distillation_hook.py index f45a4d703..c70080381 100644 --- a/mmrazor/engine/hooks/stop_distillation_hook.py +++ b/mmrazor/engine/hooks/stop_distillation_hook.py @@ -27,7 +27,7 @@ def _clear_message_hub(self): def before_train_epoch(self, runner) -> None: """Stop distillation.""" - if runner.epoch >= self.stop_epoch: + if runner.epoch == self.stop_epoch: model = runner.model # TODO: refactor after mmengine using model wrapper if is_model_wrapper(model): @@ -35,6 +35,6 @@ def before_train_epoch(self, runner) -> None: assert hasattr(model, 'distillation_stopped') runner.logger.info('Distillation has been stopped!') - model.distillation_stopped = True + model.distillation_stopped[0] = True self._clear_message_hub() diff --git a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py index 97139d256..4cf562c38 100644 --- a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py +++ b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py @@ -90,7 +90,7 @@ def __init__(self, self.distiller.prepare_from_teacher(self.teacher) # may be modified by stop distillation hook - self.distillation_stopped = False + self.register_buffer('distillation_stopped', torch.tensor([False])) @property def student(self) -> nn.Module: diff --git a/tests/test_engine/test_hooks/test_stop_distillation_hook.py b/tests/test_engine/test_hooks/test_stop_distillation_hook.py index f22878d04..151b67de3 100644 --- a/tests/test_engine/test_hooks/test_stop_distillation_hook.py +++ b/tests/test_engine/test_hooks/test_stop_distillation_hook.py @@ -3,6 +3,8 @@ from unittest import TestCase from unittest.mock import Mock +import torch +import torch.nn as nn from mmengine.logging import MessageHub from mmrazor.engine import StopDistillHook @@ -14,8 +16,9 @@ def setUp(self): self.stop_epoch = 5 self.hook = StopDistillHook(stop_epoch=self.stop_epoch) runner = Mock() - runner.model = Mock() - runner.model.distillation_stopped = False + runner.model = nn.Module() + runner.model.register_buffer('distillation_stopped', + torch.tensor([False])) runner.epoch = 0 self.runner = runner