diff --git a/mmrazor/engine/hooks/stop_distillation_hook.py b/mmrazor/engine/hooks/stop_distillation_hook.py index 3b907dc61..c70080381 100644 --- a/mmrazor/engine/hooks/stop_distillation_hook.py +++ b/mmrazor/engine/hooks/stop_distillation_hook.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.hooks import Hook +from mmengine.logging import MessageHub from mmengine.model import is_model_wrapper from mmrazor.registry import HOOKS @@ -13,14 +14,20 @@ class StopDistillHook(Hook): stop_epoch (int): Stop distillation at this epoch. """ - priority = 'LOW' - def __init__(self, stop_epoch: int) -> None: self.stop_epoch = stop_epoch + def _clear_message_hub(self): + """Private method to clear distillation-related log scalars.""" + message_hub = MessageHub.get_current_instance() + log_scalars = message_hub.log_scalars + keys_del = [key for key in log_scalars.keys() if 'distill' in key] + for key in keys_del: + del log_scalars[key] + def before_train_epoch(self, runner) -> None: """Stop distillation.""" - if runner.epoch >= self.stop_epoch: + if runner.epoch == self.stop_epoch: model = runner.model # TODO: refactor after mmengine using model wrapper if is_model_wrapper(model): @@ -28,4 +35,6 @@ def before_train_epoch(self, runner) -> None: assert hasattr(model, 'distillation_stopped') runner.logger.info('Distillation has been stopped!') - model.distillation_stopped = True + model.distillation_stopped[0] = True + + self._clear_message_hub() diff --git a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py index 97139d256..4cf562c38 100644 --- a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py +++ b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py @@ -90,7 +90,7 @@ def __init__(self, self.distiller.prepare_from_teacher(self.teacher) # may be modified by stop distillation hook - self.distillation_stopped = False + self.register_buffer('distillation_stopped', torch.tensor([False])) @property def student(self) -> nn.Module: diff --git a/tests/test_engine/test_hooks/test_stop_distillation_hook.py b/tests/test_engine/test_hooks/test_stop_distillation_hook.py index 49e753b64..151b67de3 100644 --- a/tests/test_engine/test_hooks/test_stop_distillation_hook.py +++ b/tests/test_engine/test_hooks/test_stop_distillation_hook.py @@ -1,21 +1,31 @@ # Copyright (c) OpenMMLab. All rights reserved. +import random from unittest import TestCase from unittest.mock import Mock +import torch +import torch.nn as nn +from mmengine.logging import MessageHub + from mmrazor.engine import StopDistillHook class TestStopDistillHook(TestCase): def setUp(self): - self.hook = StopDistillHook(stop_epoch=5) + self.stop_epoch = 5 + self.hook = StopDistillHook(stop_epoch=self.stop_epoch) runner = Mock() - runner.model = Mock() - runner.model.distillation_stopped = False + runner.model = nn.Module() + runner.model.register_buffer('distillation_stopped', + torch.tensor([False])) runner.epoch = 0 self.runner = runner + message_hub = dict(name='test') + self.message_hub = MessageHub.get_instance(**message_hub) + def test_before_train_epoch(self): max_epochs = 10 target = [False] * 5 + [True] * 5 @@ -24,3 +34,11 @@ def test_before_train_epoch(self): self.assertEquals(self.runner.model.distillation_stopped, target[epoch]) self.runner.epoch += 1 + + if not self.runner.model.distillation_stopped: + self.message_hub.update_scalar('distill.loss', random.random()) + + if self.runner.model.distillation_stopped: + self.assertNotIn('distill.loss', self.message_hub.log_scalars) + else: + self.assertIn('distill.loss', self.message_hub.log_scalars)