diff --git a/configs/_base_/mmdet_runtime.py b/configs/_base_/mmdet_runtime.py index 55097c5b2..091b032e4 100644 --- a/configs/_base_/mmdet_runtime.py +++ b/configs/_base_/mmdet_runtime.py @@ -14,3 +14,11 @@ load_from = None resume_from = None workflow = [('train', 1)] + +# Default setting for scaling LR automatically +# - The flag `auto_scale_lr` means enable scaling LR automatically +# or not by default. +# - `default_batch_size` = (8 GPUs) x (2 samples per GPU). +# - `default_initial_lr` = The LR by default. +auto_scale_lr_config = dict( + auto_scale_lr=False, default_batch_size=16, default_initial_lr=0.01) diff --git a/configs/nas/detnas/detnas_supernet_frcnn_shufflenetv2_fpn_1x_coco.py b/configs/nas/detnas/detnas_supernet_frcnn_shufflenetv2_fpn_1x_coco.py index 95f432441..91106c2a3 100644 --- a/configs/nas/detnas/detnas_supernet_frcnn_shufflenetv2_fpn_1x_coco.py +++ b/configs/nas/detnas/detnas_supernet_frcnn_shufflenetv2_fpn_1x_coco.py @@ -142,3 +142,8 @@ ) find_unused_parameters = True + +# NOTE: `auto_scale_lr_config` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# default_batch_size = (8 GPUs) x (2 samples per GPU) +auto_scale_lr_config = dict(default_batch_size=16, default_initial_lr=0.02) diff --git a/mmrazor/apis/__init__.py b/mmrazor/apis/__init__.py index bc7b8f485..7b140bdc3 100644 --- a/mmrazor/apis/__init__.py +++ b/mmrazor/apis/__init__.py @@ -2,6 +2,6 @@ from .mmcls import * # noqa: F401,F403 from .mmdet import * # noqa: F401,F403 from .mmseg import * # noqa: F401,F403 -from .utils import init_random_seed, set_random_seed # noqa: F401 +from .utils import auto_scale_lr, init_random_seed, set_random_seed -__all__ = ['init_random_seed', 'set_random_seed'] +__all__ = ['init_random_seed', 'set_random_seed', 'auto_scale_lr'] diff --git a/mmrazor/apis/mmcls/train.py b/mmrazor/apis/mmcls/train.py index f86c502c9..7a88b573c 100644 --- a/mmrazor/apis/mmcls/train.py +++ b/mmrazor/apis/mmcls/train.py @@ -11,6 +11,7 @@ from mmcv.runner import EpochBasedRunner, Fp16OptimizerHook, build_runner from mmcv.runner.hooks import DistEvalHook, EvalHook +from mmrazor.apis.utils import auto_scale_lr # Differences from mmclassification. from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper from mmrazor.core.hooks import DistSamplerSeedHook @@ -127,6 +128,7 @@ def train_mmcls_model(model, # build optimizers # Difference from mmclassification. # In some algorithms, there will be multi optimizers. + auto_scale_lr(cfg, distributed, logger) optimizer = build_optimizers(model, cfg.optimizer) if cfg.get('runner') is None: diff --git a/mmrazor/apis/mmdet/train.py b/mmrazor/apis/mmdet/train.py index 828cf6f89..91e9ab650 100644 --- a/mmrazor/apis/mmdet/train.py +++ b/mmrazor/apis/mmdet/train.py @@ -12,6 +12,7 @@ replace_ImageToTensor) from mmdet.utils import get_root_logger +from mmrazor.apis.utils import auto_scale_lr from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper from mmrazor.core.hooks import DistSamplerSeedHook from mmrazor.core.optimizer import build_optimizers @@ -114,6 +115,7 @@ def train_mmdet_model(model, # build optimizers # Difference from mmdetection. # In some algorithms, there will be multi optimizers. + auto_scale_lr(cfg, distributed, logger) optimizer = build_optimizers(model, cfg.optimizer) # build runner diff --git a/mmrazor/apis/mmseg/train.py b/mmrazor/apis/mmseg/train.py index 16de892f9..b3c7d0f32 100644 --- a/mmrazor/apis/mmseg/train.py +++ b/mmrazor/apis/mmseg/train.py @@ -10,6 +10,7 @@ from mmseg.datasets import build_dataloader, build_dataset from mmseg.utils import get_root_logger +from mmrazor.apis.utils import auto_scale_lr from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper from mmrazor.core.optimizer import build_optimizers from mmrazor.utils import find_latest_checkpoint @@ -96,6 +97,7 @@ def train_mmseg_model(model, # build optimizers # Difference from mmdetection. # In some algorithms, there will be multi optimizers. + auto_scale_lr(cfg, distributed, logger) optimizer = build_optimizers(model, cfg.optimizer) # build runner diff --git a/mmrazor/apis/utils.py b/mmrazor/apis/utils.py index e00c57806..1e4ed2688 100644 --- a/mmrazor/apis/utils.py +++ b/mmrazor/apis/utils.py @@ -55,3 +55,79 @@ def set_random_seed(seed: int, deterministic: bool = False) -> None: if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False + + +def auto_scale_lr(cfg, distributed, logger): + """Automatically scaling LR according to GPU number and sample per GPU. + + Args: + cfg (config): Training config. + distributed (bool): Using distributed or not. + logger (logging.Logger): Logger. + """ + warning_msg = 'in your configuration file. Please update all the ' \ + 'configuration files to mmdet >= 2.24.0. ' \ + 'Disable automatic scaling of learning rate.' + + # default config of auto scale lr + if 'auto_scale_lr_config' not in cfg: + logger.warning(f'Can not find "auto_scale_lr_config" {warning_msg}') + return + + # Get flag from config + auto_scale_lr_flag = cfg.auto_scale_lr_config.get('auto_scale_lr', False) + if auto_scale_lr_flag is False: + logger.info('Automatic scaling of learning rate (LR)' + ' has been disabled.') + return + + # Get default batch size from config + default_batch_size = cfg.auto_scale_lr_config.get('default_batch_size', 0) + if default_batch_size == 0: + logger.warning('Can not find "default_batch_size" ' f'{warning_msg}') + return + + # Get default initial LR from config + default_initial_lr = cfg.auto_scale_lr_config.get('default_initial_lr', 0) + if default_initial_lr == 0: + logger.warning('Can not find "default_initial_lr" ' f'{warning_msg}') + return + + # Get gpu number + if distributed: + _, world_size = get_dist_info() + num_gpus = range(world_size) + else: + num_gpus = len(cfg.gpu_ids) + + # calculate the batch size + batch_size = num_gpus * cfg.data.samples_per_gpu + + logger.info(f'You are using {num_gpus} GPU(s) ' + f'and {cfg.data.samples_per_gpu} samples per GPU. ' + f'Total batch size is {batch_size}.') + + if batch_size != default_batch_size: + + if cfg.optimizer.lr != default_initial_lr: + logger.warning( + 'It seems that you changed "cfg.optimizer.lr" to ' + f'{cfg.optimizer.lr} which is not the default initial lr ' + f'({default_initial_lr}) from the config file. The ' + 'automatically scaling LR will use the "cfg.optimizer.lr" to' + ' calculate the new LR. This may not lead to a best result of' + ' the training. If you know what are you doing, ignore this ' + 'warning message.') + + # scale LR with + # [linear scaling rule](https://arxiv.org/abs/1706.02677) + scaled_lr = (batch_size / default_batch_size) * cfg.optimizer.lr + logger.info('LR has been automatically scaled ' + f'from {cfg.optimizer.lr} to {scaled_lr}') + + cfg.optimizer.lr = scaled_lr + + else: + logger.info('The batch size match the ' + f'default batch size: {default_batch_size}, ' + f'will not scaling the LR ({cfg.optimizer.lr}).') diff --git a/tools/mmcls/train_mmcls.py b/tools/mmcls/train_mmcls.py index 43034750f..83821dede 100644 --- a/tools/mmcls/train_mmcls.py +++ b/tools/mmcls/train_mmcls.py @@ -79,6 +79,10 @@ def parse_args(): default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--auto-scale-lr', + action='store_true', + help='enable automatically scaling LR.') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) @@ -93,6 +97,17 @@ def main(): if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) + if args.auto_scale_lr: + if 'auto_scale_lr_config' in cfg and \ + 'auto_scale_lr' in cfg.auto_scale_lr_config: + cfg.auto_scale_lr_config.auto_scale_lr = True + else: + warnings.warn('Can not find "auto_scale_lr_config" or ' + '"auto_scale_lr" in your configuration file. ' + 'Please update all the configuration files ' + 'to mmrazor >= 0.3.0. ' + 'Disable automatic scaling of learning rate.') + # set multi-process settings setup_multi_processes(cfg) diff --git a/tools/mmdet/train_mmdet.py b/tools/mmdet/train_mmdet.py index 688f86732..7f90e77ff 100644 --- a/tools/mmdet/train_mmdet.py +++ b/tools/mmdet/train_mmdet.py @@ -86,6 +86,10 @@ def parse_args(): default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--auto-scale-lr', + action='store_true', + help='enable automatically scaling LR.') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) @@ -100,6 +104,17 @@ def main(): if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) + if args.auto_scale_lr: + if 'auto_scale_lr_config' in cfg and \ + 'auto_scale_lr' in cfg.auto_scale_lr_config: + cfg.auto_scale_lr_config.auto_scale_lr = True + else: + warnings.warn('Can not find "auto_scale_lr_config" or ' + '"auto_scale_lr" in your configuration file. ' + 'Please update all the configuration files ' + 'to mmrazor >= 0.3.0. ' + 'Disable automatic scaling of learning rate.') + # set multi-process settings setup_multi_processes(cfg) diff --git a/tools/mmseg/train_mmseg.py b/tools/mmseg/train_mmseg.py index d80978d92..49a56ea10 100644 --- a/tools/mmseg/train_mmseg.py +++ b/tools/mmseg/train_mmseg.py @@ -89,6 +89,10 @@ def parse_args(): default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--auto-scale-lr', + action='store_true', + help='enable automatically scaling LR.') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) @@ -107,6 +111,17 @@ def main(): # set multi-process settings setup_multi_processes(cfg) + if args.auto_scale_lr: + if 'auto_scale_lr_config' in cfg and \ + 'auto_scale_lr' in cfg.auto_scale_lr_config: + cfg.auto_scale_lr_config.auto_scale_lr = True + else: + warnings.warn('Can not find "auto_scale_lr_config" or ' + '"auto_scale_lr" in your configuration file. ' + 'Please update all the configuration files ' + 'to mmrazor >= 0.3.0. ' + 'Disable automatic scaling of learning rate.') + # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True