From 78ccf94afae90eb862c3dfe9a14cf50ded8e122c Mon Sep 17 00:00:00 2001 From: aptsunny Date: Mon, 2 Jan 2023 13:08:03 +0800 Subject: [PATCH 1/3] zenscore fly --- .gitignore | 3 + command.sh | 24 + .../backbones/1ms/init_plainnet.txt | 1 + .../backbones/PlainNet/SuperResIDWEXKX.py | 273 ++ .../backbones/PlainNet/SuperResK1KXK1.py | 199 ++ .../backbones/PlainNet/SuperResKXKX.py | 174 ++ .../backbones/PlainNet/__init__.py | 324 +++ .../backbones/PlainNet/basic_blocks.py | 2556 +++++++++++++++++ .../backbones/PlainNet/plainnet.py | 265 ++ .../backbones/PlainNet/super_blocks.py | 222 ++ .../SearchSpace/search_space_IDW_fixfc.py | 123 + .../SearchSpace/search_space_XXBL.py | 103 + .../ZeroShotProxy/compute_NASWOT_score.py | 126 + .../ZeroShotProxy/compute_gradnorm_score.py | 76 + .../ZeroShotProxy/compute_syncflow_score.py | 138 + .../ZeroShotProxy/compute_te_nas_score.py | 286 ++ .../ZeroShotProxy/compute_zen_score.py | 146 + .../backbones/benchmark_network_latency.py | 134 + .../architectures/backbones/global_utils.py | 282 ++ .../architectures/backbones/masternet.py | 601 ++++ slurm_zenscore.sh | 14 + tools/slurm_test_.sh | 24 + tools/slurm_train_.sh | 25 + 23 files changed, 6119 insertions(+) create mode 100755 command.sh create mode 100644 mmrazor/models/architectures/backbones/1ms/init_plainnet.txt create mode 100644 mmrazor/models/architectures/backbones/PlainNet/SuperResIDWEXKX.py create mode 100644 mmrazor/models/architectures/backbones/PlainNet/SuperResK1KXK1.py create mode 100644 mmrazor/models/architectures/backbones/PlainNet/SuperResKXKX.py create mode 100644 mmrazor/models/architectures/backbones/PlainNet/__init__.py create mode 100644 mmrazor/models/architectures/backbones/PlainNet/basic_blocks.py create mode 100644 mmrazor/models/architectures/backbones/PlainNet/plainnet.py create mode 100644 mmrazor/models/architectures/backbones/PlainNet/super_blocks.py create mode 100644 mmrazor/models/architectures/backbones/SearchSpace/search_space_IDW_fixfc.py create mode 100644 mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py create mode 100644 mmrazor/models/architectures/backbones/ZeroShotProxy/compute_NASWOT_score.py create mode 100644 mmrazor/models/architectures/backbones/ZeroShotProxy/compute_gradnorm_score.py create mode 100644 mmrazor/models/architectures/backbones/ZeroShotProxy/compute_syncflow_score.py create mode 100644 mmrazor/models/architectures/backbones/ZeroShotProxy/compute_te_nas_score.py create mode 100644 mmrazor/models/architectures/backbones/ZeroShotProxy/compute_zen_score.py create mode 100644 mmrazor/models/architectures/backbones/benchmark_network_latency.py create mode 100644 mmrazor/models/architectures/backbones/global_utils.py create mode 100644 mmrazor/models/architectures/backbones/masternet.py create mode 100755 slurm_zenscore.sh create mode 100755 tools/slurm_test_.sh create mode 100755 tools/slurm_train_.sh diff --git a/.gitignore b/.gitignore index e14966259..c10f6edbb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +.base/* +.torch/* + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/command.sh b/command.sh new file mode 100755 index 000000000..8d93a0ff9 --- /dev/null +++ b/command.sh @@ -0,0 +1,24 @@ +#!/bin/bash + + +job_name=$1 +train_gpu=$2 +num_node=$3 +command=$4 +total_process=$((train_gpu*num_node)) + +mkdir -p log + +now=$(date +"%Y%m%d_%H%M%S") + +# nohup +GLOG_vmodule=MemcachedClient=-1 \ +srun --partition=pat_dev \ +--mpi=pmi2 -n$total_process \ +--gres=gpu:$train_gpu \ +--ntasks-per-node=$train_gpu \ +--job-name=$job_name \ +--kill-on-bad-exit=1 \ +--cpus-per-task=5 \ +$command 2>&1|tee -a log/$job_name.log & +# -x "SH-IDC1-10-198-4-[104,106]" \ \ No newline at end of file diff --git a/mmrazor/models/architectures/backbones/1ms/init_plainnet.txt b/mmrazor/models/architectures/backbones/1ms/init_plainnet.txt new file mode 100644 index 000000000..99f5cc71c --- /dev/null +++ b/mmrazor/models/architectures/backbones/1ms/init_plainnet.txt @@ -0,0 +1 @@ +SuperConvK3BNRELU(3,32,2,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,128,2,64,1)SuperResK3K3(128,256,2,128,1)SuperResK3K3(256,512,2,256,1)SuperConvK1BNRELU(256,512,1,1) \ No newline at end of file diff --git a/mmrazor/models/architectures/backbones/PlainNet/SuperResIDWEXKX.py b/mmrazor/models/architectures/backbones/PlainNet/SuperResIDWEXKX.py new file mode 100644 index 000000000..3f80571d8 --- /dev/null +++ b/mmrazor/models/architectures/backbones/PlainNet/SuperResIDWEXKX.py @@ -0,0 +1,273 @@ +''' +Copyright (C) 2010-2021 Alibaba Group Holding Limited. +''' + + +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import uuid + +import PlainNet +from PlainNet import _get_right_parentheses_index_ +from PlainNet.super_blocks import PlainNetSuperBlockClass +from torch import nn +import global_utils + + +class SuperResIDWEXKX(PlainNetSuperBlockClass): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, + kernel_size=None, expension=None, + no_create=False, no_reslink=False, no_BN=False, use_se=False, **kwargs): + super(SuperResIDWEXKX, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.expension = expension + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + + self.use_se = use_se + if self.use_se: + print('---debug use_se in ' + str(self)) + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + # first DW + dw_channels = global_utils.smart_round(self.bottleneck_channels * self.expension, base=8) + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, dw_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + + inner_str += 'ConvDW({},{},{})'.format(dw_channels, self.kernel_size, current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + if self.use_se: + inner_str += 'SE({})'.format(dw_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(dw_channels, bottleneck_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(bottleneck_channels) + # inner_str += 'RELU({})'.format(bottleneck_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format(inner_str, self.out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format(inner_str, self.out_channels) + + else: + res_str = '{}RELU({})'.format(inner_str, self.out_channels) + + full_str += res_str + + # second DW + inner_str = '' + dw_channels = global_utils.smart_round(self.out_channels * self.expension, base=8) + inner_str += 'ConvKX({},{},{},{})'.format(bottleneck_channels, dw_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + + inner_str += 'ConvDW({},{},{})'.format(dw_channels, self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + if self.use_se: + inner_str += 'SE({})'.format(dw_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(dw_channels, self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + res_str = 'ResBlock({})RELU({})'.format(inner_str, self.out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, self.out_channels) + + full_str += res_str + last_channels = out_channels + current_stride = 1 + pass + + self.block_list = PlainNet.create_netblock_list_from_str(full_str, no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, self.bottleneck_channels, self.sub_layers, self.kernel_size + ) + + def encode_structure(self): + return [self.out_channels, self.sub_layers, self.bottleneck_channels] + + def split(self, split_layer_threshold): + if self.sub_layers >= split_layer_threshold: + new_sublayers_1 = split_layer_threshold // 2 + new_sublayers_2 = self.sub_layers - new_sublayers_1 + new_block_str1 = type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, new_sublayers_1) + new_block_str2 = type(self).__name__ + '({},{},{},{},{})'.format(self.out_channels, self.out_channels, + 1, self.bottleneck_channels, + new_sublayers_2) + return new_block_str1 + new_block_str2 + else: + return str(self) + + def structure_scale(self, scale=1.0, channel_scale=None, sub_layer_scale=None): + if channel_scale is None: + channel_scale = scale + if sub_layer_scale is None: + sub_layer_scale = scale + + new_out_channels = global_utils.smart_round(self.out_channels * channel_scale) + new_bottleneck_channels = global_utils.smart_round(self.bottleneck_channels * channel_scale) + new_sub_layers = max(1, round(self.sub_layers * sub_layer_scale)) + + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, new_out_channels, + self.stride, new_bottleneck_channels, new_sub_layers) + + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + block_name=tmp_block_name, **kwargs),s[idx + 1:] + + +class SuperResIDWE1K3(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE1K3, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, expension=1.0, + no_create=no_create, **kwargs) + +class SuperResIDWE2K3(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE2K3, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, expension=2.0, + no_create=no_create, **kwargs) + +class SuperResIDWE4K3(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE4K3, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, expension=4.0, + no_create=no_create, **kwargs) + +class SuperResIDWE6K3(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE6K3, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, expension=6.0, + no_create=no_create, **kwargs) + + +class SuperResIDWE1K5(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE1K5, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, expension=1.0, + no_create=no_create, **kwargs) + +class SuperResIDWE2K5(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE2K5, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, expension=2.0, + no_create=no_create, **kwargs) + +class SuperResIDWE4K5(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE4K5, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, expension=4.0, + no_create=no_create, **kwargs) + +class SuperResIDWE6K5(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE6K5, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, expension=6.0, + no_create=no_create, **kwargs) + +class SuperResIDWE1K7(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE1K7, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, expension=1.0, + no_create=no_create, **kwargs) + +class SuperResIDWE2K7(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE2K7, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, expension=2.0, + no_create=no_create, **kwargs) + +class SuperResIDWE4K7(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE4K7, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, expension=4.0, + no_create=no_create, **kwargs) + +class SuperResIDWE6K7(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE6K7, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, expension=6.0, + no_create=no_create, **kwargs) + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperResIDWE1K3': SuperResIDWE1K3, + 'SuperResIDWE2K3': SuperResIDWE2K3, + 'SuperResIDWE4K3': SuperResIDWE4K3, + 'SuperResIDWE6K3': SuperResIDWE6K3, + 'SuperResIDWE1K5': SuperResIDWE1K5, + 'SuperResIDWE2K5': SuperResIDWE2K5, + 'SuperResIDWE4K5': SuperResIDWE4K5, + 'SuperResIDWE6K5': SuperResIDWE6K5, + 'SuperResIDWE1K7': SuperResIDWE1K7, + 'SuperResIDWE2K7': SuperResIDWE2K7, + 'SuperResIDWE4K7': SuperResIDWE4K7, + 'SuperResIDWE6K7': SuperResIDWE6K7, + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict \ No newline at end of file diff --git a/mmrazor/models/architectures/backbones/PlainNet/SuperResK1KXK1.py b/mmrazor/models/architectures/backbones/PlainNet/SuperResK1KXK1.py new file mode 100644 index 000000000..e222598f8 --- /dev/null +++ b/mmrazor/models/architectures/backbones/PlainNet/SuperResK1KXK1.py @@ -0,0 +1,199 @@ +''' +Copyright (C) 2010-2021 Alibaba Group Holding Limited. +''' + + +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import uuid + +import PlainNet +from PlainNet import _get_right_parentheses_index_ +from PlainNet.super_blocks import PlainNetSuperBlockClass +from torch import nn +import global_utils + + +class SuperResK1KXK1(PlainNetSuperBlockClass): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None ,sub_layers=None, kernel_size=None, + no_create=False, no_reslink=False, no_BN=False, use_se=False, **kwargs): + super(SuperResK1KXK1, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + if self.use_se: + print('---debug use_se in ' + str(self)) + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + + # first bl-block with reslink + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, self.bottleneck_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.bottleneck_channels, + self.kernel_size, current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format(inner_str, out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format(inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + # second bl-block with reslink + inner_str = '' + inner_str += 'ConvKX({},{},{},{})'.format(self.out_channels, self.bottleneck_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.bottleneck_channels, + self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + res_str = 'ResBlock({})RELU({})'.format(inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + last_channels = out_channels + current_stride = 1 + pass + + self.block_list = PlainNet.create_netblock_list_from_str(full_str, no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, self.bottleneck_channels, self.sub_layers, self.kernel_size + ) + + def encode_structure(self): + return [self.out_channels, self.sub_layers, self.bottleneck_channels] + + def split(self, split_layer_threshold): + if self.sub_layers >= split_layer_threshold: + new_sublayers_1 = split_layer_threshold // 2 + new_sublayers_2 = self.sub_layers - new_sublayers_1 + new_block_str1 = type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, new_sublayers_1) + new_block_str2 = type(self).__name__ + '({},{},{},{},{})'.format(self.out_channels, self.out_channels, + 1, self.bottleneck_channels, + new_sublayers_2) + return new_block_str1 + new_block_str2 + else: + return str(self) + + def structure_scale(self, scale=1.0, channel_scale=None, sub_layer_scale=None): + if channel_scale is None: + channel_scale = scale + if sub_layer_scale is None: + sub_layer_scale = scale + + new_out_channels = global_utils.smart_round(self.out_channels * channel_scale) + new_bottleneck_channels = global_utils.smart_round(self.bottleneck_channels * channel_scale) + new_sub_layers = max(1, round(self.sub_layers * sub_layer_scale)) + + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, new_out_channels, + self.stride, new_bottleneck_channels, new_sub_layers) + + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + block_name=tmp_block_name, **kwargs),s[idx + 1:] + + +class SuperResK1K3K1(SuperResK1KXK1): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK1K3K1, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, **kwargs) + +class SuperResK1K5K1(SuperResK1KXK1): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK1K5K1, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, **kwargs) + + +class SuperResK1K7K1(SuperResK1KXK1): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK1K7K1, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, **kwargs) + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperResK1K3K1': SuperResK1K3K1, + 'SuperResK1K5K1': SuperResK1K5K1, + 'SuperResK1K7K1': SuperResK1K7K1, + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict \ No newline at end of file diff --git a/mmrazor/models/architectures/backbones/PlainNet/SuperResKXKX.py b/mmrazor/models/architectures/backbones/PlainNet/SuperResKXKX.py new file mode 100644 index 000000000..c2ba60f0b --- /dev/null +++ b/mmrazor/models/architectures/backbones/PlainNet/SuperResKXKX.py @@ -0,0 +1,174 @@ +''' +Copyright (C) 2010-2021 Alibaba Group Holding Limited. +''' + + +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import uuid + +import PlainNet +from PlainNet import _get_right_parentheses_index_ +from PlainNet.super_blocks import PlainNetSuperBlockClass +from torch import nn +import global_utils + + +class SuperResKXKX(PlainNetSuperBlockClass): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, kernel_size=None, + no_create=False, no_reslink=False, no_BN=False, use_se=False, **kwargs): + super(SuperResKXKX, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + if self.use_se: + print('---debug use_se in ' + str(self)) + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, self.bottleneck_channels, self.kernel_size, current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.out_channels, self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format(inner_str, out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format(inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + last_channels = out_channels + current_stride = 1 + pass + + self.block_list = PlainNet.create_netblock_list_from_str(full_str, no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def forward_pre_relu(self, x): + output = x + for block in self.block_list[0:-1]: + output = block(output) + return output + + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, self.bottleneck_channels, self.sub_layers, self.kernel_size + ) + + def encode_structure(self): + return [self.out_channels, self.sub_layers, self.bottleneck_channels] + + def split(self, split_layer_threshold): + if self.sub_layers >= split_layer_threshold: + new_sublayers_1 = split_layer_threshold // 2 + new_sublayers_2 = self.sub_layers - new_sublayers_1 + new_block_str1 = type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, new_sublayers_1) + new_block_str2 = type(self).__name__ + '({},{},{},{},{})'.format(self.out_channels, self.out_channels, + 1, self.bottleneck_channels, + new_sublayers_2) + return new_block_str1 + new_block_str2 + else: + return str(self) + + def structure_scale(self, scale=1.0, channel_scale=None, sub_layer_scale=None): + if channel_scale is None: + channel_scale = scale + if sub_layer_scale is None: + sub_layer_scale = scale + + new_out_channels = global_utils.smart_round(self.out_channels * channel_scale) + new_bottleneck_channels = global_utils.smart_round(self.bottleneck_channels * channel_scale) + new_sub_layers = max(1, round(self.sub_layers * sub_layer_scale)) + + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, new_out_channels, + self.stride, new_bottleneck_channels, new_sub_layers) + + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + block_name=tmp_block_name, **kwargs), s[idx + 1:] + + +class SuperResK3K3(SuperResKXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK3K3, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, **kwargs) + +class SuperResK5K5(SuperResKXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK5K5, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, **kwargs) + +class SuperResK7K7(SuperResKXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK7K7, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, **kwargs) + + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperResK3K3': SuperResK3K3, + 'SuperResK5K5': SuperResK5K5, + 'SuperResK7K7': SuperResK7K7, + + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict diff --git a/mmrazor/models/architectures/backbones/PlainNet/__init__.py b/mmrazor/models/architectures/backbones/PlainNet/__init__.py new file mode 100644 index 000000000..8b3942e11 --- /dev/null +++ b/mmrazor/models/architectures/backbones/PlainNet/__init__.py @@ -0,0 +1,324 @@ +from .basic_blocks import AdaptiveAvgPool, BN, ConvDW, ConvKX, ConvKXG2, ConvKXG4, ConvKXG8, ConvKXG16, ConvKXG32, Flatten, Linear, MaxPool, MultiSumBlock, MultiCatBlock, PlainNetBasicBlockClass, RELU, ResBlock, ResBlockProj, Sequential, SE, Swish, SuperConvK1BNRELU, SuperConvK3BNRELU, SuperConvK5BNRELU, SuperConvK7BNRELU, SuperResK3K3, SuperResK5K5, SuperResK7K7, SuperResK1K3K1, SuperResK1K5K1, SuperResK1K7K1, SuperResIDWE1K3, SuperResIDWE2K3, SuperResIDWE4K3, SuperResIDWE6K3, SuperResIDWE1K5, SuperResIDWE2K5, SuperResIDWE4K5, SuperResIDWE6K5, SuperResIDWE1K7, SuperResIDWE2K7, SuperResIDWE4K7, SuperResIDWE6K7, PlainNetSuperBlockClass + + +__all__ = [ + 'AdaptiveAvgPool', + 'BN', + 'ConvDW', + 'ConvKX', + 'ConvKXG2', + 'ConvKXG4', 'ConvKXG8', 'ConvKXG16', 'ConvKXG32', 'Flatten', 'Linear', + 'MaxPool', 'MultiSumBlock', 'MultiCatBlock', 'PlainNetBasicBlockClass', 'RELU', 'ResBlock', 'ResBlockProj', 'Sequential', 'SE', 'Swish', + 'SuperConvK1BNRELU', + 'SuperConvK3BNRELU', + 'SuperConvK5BNRELU', + 'SuperConvK7BNRELU', + 'SuperResK3K3', + 'SuperResK5K5', + 'SuperResK7K7', + 'SuperResK1K3K1', + 'SuperResK1K5K1', + 'SuperResK1K7K1', + 'SuperResIDWE1K3', + 'SuperResIDWE2K3', + 'SuperResIDWE4K3', + 'SuperResIDWE6K3', + 'SuperResIDWE1K5', + 'SuperResIDWE2K5', + 'SuperResIDWE4K5', + 'SuperResIDWE6K5', + 'SuperResIDWE1K7', + 'SuperResIDWE2K7', + 'SuperResIDWE4K7', + 'SuperResIDWE6K7', + 'PlainNetSuperBlockClass', + # 'SuperBlock' +] + + +""" + +_all_netblocks_dict_ = {} + +def parse_cmd_options(argv, opt=None): + parser = argparse.ArgumentParser() + parser.add_argument('--plainnet_struct', type=str, default=None, help='PlainNet structure string') + parser.add_argument('--plainnet_struct_txt', type=str, default=None, help='PlainNet structure file name') + parser.add_argument('--num_classes', type=int, default=None, help='how to prune') + module_opt, _ = parser.parse_known_args(argv) + + return module_opt + + + +def pretty_format(plainnet_str, indent=2): + the_formated_str = '' + indent_str = '' + if indent >= 1: + indent_str = ''.join([' '] * indent) + + # print(indent_str, end='') + the_formated_str += indent_str + + s = plainnet_str + while len(s) > 0: + if s[0] == ';': + # print(';\n' + indent_str, end='') + the_formated_str += ';\n' + indent_str + s = s[1:] + + left_par_idx = s.find('(') + assert left_par_idx is not None + right_par_idx = _get_right_parentheses_index_(s) + the_block_class_name = s[0:left_par_idx] + + if the_block_class_name in ['MultiSumBlock', 'MultiCatBlock','MultiGroupBlock']: + # print('\n' + indent_str + the_block_class_name + '(') + sub_str = s[left_par_idx + 1:right_par_idx] + + # find block_name + tmp_idx = sub_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'no_name' + else: + tmp_block_name = sub_str[0:tmp_idx] + sub_str = sub_str[tmp_idx+1:] + + if len(tmp_block_name) > 8: + tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] + + the_formated_str += '\n' + indent_str + the_block_class_name + '({}|\n'.format(tmp_block_name) + + the_formated_str += pretty_format(sub_str, indent + 1) + # print('\n' + indent_str + ')') + # print(indent_str, end='') + the_formated_str += '\n' + indent_str + ')\n' + indent_str + elif the_block_class_name in ['ResBlock']: + # print('\n' + indent_str + the_block_class_name + '(') + in_channels = None + the_stride = None + sub_str = s[left_par_idx + 1:right_par_idx] + # find block_name + tmp_idx = sub_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'no_name' + else: + tmp_block_name = sub_str[0:tmp_idx] + sub_str = sub_str[tmp_idx + 1:] + + first_comma_index = sub_str.find(',') + if first_comma_index < 0 or not sub_str[0:first_comma_index].isdigit(): + in_channels = None + else: + in_channels = int(sub_str[0:first_comma_index]) + sub_str = sub_str[first_comma_index+1:] + second_comma_index = sub_str.find(',') + if second_comma_index < 0 or not sub_str[0:second_comma_index].isdigit(): + the_stride = None + else: + the_stride = int(sub_str[0:second_comma_index]) + sub_str = sub_str[second_comma_index + 1:] + pass + pass + + if len(tmp_block_name) > 8: + tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] + + the_formated_str += '\n' + indent_str + the_block_class_name + '({}|'.format(tmp_block_name) + if in_channels is not None: + the_formated_str += '{},'.format(in_channels) + else: + the_formated_str += ',' + + if the_stride is not None: + the_formated_str += '{},'.format(the_stride) + else: + the_formated_str += ',' + + the_formated_str += '\n' + + the_formated_str += pretty_format(sub_str, indent + 1) + # print('\n' + indent_str + ')') + # print(indent_str, end='') + the_formated_str += '\n' + indent_str + ')\n' + indent_str + else: + # print(s[0:right_par_idx+1], end='') + sub_str = s[left_par_idx + 1:right_par_idx] + # find block_name + tmp_idx = sub_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'no_name' + else: + tmp_block_name = sub_str[0:tmp_idx] + sub_str = sub_str[tmp_idx + 1:] + + if len(tmp_block_name) > 8: + tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] + + the_formated_str += the_block_class_name + '({}|'.format(tmp_block_name) + sub_str + ')' + + s = s[right_par_idx+1:] + pass # end while + + return the_formated_str + +def _get_right_parentheses_index_(s): + # assert s[0] == '(' + left_paren_count = 0 + for index, x in enumerate(s): + + if x == '(': + left_paren_count += 1 + elif x == ')': + left_paren_count -= 1 + if left_paren_count == 0: + return index + else: + pass + return None + +def _create_netblock_list_from_str_(s, no_create=False, **kwargs): + block_list = [] + while len(s) > 0: + is_found_block_class = False + for the_block_class_name in _all_netblocks_dict_.keys(): + tmp_idx = s.find('(') + if tmp_idx > 0 and s[0:tmp_idx] == the_block_class_name: # classname + is_found_block_class = True + the_block_class = _all_netblocks_dict_[the_block_class_name] + the_block, remaining_s = the_block_class.create_from_str(s, no_create=no_create, **kwargs) + if the_block is not None: + block_list.append(the_block) + s = remaining_s + if len(s) > 0 and s[0] == ';': + return block_list, s[1:] + break + pass # end if + pass # end for + assert is_found_block_class + pass # end while + return block_list, '' + +def create_netblock_list_from_str(s, no_create=False, **kwargs): + the_list, remaining_s = _create_netblock_list_from_str_(s, no_create=no_create, **kwargs) + assert len(remaining_s) == 0 + return the_list + +def add_SE_block(structure_str: str): + new_str = '' + RELU = 'RELU' + offset = 4 + + idx = structure_str.find(RELU) + while idx >= 0: + new_str += structure_str[0: idx] + structure_str = structure_str[idx:] + r_idx = _get_right_parentheses_index_(structure_str[offset:]) + offset + channels = structure_str[offset + 1:r_idx] + new_str += 'RELU({})SE({})'.format(channels, channels) + structure_str = structure_str[r_idx + 1:] + idx = structure_str.find(RELU) + pass + + new_str += structure_str + return new_str + + +class PlainNet(nn.Module): + def __init__(self, argv=None, opt=None, num_classes=None, plainnet_struct=None, no_create=False, + **kwargs): + super(PlainNet, self).__init__() + self.argv = argv + self.opt = opt + self.num_classes = num_classes + self.plainnet_struct = plainnet_struct + + self.module_opt = parse_cmd_options(self.argv) + + if self.num_classes is None: + self.num_classes = self.module_opt.num_classes + + if self.plainnet_struct is None and self.module_opt.plainnet_struct is not None: + self.plainnet_struct = self.module_opt.plainnet_struct + + if self.plainnet_struct is None: + # load structure from text file + if hasattr(opt, 'plainnet_struct_txt') and opt.plainnet_struct_txt is not None: + plainnet_struct_txt = opt.plainnet_struct_txt + else: + plainnet_struct_txt = self.module_opt.plainnet_struct_txt + + if plainnet_struct_txt is not None: + with open(plainnet_struct_txt, 'r') as fid: + the_line = fid.readlines()[0].strip() + self.plainnet_struct = the_line + pass + + if self.plainnet_struct is None: + return + + the_s = self.plainnet_struct # type: str + + block_list, remaining_s = _create_netblock_list_from_str_(the_s, no_create=no_create, **kwargs) + assert len(remaining_s) == 0 + + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) # register + + def forward(self, x): + output = x + for the_block in self.block_list: + output = the_block(output) + return output + + def __str__(self): + s = '' + for the_block in self.block_list: + s += str(the_block) + return s + + def __repr__(self): + return str(self) + + def get_FLOPs(self, input_resolution): + the_res = input_resolution + the_flops = 0 + for the_block in self.block_list: + the_flops += the_block.get_FLOPs(the_res) + the_res = the_block.get_output_resolution(the_res) + + return the_flops + + def get_model_size(self): + the_size = 0 + for the_block in self.block_list: + the_size += the_block.get_model_size() + + return the_size + + def replace_block(self, block_id, new_block): + self.block_list[block_id] = new_block + if block_id < len(self.block_list): + self.block_list[block_id + 1].set_in_channels(new_block.out_channels) + + self.module_list = nn.Module(self.block_list) + + + +from PlainNet import basic_blocks +_all_netblocks_dict_ = basic_blocks.register_netblocks_dict(_all_netblocks_dict_) + +# from PlainNet import super_blocks +# _all_netblocks_dict_ = super_blocks.register_netblocks_dict(_all_netblocks_dict_) + +# from PlainNet import SuperResKXKX +# _all_netblocks_dict_ = SuperResKXKX.register_netblocks_dict(_all_netblocks_dict_) + +# from PlainNet import SuperResK1KXK1 +# _all_netblocks_dict_ = SuperResK1KXK1.register_netblocks_dict(_all_netblocks_dict_) + +# from PlainNet import SuperResIDWEXKX +# _all_netblocks_dict_ = SuperResIDWEXKX.register_netblocks_dict(_all_netblocks_dict_) + +""" \ No newline at end of file diff --git a/mmrazor/models/architectures/backbones/PlainNet/basic_blocks.py b/mmrazor/models/architectures/backbones/PlainNet/basic_blocks.py new file mode 100644 index 000000000..8e69e1f2c --- /dev/null +++ b/mmrazor/models/architectures/backbones/PlainNet/basic_blocks.py @@ -0,0 +1,2556 @@ +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np +import uuid +import global_utils + +# from PlainNet import _get_right_parentheses_index_, _create_netblock_list_from_str_, create_netblock_list_from_str +from mmrazor.registry import MODELS + +_all_netblocks_dict_ = {} + +def pretty_format(plainnet_str, indent=2): + the_formated_str = '' + indent_str = '' + if indent >= 1: + indent_str = ''.join([' '] * indent) + + # print(indent_str, end='') + the_formated_str += indent_str + + s = plainnet_str + while len(s) > 0: + if s[0] == ';': + # print(';\n' + indent_str, end='') + the_formated_str += ';\n' + indent_str + s = s[1:] + + left_par_idx = s.find('(') + assert left_par_idx is not None + right_par_idx = _get_right_parentheses_index_(s) + the_block_class_name = s[0:left_par_idx] + + if the_block_class_name in ['MultiSumBlock', 'MultiCatBlock','MultiGroupBlock']: + # print('\n' + indent_str + the_block_class_name + '(') + sub_str = s[left_par_idx + 1:right_par_idx] + + # find block_name + tmp_idx = sub_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'no_name' + else: + tmp_block_name = sub_str[0:tmp_idx] + sub_str = sub_str[tmp_idx+1:] + + if len(tmp_block_name) > 8: + tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] + + the_formated_str += '\n' + indent_str + the_block_class_name + '({}|\n'.format(tmp_block_name) + + the_formated_str += pretty_format(sub_str, indent + 1) + # print('\n' + indent_str + ')') + # print(indent_str, end='') + the_formated_str += '\n' + indent_str + ')\n' + indent_str + elif the_block_class_name in ['ResBlock']: + # print('\n' + indent_str + the_block_class_name + '(') + in_channels = None + the_stride = None + sub_str = s[left_par_idx + 1:right_par_idx] + # find block_name + tmp_idx = sub_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'no_name' + else: + tmp_block_name = sub_str[0:tmp_idx] + sub_str = sub_str[tmp_idx + 1:] + + first_comma_index = sub_str.find(',') + if first_comma_index < 0 or not sub_str[0:first_comma_index].isdigit(): + in_channels = None + else: + in_channels = int(sub_str[0:first_comma_index]) + sub_str = sub_str[first_comma_index+1:] + second_comma_index = sub_str.find(',') + if second_comma_index < 0 or not sub_str[0:second_comma_index].isdigit(): + the_stride = None + else: + the_stride = int(sub_str[0:second_comma_index]) + sub_str = sub_str[second_comma_index + 1:] + pass + pass + + if len(tmp_block_name) > 8: + tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] + + the_formated_str += '\n' + indent_str + the_block_class_name + '({}|'.format(tmp_block_name) + if in_channels is not None: + the_formated_str += '{},'.format(in_channels) + else: + the_formated_str += ',' + + if the_stride is not None: + the_formated_str += '{},'.format(the_stride) + else: + the_formated_str += ',' + + the_formated_str += '\n' + + the_formated_str += pretty_format(sub_str, indent + 1) + # print('\n' + indent_str + ')') + # print(indent_str, end='') + the_formated_str += '\n' + indent_str + ')\n' + indent_str + else: + # print(s[0:right_par_idx+1], end='') + sub_str = s[left_par_idx + 1:right_par_idx] + # find block_name + tmp_idx = sub_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'no_name' + else: + tmp_block_name = sub_str[0:tmp_idx] + sub_str = sub_str[tmp_idx + 1:] + + if len(tmp_block_name) > 8: + tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] + + the_formated_str += the_block_class_name + '({}|'.format(tmp_block_name) + sub_str + ')' + + s = s[right_par_idx+1:] + pass # end while + + return the_formated_str + +def _get_right_parentheses_index_(s): + # assert s[0] == '(' + left_paren_count = 0 + for index, x in enumerate(s): + + if x == '(': + left_paren_count += 1 + elif x == ')': + left_paren_count -= 1 + if left_paren_count == 0: + return index + else: + pass + return None + +def _create_netblock_list_from_str_(s, no_create=False, **kwargs): + block_list = [] + while len(s) > 0: + is_found_block_class = False + for the_block_class_name in _all_netblocks_dict_.keys(): + tmp_idx = s.find('(') + if tmp_idx > 0 and s[0:tmp_idx] == the_block_class_name: # classname + is_found_block_class = True + the_block_class = _all_netblocks_dict_[the_block_class_name] + the_block, remaining_s = the_block_class.create_from_str(s, no_create=no_create, **kwargs) + if the_block is not None: + block_list.append(the_block) + s = remaining_s + if len(s) > 0 and s[0] == ';': + return block_list, s[1:] + break + pass # end if + pass # end for + assert is_found_block_class + pass # end while + return block_list, '' + +def _build_netblock_list_from_str_(s, no_create=False, **kwargs): + block_list = [] + while len(s) > 0: + is_found_block_class = False + # for the_block_class_name in _all_netblocks_dict_.keys(): + tmp_idx = s.find('(') + # if tmp_idx > 0 and s[0:tmp_idx] == the_block_class_name: # classname + if tmp_idx > 0: # classname + is_found_block_class = True + + mutable_cfg = dict(type=s[0:tmp_idx]) + the_block_class = MODELS.module_dict[s[0:tmp_idx]] + the_block_cfg, remaining_s = the_block_class.create_from_str(s, no_create=no_create, **kwargs) + mutable_cfg.update(the_block_cfg) + # print(mutable_cfg) + the_block = MODELS.build(mutable_cfg) + # the_block_class = _all_netblocks_dict_[the_block_class_name] + # the_block, remaining_s = the_block_class.create_from_str(s, no_create=no_create, **kwargs) + if the_block is not None: + block_list.append(the_block) + s = remaining_s + if len(s) > 0 and s[0] == ';': + return block_list, s[1:] + # break + pass # end if + # pass # end for + assert is_found_block_class + pass # end while + return block_list, '' + +def create_netblock_list_from_str(s, no_create=False, **kwargs): + the_list, remaining_s = _create_netblock_list_from_str_(s, no_create=no_create, **kwargs) + assert len(remaining_s) == 0 + return the_list + +def build_netblock_list_from_str(s, no_create=False, **kwargs): + the_list, remaining_s = _build_netblock_list_from_str_(s, no_create=no_create, **kwargs) + assert len(remaining_s) == 0 + return the_list + +def add_SE_block(structure_str: str): + new_str = '' + RELU = 'RELU' + offset = 4 + + idx = structure_str.find(RELU) + while idx >= 0: + new_str += structure_str[0: idx] + structure_str = structure_str[idx:] + r_idx = _get_right_parentheses_index_(structure_str[offset:]) + offset + channels = structure_str[offset + 1:r_idx] + new_str += 'RELU({})SE({})'.format(channels, channels) + structure_str = structure_str[r_idx + 1:] + idx = structure_str.find(RELU) + pass + + new_str += structure_str + return new_str + +@MODELS.register_module() +class PlainNetBasicBlockClass(nn.Module): + def __init__(self, in_channels=None, out_channels=None, stride=1, no_create=False, block_name=None, **kwargs): + super(PlainNetBasicBlockClass, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.no_create = no_create + self.block_name = block_name + if self.block_name is None: + self.block_name = 'uuid{}'.format(uuid.uuid4().hex) + + def forward(self, x): + raise RuntimeError('Not implemented') + + def __str__(self): + return type(self).__name__ + '({},{},{})'.format(self.in_channels, self.out_channels, self.stride) + + def __repr__(self): + return type(self).__name__ + '({}|{},{},{})'.format(self.block_name, self.in_channels, self.out_channels, self.stride) + + def get_output_resolution(self, input_resolution): + raise RuntimeError('Not implemented') + + def get_FLOPs(self, input_resolution): + raise RuntimeError('Not implemented') + + def get_model_size(self): + raise RuntimeError('Not implemented') + + def set_in_channels(self, c): + raise RuntimeError('Not implemented') + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert PlainNetBasicBlockClass.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + # return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, + # block_name=tmp_block_name, no_create=no_create), s[idx + 1:] + return dict(in_channels=in_channels, out_channels=out_channels, stride=stride, + block_name=tmp_block_name, no_create=no_create), s[idx + 1:] + + @classmethod + def is_instance_from_str(cls, s): + if s.startswith(cls.__name__ + '(') and s[-1] == ')': + return True + else: + return False + +@MODELS.register_module() +class AdaptiveAvgPool(PlainNetBasicBlockClass): + def __init__(self, out_channels, output_size, no_create=False, **kwargs): + super(AdaptiveAvgPool, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.output_size = output_size + self.no_create = no_create + if not no_create: + self.netblock = nn.AdaptiveAvgPool2d(output_size=(self.output_size, self.output_size)) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return type(self).__name__ + '({},{})'.format(self.out_channels // self.output_size**2, self.output_size) + + def __repr__(self): + return type(self).__name__ + '({}|{},{})'.format(self.block_name, + self.out_channels // self.output_size ** 2, self.output_size) + + def get_output_resolution(self, input_resolution): + return self.output_size + + def get_FLOPs(self, input_resolution): + return 0 + + def get_model_size(self): + return 0 + + def set_in_channels(self, c): + self.in_channels = c + self.out_channels = c + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert AdaptiveAvgPool.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len('AdaptiveAvgPool('):idx] + + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + out_channels = int(param_str_split[0]) + output_size = int(param_str_split[1]) + return dict(out_channels=out_channels, output_size=output_size, + block_name=tmp_block_name, no_create=no_create), s[idx + 1:] + # return AdaptiveAvgPool(out_channels=out_channels, output_size=output_size, + # block_name=tmp_block_name, no_create=no_create), s[idx + 1:] + +@MODELS.register_module() +class BN(PlainNetBasicBlockClass): + def __init__(self, out_channels=None, copy_from=None, no_create=False, **kwargs): + super(BN, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.BatchNorm2d) + self.in_channels = copy_from.weight.shape[0] + self.out_channels = copy_from.weight.shape[0] + assert out_channels is None or out_channels == self.out_channels + self.netblock = copy_from + + else: + self.in_channels = out_channels + self.out_channels = out_channels + if no_create: + return + else: + self.netblock = nn.BatchNorm2d(num_features=self.out_channels) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'BN({})'.format(self.out_channels) + + def __repr__(self): + return 'BN({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + def get_FLOPs(self, input_resolution): + return input_resolution ** 2 * self.out_channels + + def get_model_size(self): + return self.out_channels + + def set_in_channels(self, c): + self.in_channels = c + self.out_channels = c + if not self.no_create: + self.netblock = nn.BatchNorm2d(num_features=self.out_channels) + self.netblock.train() + self.netblock.requires_grad_(True) + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert BN.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len('BN('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + out_channels = int(param_str) + # return BN(out_channels=out_channels, block_name=tmp_block_name, no_create=no_create), s[idx + 1:] + return dict(out_channels=out_channels, block_name=tmp_block_name, no_create=no_create), s[idx + 1:] + +@MODELS.register_module() +class ConvKX(PlainNetBasicBlockClass): + def __init__(self, in_channels=None, out_channels=None, kernel_size=None, stride=None, groups=1, copy_from=None, + no_create=False, **kwargs): + super(ConvKX, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.Conv2d) + self.in_channels = copy_from.in_channels + self.out_channels = copy_from.out_channels + self.kernel_size = copy_from.kernel_size[0] + self.stride = copy_from.stride[0] + self.groups = copy_from.groups + assert in_channels is None or in_channels == self.in_channels + assert out_channels is None or out_channels == self.out_channels + assert kernel_size is None or kernel_size == self.kernel_size + assert stride is None or stride == self.stride + self.netblock = copy_from + else: + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.groups = groups + self.kernel_size = kernel_size + self.padding = (self.kernel_size - 1) // 2 + if no_create or self.in_channels == 0 or self.out_channels == 0 or self.kernel_size == 0 \ + or self.stride == 0: + return + else: + self.netblock = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, + kernel_size=self.kernel_size, stride=self.stride, + padding=self.padding, bias=False, groups=self.groups) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return type(self).__name__ + '({},{},{},{})'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride) + + def __repr__(self): + return type(self).__name__ + '({}|{},{},{},{})'.format(self.block_name, self.in_channels, self.out_channels, self.kernel_size, self.stride) + + def get_output_resolution(self, input_resolution): + return input_resolution // self.stride + + def get_FLOPs(self, input_resolution): + return self.in_channels * self.out_channels * self.kernel_size ** 2 * input_resolution ** 2 // self.stride ** 2 // self.groups + + def get_model_size(self): + return self.in_channels * self.out_channels * self.kernel_size ** 2 // self.groups + + def set_in_channels(self, c): + self.in_channels = c + if not self.no_create: + self.netblock = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, + kernel_size=self.kernel_size, stride=self.stride, + padding=self.padding, bias=False) + self.netblock.train() + self.netblock.requires_grad_(True) + + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert cls.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + split_str = param_str.split(',') + in_channels = int(split_str[0]) + out_channels = int(split_str[1]) + kernel_size = int(split_str[2]) + stride = int(split_str[3]) + return dict(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, no_create=no_create, block_name=tmp_block_name), s[idx + 1:] + # return cls(in_channels=in_channels, out_channels=out_channels, + # kernel_size=kernel_size, stride=stride, no_create=no_create, block_name=tmp_block_name), s[idx + 1:] + +@MODELS.register_module() +class ConvDW(PlainNetBasicBlockClass): + def __init__(self, out_channels=None, kernel_size=None, stride=None, copy_from=None, + no_create=False, **kwargs): + super(ConvDW, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.Conv2d) + self.in_channels = copy_from.in_channels + self.out_channels = copy_from.out_channels + self.kernel_size = copy_from.kernel_size[0] + self.stride = copy_from.stride[0] + assert self.in_channels == self.out_channels + assert out_channels is None or out_channels == self.out_channels + assert kernel_size is None or kernel_size == self.kernel_size + assert stride is None or stride == self.stride + + self.netblock = copy_from + else: + + self.in_channels = out_channels + self.out_channels = out_channels + self.stride = stride + self.kernel_size = kernel_size + + self.padding = (self.kernel_size - 1) // 2 + if no_create or self.in_channels == 0 or self.out_channels == 0 or self.kernel_size == 0 \ + or self.stride == 0: + return + else: + self.netblock = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, + kernel_size=self.kernel_size, stride=self.stride, + padding=self.padding, bias=False, groups=self.in_channels) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'ConvDW({},{},{})'.format(self.out_channels, self.kernel_size, self.stride) + + def __repr__(self): + return 'ConvDW({}|{},{},{})'.format(self.block_name, self.out_channels, self.kernel_size, self.stride) + + def get_output_resolution(self, input_resolution): + return input_resolution // self.stride + + def get_FLOPs(self, input_resolution): + return self.out_channels * self.kernel_size ** 2 * input_resolution ** 2 // self.stride ** 2 + + def get_model_size(self): + return self.out_channels * self.kernel_size ** 2 + + def set_in_channels(self, c): + self.in_channels = c + self.out_channels=self.in_channels + if not self.no_create: + self.netblock = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, + kernel_size=self.kernel_size, stride=self.stride, + padding=self.padding, bias=False, groups=self.in_channels) + self.netblock.train() + self.netblock.requires_grad_(True) + + + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert ConvDW.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len('ConvDW('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + split_str = param_str.split(',') + out_channels = int(split_str[0]) + kernel_size = int(split_str[1]) + stride = int(split_str[2]) + return dict(out_channels=out_channels, + kernel_size=kernel_size, stride=stride, no_create=no_create, block_name=tmp_block_name), s[idx + 1:] + # return ConvDW(out_channels=out_channels, + # kernel_size=kernel_size, stride=stride, no_create=no_create, block_name=tmp_block_name), s[idx + 1:] + +@MODELS.register_module() +class ConvKXG2(ConvKX): + def __init__(self, in_channels=None, out_channels=None, kernel_size=None, stride=None, copy_from=None, + no_create=False, **kwargs): + super(ConvKXG2, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, copy_from=copy_from, no_create=no_create, + groups=2, **kwargs) + +@MODELS.register_module() +class ConvKXG4(ConvKX): + def __init__(self, in_channels=None, out_channels=None, kernel_size=None, stride=None, copy_from=None, + no_create=False, **kwargs): + super(ConvKXG4, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, copy_from=copy_from, no_create=no_create, + groups=4, **kwargs) + +@MODELS.register_module() +class ConvKXG8(ConvKX): + def __init__(self, in_channels=None, out_channels=None, kernel_size=None, stride=None, copy_from=None, + no_create=False, **kwargs): + super(ConvKXG8, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, copy_from=copy_from, no_create=no_create, + groups=8, **kwargs) + +@MODELS.register_module() +class ConvKXG16(ConvKX): + def __init__(self, in_channels=None, out_channels=None, kernel_size=None, stride=None, copy_from=None, + no_create=False, **kwargs): + super(ConvKXG16, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, copy_from=copy_from, no_create=no_create, + groups=16, **kwargs) + +@MODELS.register_module() +class ConvKXG32(ConvKX): + def __init__(self, in_channels=None, out_channels=None, kernel_size=None, stride=None, copy_from=None, + no_create=False, **kwargs): + super(ConvKXG32, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, copy_from=copy_from, no_create=no_create, + groups=32, **kwargs) + +@MODELS.register_module() +class Flatten(PlainNetBasicBlockClass): + def __init__(self, out_channels, no_create=False, **kwargs): + super(Flatten, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.no_create = no_create + + def forward(self, x): + return torch.flatten(x, 1) + + def __str__(self): + return 'Flatten({})'.format(self.out_channels) + + def __repr__(self): + return 'Flatten({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return 1 + + def get_FLOPs(self, input_resolution): + return 0 + + def get_model_size(self): + return 0 + + def set_in_channels(self, c): + self.in_channels = c + self.out_channels = c + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Flatten.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len('Flatten('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + # return Flatten(out_channels=out_channels, no_create=no_create, block_name=tmp_block_name), s[idx + 1:] + return dict(out_channels=out_channels, no_create=no_create, block_name=tmp_block_name), s[idx + 1:] + +@MODELS.register_module() +class Linear(PlainNetBasicBlockClass): + def __init__(self, in_channels=None, out_channels=None, bias=True, copy_from=None, + no_create=False, **kwargs): + super(Linear, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.Linear) + self.in_channels = copy_from.weight.shape[1] + self.out_channels = copy_from.weight.shape[0] + self.use_bias = copy_from.bias is not None + assert in_channels is None or in_channels == self.in_channels + assert out_channels is None or out_channels == self.out_channels + + self.netblock = copy_from + else: + + self.in_channels = in_channels + self.out_channels = out_channels + self.use_bias = bias + if not no_create: + self.netblock = nn.Linear(self.in_channels, self.out_channels, + bias=self.use_bias) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'Linear({},{},{})'.format(self.in_channels, self.out_channels, int(self.use_bias)) + + def __repr__(self): + return 'Linear({}|{},{},{})'.format(self.block_name, self.in_channels, self.out_channels, int(self.use_bias)) + + def get_output_resolution(self, input_resolution): + assert input_resolution == 1 + return 1 + + def get_FLOPs(self, input_resolution): + return self.in_channels * self.out_channels + + def get_model_size(self): + return self.in_channels * self.out_channels + int(self.use_bias) + + def set_in_channels(self, c): + self.in_channels = c + if not self.no_create: + self.netblock = nn.Linear(self.in_channels, self.out_channels, + bias=self.use_bias) + self.netblock.train() + self.netblock.requires_grad_(True) + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Linear.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len('Linear('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + split_str = param_str.split(',') + in_channels = int(split_str[0]) + out_channels = int(split_str[1]) + use_bias = int(split_str[2]) + + return dict(in_channels=in_channels, out_channels=out_channels, bias=use_bias == 1, + block_name=tmp_block_name, no_create=no_create), s[idx+1 :] + # return Linear(in_channels=in_channels, out_channels=out_channels, bias=use_bias == 1, + # block_name=tmp_block_name, no_create=no_create), s[idx+1 :] + +@MODELS.register_module() +class MaxPool(PlainNetBasicBlockClass): + def __init__(self, out_channels, kernel_size, stride, no_create=False, **kwargs): + super(MaxPool, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = (kernel_size - 1) // 2 + self.no_create = no_create + if not no_create: + self.netblock = nn.MaxPool2d(kernel_size=self.kernel_size, stride=self.stride, padding=self.padding) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'MaxPool({},{},{})'.format(self.out_channels, self.kernel_size, self.stride) + + def __repr__(self): + return 'MaxPool({}|{},{},{})'.format(self.block_name, self.out_channels, self.kernel_size, self.stride) + + def get_output_resolution(self, input_resolution): + return input_resolution // self.stride + + def get_FLOPs(self, input_resolution): + return 0 + + def get_model_size(self): + return 0 + + def set_in_channels(self, c): + self.in_channels = c + self.out_channels = c + if not self.no_create: + self.netblock = nn.MaxPool2d(kernel_size=self.kernel_size, stride=self.stride, padding=self.padding) + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert MaxPool.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len('MaxPool('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + out_channels = int(param_str_split[0]) + kernel_size = int(param_str_split[1]) + stride = int(param_str_split[2]) + return dict(out_channels=out_channels, kernel_size=kernel_size, stride=stride, no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + # return MaxPool(out_channels=out_channels, kernel_size=kernel_size, stride=stride, no_create=no_create, + # block_name=tmp_block_name), s[idx + 1:] + +@MODELS.register_module() +class Sequential(PlainNetBasicBlockClass): + def __init__(self, block_list, no_create=False, **kwargs): + super(Sequential, self).__init__(**kwargs) + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) + self.in_channels = block_list[0].in_channels + self.out_channels = block_list[-1].out_channels + self.no_create = no_create + res = 1024 + for block in self.block_list: + res = block.get_output_resolution(res) + self.stride = 1024 // res + + def forward(self, x): + output = x + for inner_block in self.block_list: + output = inner_block(output) + return output + + def __str__(self): + s = 'Sequential(' + for inner_block in self.block_list: + s += str(inner_block) + s += ')' + return s + + def __repr__(self): + return str(self) + + def get_output_resolution(self, input_resolution): + the_res = input_resolution + for the_block in self.block_list: + the_res = the_block.get_output_resolution(the_res) + return the_res + + def get_FLOPs(self, input_resolution): + the_res = input_resolution + the_flops = 0 + for the_block in self.block_list: + the_flops += the_block.get_FLOPs(the_res) + the_res = the_block.get_output_resolution(the_res) + return the_flops + + def get_model_size(self): + the_size = 0 + for the_block in self.block_list: + the_size += the_block.get_model_size() + + return the_size + + def set_in_channels(self, c): + self.in_channels = c + if len(self.block_list) == 0: + self.out_channels = c + return + + self.block_list[0].set_in_channels(c) + last_channels = self.block_list[0].out_channels + if len(self.block_list) >= 2 and isinstance(self.block_list[1], BN): + self.block_list[1].set_in_channels(last_channels) + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Sequential.is_instance_from_str(s) + the_right_paraen_idx = _get_right_parentheses_index_(s) + param_str = s[len('Sequential(')+1:the_right_paraen_idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + # the_block_list, remaining_s = _create_netblock_list_from_str_(param_str, no_create=no_create) + the_block_list, remaining_s = _build_netblock_list_from_str_(param_str, no_create=no_create) + assert len(remaining_s) == 0 + if the_block_list is None or len(the_block_list) == 0: + return None, '' + # return Sequential(block_list=the_block_list, no_create=no_create, block_name=tmp_block_name), '' + return dict(block_list=the_block_list, no_create=no_create, block_name=tmp_block_name), '' + +@MODELS.register_module() +class MultiSumBlock(PlainNetBasicBlockClass): + def __init__(self, block_list, no_create=False, **kwargs): + super(MultiSumBlock, self).__init__(**kwargs) + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) + self.in_channels = np.max([x.in_channels for x in block_list]) + self.out_channels = np.max([x.out_channels for x in block_list]) + self.no_create = no_create + + res = 1024 + res = self.block_list[0].get_output_resolution(res) + self.stride = 1024 // res + + def forward(self, x): + output = self.block_list[0](x) + for inner_block in self.block_list[1:]: + output2 = inner_block(x) + output = output + output2 + return output + + def __str__(self): + s = 'MultiSumBlock({}|'.format(self.block_name) + for inner_block in self.block_list: + s += str(inner_block) + ';' + s = s[:-1] + s += ')' + return s + + def __repr__(self): + return str(self) + + + def get_output_resolution(self, input_resolution): + the_res = self.block_list[0].get_output_resolution(input_resolution) + for the_block in self.block_list: + assert the_res == the_block.get_output_resolution(input_resolution) + + return the_res + + def get_FLOPs(self, input_resolution): + the_flops = 0 + for the_block in self.block_list: + the_flops += the_block.get_FLOPs(input_resolution) + + return the_flops + + def get_model_size(self): + the_size = 0 + for the_block in self.block_list: + the_size += the_block.get_model_size() + + return the_size + + def set_in_channels(self, c): + self.in_channels = c + for the_block in self.block_list: + the_block.set_in_channels(c) + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert MultiSumBlock.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len('MultiSumBlock('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + the_s = param_str + + the_block_list = [] + while len(the_s) > 0: + # tmp_block_list, remaining_s = _create_netblock_list_from_str_(the_s, no_create=no_create) + tmp_block_list, remaining_s = _build_netblock_list_from_str_(the_s, no_create=no_create) + the_s = remaining_s + if tmp_block_list is None: + pass + elif len(tmp_block_list) == 1: + the_block_list.append(tmp_block_list[0]) + else: + the_block_list.append(Sequential(block_list=tmp_block_list, no_create=no_create)) + pass # end while + + if len(the_block_list) == 0: + return None, s[idx+1:] + + # return MultiSumBlock(block_list=the_block_list, block_name=tmp_block_name, no_create=no_create), s[idx+1:] + return dict(block_list=the_block_list, block_name=tmp_block_name, no_create=no_create), s[idx+1:] + +@MODELS.register_module() +class MultiCatBlock(PlainNetBasicBlockClass): + def __init__(self, block_list, no_create=False, **kwargs): + super(MultiCatBlock, self).__init__(**kwargs) + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) + self.in_channels = np.max([x.in_channels for x in block_list]) + self.out_channels = np.sum([x.out_channels for x in block_list]) + self.no_create = no_create + + res = 1024 + res = self.block_list[0].get_output_resolution(res) + self.stride = 1024 // res + + def forward(self, x): + output_list = [] + for inner_block in self.block_list: + output = inner_block(x) + output_list.append(output) + + return torch.cat(output_list, dim=1) + + def __str__(self): + s = 'MultiCatBlock({}|'.format(self.block_name) + for inner_block in self.block_list: + s += str(inner_block) + ';' + + s = s[:-1] + s += ')' + return s + + def __repr__(self): + return str(self) + + def get_output_resolution(self, input_resolution): + the_res = self.block_list[0].get_output_resolution(input_resolution) + for the_block in self.block_list: + assert the_res == the_block.get_output_resolution(input_resolution) + + return the_res + + def get_FLOPs(self, input_resolution): + the_flops = 0 + for the_block in self.block_list: + the_flops += the_block.get_FLOPs(input_resolution) + + return the_flops + + def get_model_size(self): + the_size = 0 + for the_block in self.block_list: + the_size += the_block.get_model_size() + + return the_size + + def set_in_channels(self, c): + self.in_channels = c + for the_block in self.block_list: + the_block.set_in_channels(c) + self.out_channels = np.sum([x.out_channels for x in self.block_list]) + + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert MultiCatBlock.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len('MultiCatBlock('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + the_s = param_str + + the_block_list = [] + while len(the_s) > 0: + # tmp_block_list, remaining_s = _create_netblock_list_from_str_(the_s, no_create=no_create) + tmp_block_list, remaining_s = _build_netblock_list_from_str_(the_s, no_create=no_create) + the_s = remaining_s + if tmp_block_list is None: + pass + elif len(tmp_block_list) == 1: + the_block_list.append(tmp_block_list[0]) + else: + the_block_list.append(Sequential(block_list=tmp_block_list, no_create=no_create)) + pass # end if + pass # end while + + if len(the_block_list) == 0: + return None, s[idx+1:] + + return dict(block_list=the_block_list, block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + # return MultiCatBlock(block_list=the_block_list, block_name=tmp_block_name, + # no_create=no_create), s[idx + 1:] + +@MODELS.register_module() +class RELU(PlainNetBasicBlockClass): + def __init__(self, out_channels, no_create=False, **kwargs): + super(RELU, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.no_create = no_create + + def forward(self, x): + return F.relu(x) + + def __str__(self): + return 'RELU({})'.format(self.out_channels) + + def __repr__(self): + return 'RELU({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + def get_FLOPs(self, input_resolution): + return 0 + + def get_model_size(self): + return 0 + + def set_in_channels(self, c): + self.in_channels = c + self.out_channels = c + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert RELU.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len('RELU('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + # return RELU(out_channels=out_channels, no_create=no_create, block_name=tmp_block_name), s[idx+1:] + return dict(out_channels=out_channels, no_create=no_create, block_name=tmp_block_name), s[idx+1:] + +@MODELS.register_module() +class ResBlock(PlainNetBasicBlockClass): + ''' + ResBlock(in_channles, inner_blocks_str). If in_channels is missing, use block_list[0].in_channels as in_channels + ''' + def __init__(self, block_list, in_channels=None, stride=None, no_create=False, **kwargs): + super(ResBlock, self).__init__(**kwargs) + self.block_list = block_list + self.stride = stride + self.no_create = no_create + if not no_create: + self.module_list = nn.ModuleList(block_list) + + if in_channels is None: + self.in_channels = block_list[0].in_channels + else: + self.in_channels = in_channels + self.out_channels = block_list[-1].out_channels + + if self.stride is None: + tmp_input_res = 1024 + tmp_output_res = self.get_output_resolution(tmp_input_res) + self.stride = tmp_input_res // tmp_output_res + + self.proj = None + if self.stride > 1 or self.in_channels != self.out_channels: + self.proj = nn.Sequential( + nn.Conv2d(self.in_channels, self.out_channels, 1, self.stride), + nn.BatchNorm2d(self.out_channels), + ) + + def forward(self, x): + if len(self.block_list) == 0: + return x + + output = x + for inner_block in self.block_list: + output = inner_block(output) + + if self.proj is not None: + output = output + self.proj(x) + else: + output = output + x + + return output + + def __str__(self): + s = 'ResBlock({},{},'.format(self.in_channels, self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def __repr__(self): + s = 'ResBlock({}|{},{},'.format(self.block_name, self.in_channels, self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def get_output_resolution(self, input_resolution): + the_res = input_resolution + for the_block in self.block_list: + the_res = the_block.get_output_resolution(the_res) + + return the_res + + def get_FLOPs(self, input_resolution): + the_res = input_resolution + the_flops = 0 + for the_block in self.block_list: + the_flops += the_block.get_FLOPs(the_res) + the_res = the_block.get_output_resolution(the_res) + + if self.proj is not None: + the_flops += self.in_channels * self.out_channels * (the_res / self.stride) ** 2 + \ + (the_res / self.stride) ** 2 * self.out_channels + + return the_flops + + def get_model_size(self): + the_size = 0 + for the_block in self.block_list: + the_size += the_block.get_model_size() + + if self.proj is not None: + the_size += self.in_channels * self.out_channels + self.out_channels + + return the_size + + def set_in_channels(self, c): + self.in_channels = c + if len(self.block_list) == 0: + self.out_channels = c + return + + self.block_list[0].set_in_channels(c) + last_channels = self.block_list[0].out_channels + if len(self.block_list) >= 2 and \ + ( isinstance(self.block_list[0], ConvKX) or isinstance(self.block_list[0], ConvDW)) and \ + isinstance(self.block_list[1], BN): + self.block_list[1].set_in_channels(last_channels) + + self.proj = None + if not self.no_create: + if self.stride > 1 or self.in_channels != self.out_channels: + self.proj = nn.Sequential( + nn.Conv2d(self.in_channels, self.out_channels, 1, self.stride), + nn.BatchNorm2d(self.out_channels), + ) + self.proj.train() + self.proj.requires_grad_(True) + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert ResBlock.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + the_stride = None + param_str = s[len('ResBlock('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + first_comma_index = param_str.find(',') + if first_comma_index < 0 or not param_str[0:first_comma_index].isdigit(): # cannot parse in_channels, missing, use default + in_channels = None + # the_block_list, remaining_s = _create_netblock_list_from_str_(param_str, no_create=no_create) + the_block_list, remaining_s = _build_netblock_list_from_str_(param_str, no_create=no_create) + else: + in_channels = int(param_str[0:first_comma_index]) + param_str = param_str[first_comma_index+1:] + second_comma_index = param_str.find(',') + if second_comma_index < 0 or not param_str[0:second_comma_index].isdigit(): + # the_block_list, remaining_s = _create_netblock_list_from_str_(param_str, no_create=no_create) + the_block_list, remaining_s = _build_netblock_list_from_str_(param_str, no_create=no_create) + else: + the_stride = int(param_str[0:second_comma_index]) + param_str = param_str[second_comma_index + 1:] + # the_block_list, remaining_s = _create_netblock_list_from_str_(param_str, no_create=no_create) + the_block_list, remaining_s = _build_netblock_list_from_str_(param_str, no_create=no_create) + pass + pass + + assert len(remaining_s) == 0 + if the_block_list is None or len(the_block_list) == 0: + return None, s[idx+1:] + return dict(block_list=the_block_list, in_channels=in_channels, + stride=the_stride, no_create=no_create, block_name=tmp_block_name), s[idx+1:] + # return ResBlock(block_list=the_block_list, in_channels=in_channels, + # stride=the_stride, no_create=no_create, block_name=tmp_block_name), s[idx+1:] + +@MODELS.register_module() +class ResBlockProj(PlainNetBasicBlockClass): + ''' + ResBlockProj(in_channles, inner_blocks_str). If in_channels is missing, use block_list[0].in_channels as in_channels + ''' + def __init__(self, block_list, in_channels=None, stride=None, no_create=False, **kwargs): + super(ResBlockProj, self).__init__(**kwargs) + self.block_list = block_list + self.stride = stride + self.no_create = no_create + if not no_create: + self.module_list = nn.ModuleList(block_list) + + if in_channels is None: + self.in_channels = block_list[0].in_channels + else: + self.in_channels = in_channels + self.out_channels = block_list[-1].out_channels + + if self.stride is None: + tmp_input_res = 1024 + tmp_output_res = self.get_output_resolution(tmp_input_res) + self.stride = tmp_input_res // tmp_output_res + + + self.proj = nn.Sequential( + nn.Conv2d(self.in_channels, self.out_channels, 1, self.stride), + nn.BatchNorm2d(self.out_channels), + ) + + def forward(self, x): + if len(self.block_list) == 0: + return x + + output = x + for inner_block in self.block_list: + output = inner_block(output) + output = output + self.proj(x) + return output + + def __str__(self): + s = 'ResBlockProj({},{},'.format(self.in_channels, self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def __repr__(self): + s = 'ResBlockProj({}|{},{},'.format(self.block_name, self.in_channels, self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def get_output_resolution(self, input_resolution): + the_res = input_resolution + for the_block in self.block_list: + the_res = the_block.get_output_resolution(the_res) + + return the_res + + def get_FLOPs(self, input_resolution): + the_res = input_resolution + the_flops = 0 + for the_block in self.block_list: + the_flops += the_block.get_FLOPs(the_res) + the_res = the_block.get_output_resolution(the_res) + + if self.proj is not None: + the_flops += self.in_channels * self.out_channels * (the_res / self.stride) ** 2 + \ + (the_res / self.stride) ** 2 * self.out_channels + + return the_flops + + def get_model_size(self): + the_size = 0 + for the_block in self.block_list: + the_size += the_block.get_model_size() + + if self.proj is not None: + the_size += self.in_channels * self.out_channels + self.out_channels + + return the_size + + def set_in_channels(self, c): + self.in_channels = c + if len(self.block_list) == 0: + self.out_channels = c + return + + self.block_list[0].set_in_channels(c) + last_channels = self.block_list[0].out_channels + if len(self.block_list) >= 2 and \ + ( isinstance(self.block_list[0], ConvKX) or isinstance(self.block_list[0], ConvDW)) and \ + isinstance(self.block_list[1], BN): + self.block_list[1].set_in_channels(last_channels) + + self.proj = None + if not self.no_create: + if self.stride > 1 or self.in_channels != self.out_channels: + self.proj = nn.Sequential( + nn.Conv2d(self.in_channels, self.out_channels, 1, self.stride), + nn.BatchNorm2d(self.out_channels), + ) + self.proj.train() + self.proj.requires_grad_(True) + + + + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert ResBlockProj.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + the_stride = None + param_str = s[len('ResBlockProj('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + first_comma_index = param_str.find(',') + if first_comma_index < 0 or not param_str[0:first_comma_index].isdigit(): # cannot parse in_channels, missing, use default + in_channels = None + # the_block_list, remaining_s = _create_netblock_list_from_str_(param_str, no_create=no_create) + the_block_list, remaining_s = _build_netblock_list_from_str_(param_str, no_create=no_create) + else: + in_channels = int(param_str[0:first_comma_index]) + param_str = param_str[first_comma_index+1:] + second_comma_index = param_str.find(',') + if second_comma_index < 0 or not param_str[0:second_comma_index].isdigit(): + # the_block_list, remaining_s = _create_netblock_list_from_str_(param_str, no_create=no_create) + the_block_list, remaining_s = _build_netblock_list_from_str_(param_str, no_create=no_create) + else: + the_stride = int(param_str[0:second_comma_index]) + param_str = param_str[second_comma_index + 1:] + # the_block_list, remaining_s = _create_netblock_list_from_str_(param_str, no_create=no_create) + the_block_list, remaining_s = _build_netblock_list_from_str_(param_str, no_create=no_create) + pass + pass + + assert len(remaining_s) == 0 + if the_block_list is None or len(the_block_list) == 0: + return None, s[idx+1:] + return dict(block_list=the_block_list, in_channels=in_channels, + stride=the_stride, no_create=no_create, block_name=tmp_block_name), s[idx+1:] + # return ResBlockProj(block_list=the_block_list, in_channels=in_channels, + # stride=the_stride, no_create=no_create, block_name=tmp_block_name), s[idx+1:] + +@MODELS.register_module() +class SE(PlainNetBasicBlockClass): + def __init__(self, out_channels=None, copy_from=None, + no_create=False, **kwargs): + super(SE, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + raise RuntimeError('Not implemented') + else: + self.in_channels = out_channels + self.out_channels = out_channels + self.se_ratio = 0.25 + self.se_channels = max(1, int(round(self.out_channels * self.se_ratio))) + if no_create or self.out_channels == 0: + return + else: + self.netblock = nn.Sequential( + nn.AdaptiveAvgPool2d((1,1)), + nn.Conv2d(in_channels=self.out_channels, out_channels=self.se_channels, kernel_size=1, stride=1, + padding=0, bias=False), + nn.BatchNorm2d(self.se_channels), + nn.ReLU(), + nn.Conv2d(in_channels=self.se_channels, out_channels=self.out_channels, kernel_size=1, stride=1, + padding=0, bias=False), + nn.BatchNorm2d(self.out_channels), + nn.Sigmoid() + ) + + def forward(self, x): + se_x = self.netblock(x) + return se_x * x + + def __str__(self): + return 'SE({})'.format(self.out_channels) + + def __repr__(self): + return 'SE({}|{})'.format(self.block_name,self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + def get_FLOPs(self, input_resolution): + return self.in_channels * self.se_channels + self.se_channels * self.out_channels + self.out_channels + \ + self.out_channels * input_resolution ** 2 + + def get_model_size(self): + return self.in_channels * self.se_channels + 2 * self.se_channels + self.se_channels * self.out_channels + \ + 2 * self.out_channels + + def set_in_channels(self, c): + self.in_channels = c + if not self.no_create: + self.netblock = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(in_channels=self.out_channels, out_channels=self.se_channels, kernel_size=1, stride=1, + padding=0, bias=False), + nn.BatchNorm2d(self.se_channels), + nn.ReLU(), + nn.Conv2d(in_channels=self.se_channels, out_channels=self.out_channels, kernel_size=1, stride=1, + padding=0, bias=False), + nn.BatchNorm2d(self.out_channels), + nn.Sigmoid() + ) + self.netblock.train() + self.netblock.requires_grad_(True) + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert SE.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len('SE('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + return dict(out_channels=out_channels, no_create=no_create, block_name=tmp_block_name), s[idx + 1:] + # return SE(out_channels=out_channels, no_create=no_create, block_name=tmp_block_name), s[idx + 1:] + +@MODELS.register_module() +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + +@MODELS.register_module() +class Swish(PlainNetBasicBlockClass): + def __init__(self, out_channels=None, copy_from=None, + no_create=False, **kwargs): + super(Swish, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + raise RuntimeError('Not implemented') + else: + self.in_channels = out_channels + self.out_channels = out_channels + + def forward(self, x): + return SwishImplementation.apply(x) + + def __str__(self): + return 'Swish({})'.format(self.out_channels) + + def __repr__(self): + return 'Swish({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + def get_FLOPs(self, input_resolution): + return self.out_channels * input_resolution ** 2 + + def get_model_size(self): + return 0 + + def set_in_channels(self, c): + self.in_channels = c + self.out_channels = c + + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Swish.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len('Swish('):idx] + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + return dict(out_channels=out_channels, no_create=no_create, block_name=tmp_block_name), s[idx + 1:] + # return Swish(out_channels=out_channels, no_create=no_create, block_name=tmp_block_name), s[idx + 1:] + + +def _add_bn_layer_(block_list): + new_block_list = [] + for the_block in block_list: + if isinstance(the_block, ConvKX) or isinstance(the_block, ConvDW): + out_channels = the_block.out_channels + new_bn_block = BN(out_channels=out_channels, no_create=True) + new_seq_with_bn = Sequential(block_list=[the_block, new_bn_block], no_create=True) + new_block_list.append(new_seq_with_bn) + elif hasattr(the_block, 'block_list'): + new_block_list = _add_bn_layer_(the_block.block_list) + the_block.module_list = nn.ModuleList(new_block_list) + the_block.block_list = new_block_list + new_block_list.append(the_block) + else: + new_block_list.append(the_block) + pass + pass + + return new_block_list + + +def _remove_bn_layer_(block_list): + new_block_list = [] + for the_block in block_list: + if isinstance(the_block, BN): + continue + elif hasattr(the_block, 'block_list'): + new_block_list = _remove_bn_layer_(the_block.block_list) + the_block.module_list = nn.ModuleList(new_block_list) + the_block.block_list = new_block_list + new_block_list.append(the_block) + else: + new_block_list.append(the_block) + pass + pass + + return new_block_list + + +def _add_se_layer_(block_list): + new_block_list = [] + for the_block in block_list: + if isinstance(the_block, RELU): + out_channels = the_block.out_channels + new_se_block = SE(out_channels=out_channels, no_create=True) + new_seq_with_bn = Sequential(block_list=[the_block, new_se_block], no_create=True) + new_block_list.append(new_seq_with_bn) + elif hasattr(the_block, 'block_list'): + new_block_list = _add_se_layer_(the_block.block_list) + the_block.module_list = nn.ModuleList(new_block_list) + the_block.block_list = new_block_list + new_block_list.append(the_block) + else: + new_block_list.append(the_block) + pass + pass + + return new_block_list + +def _replace_relu_with_swish_layer_(block_list): + new_block_list = [] + for the_block in block_list: + if isinstance(the_block, RELU): + out_channels = the_block.out_channels + new_swish_block = Swish(out_channels=out_channels, no_create=True) + new_block_list.append(new_swish_block) + elif hasattr(the_block, 'block_list'): + new_block_list = _replace_relu_with_swish_layer_(the_block.block_list) + the_block.module_list = nn.ModuleList(new_block_list) + the_block.block_list = new_block_list + new_block_list.append(the_block) + else: + new_block_list.append(the_block) + pass + pass + + return new_block_list + +def _fuse_convkx_and_bn_(convkx, bn): + the_weight_scale = bn.weight / torch.sqrt(bn.running_var + bn.eps) + convkx.weight[:] = convkx.weight * the_weight_scale.view((-1, 1, 1, 1)) + the_bias_shift = (bn.weight * bn.running_mean) / \ + torch.sqrt(bn.running_var + bn.eps) + bn.weight[:] = 1 + bn.bias[:] = bn.bias - the_bias_shift + bn.running_var[:] = 1.0 - bn.eps + bn.running_mean[:] = 0.0 + + +def _fuse_bn_layer_for_blocks_list_(block_list): + last_block = None # type: ConvKX + with torch.no_grad(): + for the_block in block_list: + if isinstance(the_block, BN): + # assert isinstance(last_block, ConvKX) or isinstance(last_block, ConvDW) + if isinstance(last_block, ConvKX) or isinstance(last_block, ConvDW): + _fuse_convkx_and_bn_(last_block.netblock, the_block.netblock) + else: + print('--- warning! Cannot fuse BN={} because last_block={}'.format(the_block, last_block)) + + last_block = None + elif isinstance(the_block, ConvKX) or isinstance(the_block, ConvDW): + last_block = the_block + elif hasattr(the_block, 'block_list') and the_block.block_list is not None and \ + len(the_block.block_list) > 0: + _fuse_bn_layer_for_blocks_list_(the_block.block_list) + else: + pass + pass + pass + pass # end with + +@MODELS.register_module() +class PlainNetSuperBlockClass(PlainNetBasicBlockClass): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): + super(PlainNetSuperBlockClass, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.sub_layers = sub_layers + self.no_create = no_create + self.block_list = None + self.module_list = None + + def forward(self, x): # + output = x + for block in self.block_list: + output = block(output) + return output + + def __str__(self): + return type(self).__name__ + '({},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|{},{},{},{})'.format(self.block_name, self.in_channels, self.out_channels, + self.stride, self.sub_layers) + + def get_output_resolution(self, input_resolution): + resolution = input_resolution + for block in self.block_list: + resolution = block.get_output_resolution(resolution) + return resolution + + def get_FLOPs(self, input_resolution): + resolution = input_resolution + flops = 0.0 + for block in self.block_list: + flops += block.get_FLOPs(resolution) + resolution = block.get_output_resolution(resolution) + return flops + + + def get_model_size(self): + model_size = 0.0 + for block in self.block_list: + model_size += block.get_model_size() + return model_size + + def set_in_channels(self, c): + self.in_channels = c + if len(self.block_list) == 0: + self.out_channels = c + return + + self.block_list[0].set_in_channels(c) + last_channels = self.block_list[0].out_channels + if len(self.block_list) >= 2 and \ + (isinstance(self.block_list[0], ConvKX) or isinstance(self.block_list[0], ConvDW)) and \ + isinstance(self.block_list[1], BN): + self.block_list[1].set_in_channels(last_channels) + + def encode_structure(self): + return [self.out_channels, self.sub_layers] + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert cls.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + sub_layers = int(param_str_split[3]) + return dict(in_channels=in_channels, out_channels=out_channels, stride=stride, + sub_layers=sub_layers, block_name=tmp_block_name, no_create=no_create, + **kwargs),\ + s[idx + 1:] + # return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, + # sub_layers=sub_layers, block_name=tmp_block_name, no_create=no_create, + # **kwargs),\ + # s[idx + 1:] + +@MODELS.register_module() +class SuperConvKXBNRELU(PlainNetSuperBlockClass): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, kernel_size=None, + no_create=False, no_reslink=False, no_BN=False, **kwargs): + super(SuperConvKXBNRELU, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + + # if self.no_reslink: + # print('Warning! {} use no_reslink'.format(str(self))) + # if self.no_BN: + # print('Warning! {} use no_BN'.format(str(self))) + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + if not self.no_BN: + inner_str = 'ConvKX({},{},{},{})BN({})RELU({})'.format(last_channels, self.out_channels, + self.kernel_size, + current_stride, + self.out_channels, self.out_channels) + else: + inner_str = 'ConvKX({},{},{},{})RELU({})'.format(last_channels, self.out_channels, + self.kernel_size, + current_stride, + self.out_channels) + full_str += inner_str + + last_channels = out_channels + current_stride = 1 + pass + + # self.block_list = create_netblock_list_from_str(full_str, no_create=no_create, + self.block_list = build_netblock_list_from_str(full_str, no_create=no_create, + no_reslink=no_reslink, no_BN=no_BN) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def forward_pre_relu(self, x): + output = x + for block in self.block_list[0:-1]: + output = block(output) + return output + + def __str__(self): + return type(self).__name__ + '({},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|in={},out={},stride={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, self.sub_layers, self.kernel_size) + + def split(self, split_layer_threshold): + return str(self) + + def structure_scale(self, scale=1.0, channel_scale=None, sub_layer_scale=None): + if channel_scale is None: + channel_scale = scale + if sub_layer_scale is None: + sub_layer_scale = scale + + new_out_channels = global_utils.smart_round(self.out_channels * channel_scale) + new_sub_layers = max(1, round(self.sub_layers * sub_layer_scale)) + + return type(self).__name__ + '({},{},{},{})'.format(self.in_channels, new_out_channels, + self.stride, new_sub_layers) + +@MODELS.register_module() +class SuperConvK1BNRELU(SuperConvKXBNRELU): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): + super(SuperConvK1BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + sub_layers=sub_layers, + kernel_size=1, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperConvK3BNRELU(SuperConvKXBNRELU): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): + super(SuperConvK3BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperConvK5BNRELU(SuperConvKXBNRELU): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): + super(SuperConvK5BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperConvK7BNRELU(SuperConvKXBNRELU): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): + super(SuperConvK7BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResKXKX(PlainNetSuperBlockClass): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, kernel_size=None, + no_create=False, no_reslink=False, no_BN=False, use_se=False, **kwargs): + super(SuperResKXKX, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + if self.use_se: + print('---debug use_se in ' + str(self)) + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, self.bottleneck_channels, self.kernel_size, current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.out_channels, self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format(inner_str, out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format(inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + last_channels = out_channels + current_stride = 1 + pass + + # self.block_list = create_netblock_list_from_str(full_str, no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, **kwargs) + self.block_list = build_netblock_list_from_str(full_str, no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def forward_pre_relu(self, x): + output = x + for block in self.block_list[0:-1]: + output = block(output) + return output + + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, self.bottleneck_channels, self.sub_layers, self.kernel_size + ) + + def encode_structure(self): + return [self.out_channels, self.sub_layers, self.bottleneck_channels] + + def split(self, split_layer_threshold): + if self.sub_layers >= split_layer_threshold: + new_sublayers_1 = split_layer_threshold // 2 + new_sublayers_2 = self.sub_layers - new_sublayers_1 + new_block_str1 = type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, new_sublayers_1) + new_block_str2 = type(self).__name__ + '({},{},{},{},{})'.format(self.out_channels, self.out_channels, + 1, self.bottleneck_channels, + new_sublayers_2) + return new_block_str1 + new_block_str2 + else: + return str(self) + + def structure_scale(self, scale=1.0, channel_scale=None, sub_layer_scale=None): + if channel_scale is None: + channel_scale = scale + if sub_layer_scale is None: + sub_layer_scale = scale + + new_out_channels = global_utils.smart_round(self.out_channels * channel_scale) + new_bottleneck_channels = global_utils.smart_round(self.bottleneck_channels * channel_scale) + new_sub_layers = max(1, round(self.sub_layers * sub_layer_scale)) + + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, new_out_channels, + self.stride, new_bottleneck_channels, new_sub_layers) + + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return dict(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + block_name=tmp_block_name, **kwargs), s[idx + 1:] + # return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, + # bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + # block_name=tmp_block_name, **kwargs), s[idx + 1:] + +@MODELS.register_module() +class SuperResK3K3(SuperResKXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK3K3, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResK5K5(SuperResKXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK5K5, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResK7K7(SuperResKXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK7K7, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResK1KXK1(PlainNetSuperBlockClass): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None ,sub_layers=None, kernel_size=None, + no_create=False, no_reslink=False, no_BN=False, use_se=False, **kwargs): + super(SuperResK1KXK1, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + if self.use_se: + print('---debug use_se in ' + str(self)) + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + + # first bl-block with reslink + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, self.bottleneck_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.bottleneck_channels, + self.kernel_size, current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format(inner_str, out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format(inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + # second bl-block with reslink + inner_str = '' + inner_str += 'ConvKX({},{},{},{})'.format(self.out_channels, self.bottleneck_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.bottleneck_channels, + self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + res_str = 'ResBlock({})RELU({})'.format(inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + last_channels = out_channels + current_stride = 1 + pass + + # self.block_list = create_netblock_list_from_str(full_str, no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, **kwargs) + self.block_list = build_netblock_list_from_str(full_str, no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, self.bottleneck_channels, self.sub_layers, self.kernel_size + ) + + def encode_structure(self): + return [self.out_channels, self.sub_layers, self.bottleneck_channels] + + def split(self, split_layer_threshold): + if self.sub_layers >= split_layer_threshold: + new_sublayers_1 = split_layer_threshold // 2 + new_sublayers_2 = self.sub_layers - new_sublayers_1 + new_block_str1 = type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, new_sublayers_1) + new_block_str2 = type(self).__name__ + '({},{},{},{},{})'.format(self.out_channels, self.out_channels, + 1, self.bottleneck_channels, + new_sublayers_2) + return new_block_str1 + new_block_str2 + else: + return str(self) + + def structure_scale(self, scale=1.0, channel_scale=None, sub_layer_scale=None): + if channel_scale is None: + channel_scale = scale + if sub_layer_scale is None: + sub_layer_scale = scale + + new_out_channels = global_utils.smart_round(self.out_channels * channel_scale) + new_bottleneck_channels = global_utils.smart_round(self.bottleneck_channels * channel_scale) + new_sub_layers = max(1, round(self.sub_layers * sub_layer_scale)) + + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, new_out_channels, + self.stride, new_bottleneck_channels, new_sub_layers) + + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return dict(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + block_name=tmp_block_name, **kwargs),s[idx + 1:] + # return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, + # bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + # block_name=tmp_block_name, **kwargs),s[idx + 1:] + +@MODELS.register_module() +class SuperResK1K3K1(SuperResK1KXK1): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK1K3K1, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResK1K5K1(SuperResK1KXK1): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK1K5K1, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResK1K7K1(SuperResK1KXK1): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResK1K7K1, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWEXKX(PlainNetSuperBlockClass): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, + kernel_size=None, expension=None, + no_create=False, no_reslink=False, no_BN=False, use_se=False, **kwargs): + super(SuperResIDWEXKX, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.expension = expension + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + + self.use_se = use_se + if self.use_se: + print('---debug use_se in ' + str(self)) + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + # first DW + dw_channels = global_utils.smart_round(self.bottleneck_channels * self.expension, base=8) + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, dw_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + + inner_str += 'ConvDW({},{},{})'.format(dw_channels, self.kernel_size, current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + if self.use_se: + inner_str += 'SE({})'.format(dw_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(dw_channels, bottleneck_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(bottleneck_channels) + # inner_str += 'RELU({})'.format(bottleneck_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format(inner_str, self.out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format(inner_str, self.out_channels) + + else: + res_str = '{}RELU({})'.format(inner_str, self.out_channels) + + full_str += res_str + + # second DW + inner_str = '' + dw_channels = global_utils.smart_round(self.out_channels * self.expension, base=8) + inner_str += 'ConvKX({},{},{},{})'.format(bottleneck_channels, dw_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + + inner_str += 'ConvDW({},{},{})'.format(dw_channels, self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + if self.use_se: + inner_str += 'SE({})'.format(dw_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(dw_channels, self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + res_str = 'ResBlock({})RELU({})'.format(inner_str, self.out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, self.out_channels) + + full_str += res_str + last_channels = out_channels + current_stride = 1 + pass + + # self.block_list = create_netblock_list_from_str(full_str, no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, **kwargs) + self.block_list = build_netblock_list_from_str(full_str, no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, self.bottleneck_channels, self.sub_layers, self.kernel_size + ) + + def encode_structure(self): + return [self.out_channels, self.sub_layers, self.bottleneck_channels] + + def split(self, split_layer_threshold): + if self.sub_layers >= split_layer_threshold: + new_sublayers_1 = split_layer_threshold // 2 + new_sublayers_2 = self.sub_layers - new_sublayers_1 + new_block_str1 = type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.bottleneck_channels, new_sublayers_1) + new_block_str2 = type(self).__name__ + '({},{},{},{},{})'.format(self.out_channels, self.out_channels, + 1, self.bottleneck_channels, + new_sublayers_2) + return new_block_str1 + new_block_str2 + else: + return str(self) + + def structure_scale(self, scale=1.0, channel_scale=None, sub_layer_scale=None): + if channel_scale is None: + channel_scale = scale + if sub_layer_scale is None: + sub_layer_scale = scale + + new_out_channels = global_utils.smart_round(self.out_channels * channel_scale) + new_bottleneck_channels = global_utils.smart_round(self.bottleneck_channels * channel_scale) + new_sub_layers = max(1, round(self.sub_layers * sub_layer_scale)) + + return type(self).__name__ + '({},{},{},{},{})'.format(self.in_channels, new_out_channels, + self.stride, new_bottleneck_channels, new_sub_layers) + + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return dict(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + block_name=tmp_block_name, **kwargs),s[idx + 1:] + # return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, + # bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + # block_name=tmp_block_name, **kwargs),s[idx + 1:] + +@MODELS.register_module() +class SuperResIDWE1K3(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE1K3, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, expension=1.0, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWE2K3(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE2K3, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, expension=2.0, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWE4K3(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE4K3, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, expension=4.0, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWE6K3(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE6K3, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=3, expension=6.0, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWE1K5(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE1K5, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, expension=1.0, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWE2K5(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE2K5, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, expension=2.0, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWE4K5(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE4K5, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, expension=4.0, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWE6K5(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE6K5, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=5, expension=6.0, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWE1K7(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE1K7, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, expension=1.0, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWE2K7(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE2K7, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, expension=2.0, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWE4K7(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE4K7, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, expension=4.0, + no_create=no_create, **kwargs) + +@MODELS.register_module() +class SuperResIDWE6K7(SuperResIDWEXKX): + def __init__(self, in_channels=None, out_channels=None, stride=None, bottleneck_channels=None, sub_layers=None, no_create=False, **kwargs): + super(SuperResIDWE6K7, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + bottleneck_channels=bottleneck_channels, sub_layers=sub_layers, + kernel_size=7, expension=6.0, + no_create=no_create, **kwargs) + + + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'AdaptiveAvgPool': AdaptiveAvgPool, + 'BN': BN, + 'ConvDW': ConvDW, + 'ConvKX': ConvKX, + 'ConvKXG2': ConvKXG2, + 'ConvKXG4': ConvKXG4, + 'ConvKXG8': ConvKXG8, + 'ConvKXG16': ConvKXG16, + 'ConvKXG32': ConvKXG32, + 'Flatten': Flatten, + 'Linear': Linear, + 'MaxPool': MaxPool, + 'MultiSumBlock': MultiSumBlock, + 'MultiCatBlock': MultiCatBlock, + 'PlainNetBasicBlockClass': PlainNetBasicBlockClass, + 'RELU': RELU, + 'ResBlock': ResBlock, + 'ResBlockProj': ResBlockProj, + 'Sequential': Sequential, + 'SE': SE, + 'Swish': Swish, + # super_blocks + 'SuperConvK1BNRELU': SuperConvK1BNRELU, + 'SuperConvK3BNRELU': SuperConvK3BNRELU, + 'SuperConvK5BNRELU': SuperConvK5BNRELU, + 'SuperConvK7BNRELU': SuperConvK7BNRELU, + # SuperResKXKX + 'SuperResK3K3': SuperResK3K3, + 'SuperResK5K5': SuperResK5K5, + 'SuperResK7K7': SuperResK7K7, + # SuperResK1KXK1 + 'SuperResK1K3K1': SuperResK1K3K1, + 'SuperResK1K5K1': SuperResK1K5K1, + 'SuperResK1K7K1': SuperResK1K7K1, + # SuperResIDWEXKX + 'SuperResIDWE1K3': SuperResIDWE1K3, + 'SuperResIDWE2K3': SuperResIDWE2K3, + 'SuperResIDWE4K3': SuperResIDWE4K3, + 'SuperResIDWE6K3': SuperResIDWE6K3, + 'SuperResIDWE1K5': SuperResIDWE1K5, + 'SuperResIDWE2K5': SuperResIDWE2K5, + 'SuperResIDWE4K5': SuperResIDWE4K5, + 'SuperResIDWE6K5': SuperResIDWE6K5, + 'SuperResIDWE1K7': SuperResIDWE1K7, + 'SuperResIDWE2K7': SuperResIDWE2K7, + 'SuperResIDWE4K7': SuperResIDWE4K7, + 'SuperResIDWE6K7': SuperResIDWE6K7, + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict diff --git a/mmrazor/models/architectures/backbones/PlainNet/plainnet.py b/mmrazor/models/architectures/backbones/PlainNet/plainnet.py new file mode 100644 index 000000000..dde985156 --- /dev/null +++ b/mmrazor/models/architectures/backbones/PlainNet/plainnet.py @@ -0,0 +1,265 @@ +# import torch, argparse +# from torch import nn + +# def parse_cmd_options(argv, opt=None): +# parser = argparse.ArgumentParser() +# parser.add_argument('--plainnet_struct', type=str, default=None, help='PlainNet structure string') +# parser.add_argument('--plainnet_struct_txt', type=str, default=None, help='PlainNet structure file name') +# parser.add_argument('--num_classes', type=int, default=None, help='how to prune') +# module_opt, _ = parser.parse_known_args(argv) + +# return module_opt + + +# def _get_right_parentheses_index_(s): +# # assert s[0] == '(' +# left_paren_count = 0 +# for index, x in enumerate(s): + +# if x == '(': +# left_paren_count += 1 +# elif x == ')': +# left_paren_count -= 1 +# if left_paren_count == 0: +# return index +# else: +# pass +# return None + +# def pretty_format(plainnet_str, indent=2): +# the_formated_str = '' +# indent_str = '' +# if indent >= 1: +# indent_str = ''.join([' '] * indent) + +# # print(indent_str, end='') +# the_formated_str += indent_str + +# s = plainnet_str +# while len(s) > 0: +# if s[0] == ';': +# # print(';\n' + indent_str, end='') +# the_formated_str += ';\n' + indent_str +# s = s[1:] + +# left_par_idx = s.find('(') +# assert left_par_idx is not None +# right_par_idx = _get_right_parentheses_index_(s) +# the_block_class_name = s[0:left_par_idx] + +# if the_block_class_name in ['MultiSumBlock', 'MultiCatBlock','MultiGroupBlock']: +# # print('\n' + indent_str + the_block_class_name + '(') +# sub_str = s[left_par_idx + 1:right_par_idx] + +# # find block_name +# tmp_idx = sub_str.find('|') +# if tmp_idx < 0: +# tmp_block_name = 'no_name' +# else: +# tmp_block_name = sub_str[0:tmp_idx] +# sub_str = sub_str[tmp_idx+1:] + +# if len(tmp_block_name) > 8: +# tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] + +# the_formated_str += '\n' + indent_str + the_block_class_name + '({}|\n'.format(tmp_block_name) + +# the_formated_str += pretty_format(sub_str, indent + 1) +# # print('\n' + indent_str + ')') +# # print(indent_str, end='') +# the_formated_str += '\n' + indent_str + ')\n' + indent_str +# elif the_block_class_name in ['ResBlock']: +# # print('\n' + indent_str + the_block_class_name + '(') +# in_channels = None +# the_stride = None +# sub_str = s[left_par_idx + 1:right_par_idx] +# # find block_name +# tmp_idx = sub_str.find('|') +# if tmp_idx < 0: +# tmp_block_name = 'no_name' +# else: +# tmp_block_name = sub_str[0:tmp_idx] +# sub_str = sub_str[tmp_idx + 1:] + +# first_comma_index = sub_str.find(',') +# if first_comma_index < 0 or not sub_str[0:first_comma_index].isdigit(): +# in_channels = None +# else: +# in_channels = int(sub_str[0:first_comma_index]) +# sub_str = sub_str[first_comma_index+1:] +# second_comma_index = sub_str.find(',') +# if second_comma_index < 0 or not sub_str[0:second_comma_index].isdigit(): +# the_stride = None +# else: +# the_stride = int(sub_str[0:second_comma_index]) +# sub_str = sub_str[second_comma_index + 1:] +# pass +# pass + +# if len(tmp_block_name) > 8: +# tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] + +# the_formated_str += '\n' + indent_str + the_block_class_name + '({}|'.format(tmp_block_name) +# if in_channels is not None: +# the_formated_str += '{},'.format(in_channels) +# else: +# the_formated_str += ',' + +# if the_stride is not None: +# the_formated_str += '{},'.format(the_stride) +# else: +# the_formated_str += ',' + +# the_formated_str += '\n' + +# the_formated_str += pretty_format(sub_str, indent + 1) +# # print('\n' + indent_str + ')') +# # print(indent_str, end='') +# the_formated_str += '\n' + indent_str + ')\n' + indent_str +# else: +# # print(s[0:right_par_idx+1], end='') +# sub_str = s[left_par_idx + 1:right_par_idx] +# # find block_name +# tmp_idx = sub_str.find('|') +# if tmp_idx < 0: +# tmp_block_name = 'no_name' +# else: +# tmp_block_name = sub_str[0:tmp_idx] +# sub_str = sub_str[tmp_idx + 1:] + +# if len(tmp_block_name) > 8: +# tmp_block_name = tmp_block_name[0:4] + tmp_block_name[-4:] + +# the_formated_str += the_block_class_name + '({}|'.format(tmp_block_name) + sub_str + ')' + +# s = s[right_par_idx+1:] +# pass # end while + +# return the_formated_str + +# def _create_netblock_list_from_str_(s, no_create=False, **kwargs): +# block_list = [] +# while len(s) > 0: +# is_found_block_class = False +# for the_block_class_name in _all_netblocks_dict_.keys(): +# tmp_idx = s.find('(') +# if tmp_idx > 0 and s[0:tmp_idx] == the_block_class_name: +# is_found_block_class = True +# the_block_class = _all_netblocks_dict_[the_block_class_name] +# the_block, remaining_s = the_block_class.create_from_str(s, no_create=no_create, **kwargs) +# if the_block is not None: +# block_list.append(the_block) +# s = remaining_s +# if len(s) > 0 and s[0] == ';': +# return block_list, s[1:] +# break +# pass # end if +# pass # end for +# assert is_found_block_class +# pass # end while +# return block_list, '' + +# def create_netblock_list_from_str(s, no_create=False, **kwargs): +# the_list, remaining_s = _create_netblock_list_from_str_(s, no_create=no_create, **kwargs) +# assert len(remaining_s) == 0 +# return the_list + +# def add_SE_block(structure_str: str): +# new_str = '' +# RELU = 'RELU' +# offset = 4 + +# idx = structure_str.find(RELU) +# while idx >= 0: +# new_str += structure_str[0: idx] +# structure_str = structure_str[idx:] +# r_idx = _get_right_parentheses_index_(structure_str[offset:]) + offset +# channels = structure_str[offset + 1:r_idx] +# new_str += 'RELU({})SE({})'.format(channels, channels) +# structure_str = structure_str[r_idx + 1:] +# idx = structure_str.find(RELU) +# pass + +# new_str += structure_str +# return new_str + + +# class PlainNet(nn.Module): +# def __init__(self, argv=None, opt=None, num_classes=None, plainnet_struct=None, no_create=False, +# **kwargs): +# super(PlainNet, self).__init__() +# self.argv = argv +# self.opt = opt +# self.num_classes = num_classes +# self.plainnet_struct = plainnet_struct + +# self.module_opt = parse_cmd_options(self.argv) + +# if self.num_classes is None: +# self.num_classes = self.module_opt.num_classes + +# if self.plainnet_struct is None and self.module_opt.plainnet_struct is not None: +# self.plainnet_struct = self.module_opt.plainnet_struct + +# if self.plainnet_struct is None: +# # load structure from text file +# if hasattr(opt, 'plainnet_struct_txt') and opt.plainnet_struct_txt is not None: +# plainnet_struct_txt = opt.plainnet_struct_txt +# else: +# plainnet_struct_txt = self.module_opt.plainnet_struct_txt + +# if plainnet_struct_txt is not None: +# with open(plainnet_struct_txt, 'r') as fid: +# the_line = fid.readlines()[0].strip() +# self.plainnet_struct = the_line +# pass + +# if self.plainnet_struct is None: +# return + +# the_s = self.plainnet_struct # type: str + +# block_list, remaining_s = _create_netblock_list_from_str_(the_s, no_create=no_create, **kwargs) +# assert len(remaining_s) == 0 + +# self.block_list = block_list +# if not no_create: +# self.module_list = nn.ModuleList(block_list) # register + +# def forward(self, x): +# output = x +# for the_block in self.block_list: +# output = the_block(output) +# return output + +# def __str__(self): +# s = '' +# for the_block in self.block_list: +# s += str(the_block) +# return s + +# def __repr__(self): +# return str(self) + +# def get_FLOPs(self, input_resolution): +# the_res = input_resolution +# the_flops = 0 +# for the_block in self.block_list: +# the_flops += the_block.get_FLOPs(the_res) +# the_res = the_block.get_output_resolution(the_res) + +# return the_flops + +# def get_model_size(self): +# the_size = 0 +# for the_block in self.block_list: +# the_size += the_block.get_model_size() + +# return the_size + +# def replace_block(self, block_id, new_block): +# self.block_list[block_id] = new_block +# if block_id < len(self.block_list): +# self.block_list[block_id + 1].set_in_channels(new_block.out_channels) + +# self.module_list = nn.Module(self.block_list) diff --git a/mmrazor/models/architectures/backbones/PlainNet/super_blocks.py b/mmrazor/models/architectures/backbones/PlainNet/super_blocks.py new file mode 100644 index 000000000..4466e9f68 --- /dev/null +++ b/mmrazor/models/architectures/backbones/PlainNet/super_blocks.py @@ -0,0 +1,222 @@ +''' +Copyright (C) 2010-2021 Alibaba Group Holding Limited. +''' + + +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np +import uuid +import global_utils + +import PlainNet +from PlainNet import _get_right_parentheses_index_, basic_blocks + +class PlainNetSuperBlockClass(basic_blocks.PlainNetBasicBlockClass): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): + super(PlainNetSuperBlockClass, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.sub_layers = sub_layers + self.no_create = no_create + self.block_list = None + self.module_list = None + + def forward(self, x): # + output = x + for block in self.block_list: + output = block(output) + return output + + def __str__(self): + return type(self).__name__ + '({},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|{},{},{},{})'.format(self.block_name, self.in_channels, self.out_channels, + self.stride, self.sub_layers) + + def get_output_resolution(self, input_resolution): + resolution = input_resolution + for block in self.block_list: + resolution = block.get_output_resolution(resolution) + return resolution + + def get_FLOPs(self, input_resolution): + resolution = input_resolution + flops = 0.0 + for block in self.block_list: + flops += block.get_FLOPs(resolution) + resolution = block.get_output_resolution(resolution) + return flops + + + def get_model_size(self): + model_size = 0.0 + for block in self.block_list: + model_size += block.get_model_size() + return model_size + + def set_in_channels(self, c): + self.in_channels = c + if len(self.block_list) == 0: + self.out_channels = c + return + + self.block_list[0].set_in_channels(c) + last_channels = self.block_list[0].out_channels + if len(self.block_list) >= 2 and \ + (isinstance(self.block_list[0], basic_blocks.ConvKX) or isinstance(self.block_list[0], basic_blocks.ConvDW)) and \ + isinstance(self.block_list[1], basic_blocks.BN): + self.block_list[1].set_in_channels(last_channels) + + def encode_structure(self): + return [self.out_channels, self.sub_layers] + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert cls.is_instance_from_str(s) + idx = _get_right_parentheses_index_(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + # find block_name + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + sub_layers = int(param_str_split[3]) + return cls(in_channels=in_channels, out_channels=out_channels, stride=stride, + sub_layers=sub_layers, block_name=tmp_block_name, no_create=no_create, + **kwargs),\ + s[idx + 1:] + + +class SuperConvKXBNRELU(PlainNetSuperBlockClass): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, kernel_size=None, + no_create=False, no_reslink=False, no_BN=False, **kwargs): + super(SuperConvKXBNRELU, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + + # if self.no_reslink: + # print('Warning! {} use no_reslink'.format(str(self))) + # if self.no_BN: + # print('Warning! {} use no_BN'.format(str(self))) + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + if not self.no_BN: + inner_str = 'ConvKX({},{},{},{})BN({})RELU({})'.format(last_channels, self.out_channels, + self.kernel_size, + current_stride, + self.out_channels, self.out_channels) + else: + inner_str = 'ConvKX({},{},{},{})RELU({})'.format(last_channels, self.out_channels, + self.kernel_size, + current_stride, + self.out_channels) + full_str += inner_str + + last_channels = out_channels + current_stride = 1 + pass + + self.block_list = PlainNet.create_netblock_list_from_str(full_str, no_create=no_create, + no_reslink=no_reslink, no_BN=no_BN) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def forward_pre_relu(self, x): + output = x + for block in self.block_list[0:-1]: + output = block(output) + return output + + def __str__(self): + return type(self).__name__ + '({},{},{},{})'.format(self.in_channels, self.out_channels, + self.stride, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|in={},out={},stride={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, self.sub_layers, self.kernel_size) + + def split(self, split_layer_threshold): + return str(self) + + def structure_scale(self, scale=1.0, channel_scale=None, sub_layer_scale=None): + if channel_scale is None: + channel_scale = scale + if sub_layer_scale is None: + sub_layer_scale = scale + + new_out_channels = global_utils.smart_round(self.out_channels * channel_scale) + new_sub_layers = max(1, round(self.sub_layers * sub_layer_scale)) + + return type(self).__name__ + '({},{},{},{})'.format(self.in_channels, new_out_channels, + self.stride, new_sub_layers) + + + +class SuperConvK1BNRELU(SuperConvKXBNRELU): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): + super(SuperConvK1BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + sub_layers=sub_layers, + kernel_size=1, + no_create=no_create, **kwargs) + +class SuperConvK3BNRELU(SuperConvKXBNRELU): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): + super(SuperConvK3BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, **kwargs) + +class SuperConvK5BNRELU(SuperConvKXBNRELU): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): + super(SuperConvK5BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, **kwargs) + + +class SuperConvK7BNRELU(SuperConvKXBNRELU): + def __init__(self, in_channels=None, out_channels=None, stride=None, sub_layers=None, no_create=False, **kwargs): + super(SuperConvK7BNRELU, self).__init__(in_channels=in_channels, out_channels=out_channels, stride=stride, + sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, **kwargs) + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperConvK1BNRELU': SuperConvK1BNRELU, + 'SuperConvK3BNRELU': SuperConvK3BNRELU, + 'SuperConvK5BNRELU': SuperConvK5BNRELU, + 'SuperConvK7BNRELU': SuperConvK7BNRELU, + + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict \ No newline at end of file diff --git a/mmrazor/models/architectures/backbones/SearchSpace/search_space_IDW_fixfc.py b/mmrazor/models/architectures/backbones/SearchSpace/search_space_IDW_fixfc.py new file mode 100644 index 000000000..68261958c --- /dev/null +++ b/mmrazor/models/architectures/backbones/SearchSpace/search_space_IDW_fixfc.py @@ -0,0 +1,123 @@ +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +import itertools + +import global_utils +# from PlainNet import basic_blocks, super_blocks, SuperResKXKX, SuperResK1KXK1, SuperResIDWEXKX +from PlainNet import basic_blocks + +seach_space_block_type_list_list = [ + [basic_blocks.SuperResIDWE1K3, basic_blocks.SuperResIDWE2K3, basic_blocks.SuperResIDWE4K3, + basic_blocks.SuperResIDWE6K3, + basic_blocks.SuperResIDWE1K5, basic_blocks.SuperResIDWE2K5, basic_blocks.SuperResIDWE4K5, + basic_blocks.SuperResIDWE6K5, + basic_blocks.SuperResIDWE1K7, basic_blocks.SuperResIDWE2K7, basic_blocks.SuperResIDWE4K7, + basic_blocks.SuperResIDWE6K7], +] + +__block_type_round_channels_base_dict__ = { + basic_blocks.SuperResIDWE1K3: 8, + basic_blocks.SuperResIDWE2K3: 8, + basic_blocks.SuperResIDWE4K3: 8, + basic_blocks.SuperResIDWE6K3: 8, + basic_blocks.SuperResIDWE1K5: 8, + basic_blocks.SuperResIDWE2K5: 8, + basic_blocks.SuperResIDWE4K5: 8, + basic_blocks.SuperResIDWE6K5: 8, + basic_blocks.SuperResIDWE1K7: 8, + basic_blocks.SuperResIDWE2K7: 8, + basic_blocks.SuperResIDWE4K7: 8, + basic_blocks.SuperResIDWE6K7: 8, +} + +__block_type_min_channels_base_dict__ = { + basic_blocks.SuperResIDWE1K3: 8, + basic_blocks.SuperResIDWE2K3: 8, + basic_blocks.SuperResIDWE4K3: 8, + basic_blocks.SuperResIDWE6K3: 8, + basic_blocks.SuperResIDWE1K5: 8, + basic_blocks.SuperResIDWE2K5: 8, + basic_blocks.SuperResIDWE4K5: 8, + basic_blocks.SuperResIDWE6K5: 8, + basic_blocks.SuperResIDWE1K7: 8, + basic_blocks.SuperResIDWE2K7: 8, + basic_blocks.SuperResIDWE4K7: 8, + basic_blocks.SuperResIDWE6K7: 8, +} + + +def get_select_student_channels_list(out_channels): + the_list = [out_channels * 2.5, out_channels * 2, out_channels * 1.5, out_channels * 1.25, + out_channels, + out_channels / 1.25, out_channels / 1.5, out_channels / 2, out_channels / 2.5] + the_list = [max(8, x) for x in the_list] + the_list = [global_utils.smart_round(x, base=8) for x in the_list] + the_list = list(set(the_list)) + the_list.sort(reverse=True) + return the_list + + +def get_select_student_sublayers_list(sub_layers): + the_list = [sub_layers, + sub_layers + 1, sub_layers + 2, + sub_layers - 1, sub_layers - 2, ] + the_list = [max(0, round(x)) for x in the_list] + the_list = list(set(the_list)) + the_list.sort(reverse=True) + return the_list + + +def gen_search_space(block_list, block_id): + the_block = block_list[block_id] + student_blocks_list_list = [] + + if isinstance(the_block, basic_blocks.SuperConvKXBNRELU): + student_blocks_list = [] + + if the_block.kernel_size == 1: # last fc layer, never change fc + student_out_channels_list = [the_block.out_channels] + else: + student_out_channels_list = get_select_student_channels_list(the_block.out_channels) + + for student_out_channels in student_out_channels_list: + tmp_block_str = type(the_block).__name__ + '({},{},{},1)'.format( + the_block.in_channels, student_out_channels, the_block.stride) + student_blocks_list.append(tmp_block_str) + pass + student_blocks_list = list(set(student_blocks_list)) + assert len(student_blocks_list) >= 1 + student_blocks_list_list.append(student_blocks_list) + else: + for student_block_type_list in seach_space_block_type_list_list: + student_blocks_list = [] + student_out_channels_list = get_select_student_channels_list(the_block.out_channels) + student_sublayers_list = get_select_student_sublayers_list(sub_layers=the_block.sub_layers) + student_bottleneck_channels_list = get_select_student_channels_list(the_block.bottleneck_channels) + for student_block_type in student_block_type_list: + for student_out_channels, student_sublayers, student_bottleneck_channels in itertools.product( + student_out_channels_list, student_sublayers_list, student_bottleneck_channels_list): + + # filter smallest possible channel for this block type + min_possible_channels = __block_type_round_channels_base_dict__[student_block_type] + channel_round_base = __block_type_round_channels_base_dict__[student_block_type] + student_out_channels = global_utils.smart_round(student_out_channels, channel_round_base) + student_bottleneck_channels = global_utils.smart_round(student_bottleneck_channels, + channel_round_base) + + if student_out_channels < min_possible_channels or student_bottleneck_channels < min_possible_channels: + continue + if student_sublayers <= 0: # no empty layer + continue + tmp_block_str = student_block_type.__name__ + '({},{},{},{},{})'.format( + the_block.in_channels, student_out_channels, the_block.stride, student_bottleneck_channels, + student_sublayers) + student_blocks_list.append(tmp_block_str) + pass + student_blocks_list = list(set(student_blocks_list)) + assert len(student_blocks_list) >= 1 + student_blocks_list_list.append(student_blocks_list) + pass + pass # end for student_block_type_list in seach_space_block_type_list_list: + pass + return student_blocks_list_list diff --git a/mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py b/mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py new file mode 100644 index 000000000..9850633cd --- /dev/null +++ b/mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py @@ -0,0 +1,103 @@ +import os, sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +import itertools + +import global_utils +# from PlainNet import basic_blocks, super_blocks, SuperResKXKX, SuperResK1KXK1 +from PlainNet import basic_blocks + +seach_space_block_type_list_list = [ + [basic_blocks.SuperResK1K3K1, basic_blocks.SuperResK1K5K1, basic_blocks.SuperResK1K7K1], + [basic_blocks.SuperResK3K3, basic_blocks.SuperResK5K5, basic_blocks.SuperResK7K7], +] + +__block_type_round_channels_base_dict__ = { + basic_blocks.SuperResK3K3: 8, + basic_blocks.SuperResK5K5: 8, + basic_blocks.SuperResK7K7: 8, + basic_blocks.SuperResK1K3K1: 8, + basic_blocks.SuperResK1K5K1: 8, + basic_blocks.SuperResK1K7K1: 8, +} + +__block_type_min_channels_base_dict__ = { + basic_blocks.SuperResK3K3: 8, + basic_blocks.SuperResK5K5: 8, + basic_blocks.SuperResK7K7: 8, + basic_blocks.SuperResK1K3K1: 8, + basic_blocks.SuperResK1K5K1: 8, + basic_blocks.SuperResK1K7K1: 8, +} + + +def get_select_student_channels_list(out_channels): + the_list = [out_channels * 2.5, out_channels * 2, out_channels * 1.5, out_channels * 1.25, + out_channels, + out_channels / 1.25, out_channels / 1.5, out_channels / 2, out_channels / 2.5] + the_list = [min(2048, max(8, x)) for x in the_list] + the_list = [global_utils.smart_round(x, base=8) for x in the_list] + the_list = list(set(the_list)) + the_list.sort(reverse=True) + return the_list + + +def get_select_student_sublayers_list(sub_layers): + the_list = [sub_layers, + sub_layers + 1, sub_layers + 2, + sub_layers - 1, sub_layers - 2, ] + the_list = [max(0, round(x)) for x in the_list] + the_list = list(set(the_list)) + the_list.sort(reverse=True) + return the_list + + +def gen_search_space(block_list, block_id): + the_block = block_list[block_id] + student_blocks_list_list = [] + + if isinstance(the_block, basic_blocks.SuperConvKXBNRELU): + student_blocks_list = [] + student_out_channels_list = get_select_student_channels_list(the_block.out_channels) + for student_out_channels in student_out_channels_list: + tmp_block_str = type(the_block).__name__ + '({},{},{},1)'.format( + the_block.in_channels, student_out_channels, the_block.stride) + student_blocks_list.append(tmp_block_str) + pass + student_blocks_list = list(set(student_blocks_list)) + assert len(student_blocks_list) >= 1 + student_blocks_list_list.append(student_blocks_list) + else: + for student_block_type_list in seach_space_block_type_list_list: + student_blocks_list = [] + student_out_channels_list = get_select_student_channels_list(the_block.out_channels) + student_sublayers_list = get_select_student_sublayers_list(sub_layers=the_block.sub_layers) + student_bottleneck_channels_list = get_select_student_channels_list(the_block.bottleneck_channels) + for student_block_type in student_block_type_list: + for student_out_channels, student_sublayers, student_bottleneck_channels in itertools.product( + student_out_channels_list, student_sublayers_list, student_bottleneck_channels_list): + + # filter smallest possible channel for this block type + min_possible_channels = __block_type_round_channels_base_dict__[student_block_type] + channel_round_base = __block_type_round_channels_base_dict__[student_block_type] + student_out_channels = global_utils.smart_round(student_out_channels, channel_round_base) + student_bottleneck_channels = global_utils.smart_round(student_bottleneck_channels, + channel_round_base) + + if student_out_channels < min_possible_channels or student_bottleneck_channels < min_possible_channels: + continue + if student_sublayers <= 0: # no empty layer + continue + tmp_block_str = student_block_type.__name__ + '({},{},{},{},{})'.format( + the_block.in_channels, student_out_channels, the_block.stride, student_bottleneck_channels, + student_sublayers) + student_blocks_list.append(tmp_block_str) + pass + student_blocks_list = list(set(student_blocks_list)) + assert len(student_blocks_list) >= 1 + student_blocks_list_list.append(student_blocks_list) + pass + pass # end for student_block_type_list in seach_space_block_type_list_list: + pass + return student_blocks_list_list diff --git a/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_NASWOT_score.py b/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_NASWOT_score.py new file mode 100644 index 000000000..ffbb549e0 --- /dev/null +++ b/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_NASWOT_score.py @@ -0,0 +1,126 @@ +''' +Copyright (C) 2010-2021 Alibaba Group Holding Limited. + +The implementation of NASWOT score is modified from: +https://github.com/BayesWatch/nas-without-training +''' + + + +import os, sys, time +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +from torch import nn +import numpy as np +from PlainNet import basic_blocks +import global_utils, argparse, time + +def network_weight_gaussian_init(net: nn.Module): + with torch.no_grad(): + for m in net.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + else: + continue + + return net + +def logdet(K): + s, ld = np.linalg.slogdet(K) + return ld + +def get_batch_jacobian(net, x): + net.zero_grad() + x.requires_grad_(True) + y = net(x) + y.backward(torch.ones_like(y)) + jacob = x.grad.detach() + # return jacob, target.detach(), y.detach() + return jacob, y.detach() + +def compute_nas_score(gpu, model, resolution, batch_size): + if gpu is not None: + torch.cuda.set_device(gpu) + model = model.cuda(gpu) + + network_weight_gaussian_init(model) + input = torch.randn(size=[batch_size, 3, resolution, resolution]) + if gpu is not None: + input = input.cuda(gpu) + + model.K = np.zeros((batch_size, batch_size)) + + def counting_forward_hook(module, inp, out): + try: + if not module.visited_backwards: + return + if isinstance(inp, tuple): + inp = inp[0] + inp = inp.view(inp.size(0), -1) + x = (inp > 0).float() + K = x @ x.t() + K2 = (1. - x) @ (1. - x.t()) + model.K = model.K + K.cpu().numpy() + K2.cpu().numpy() + except Exception as err: + print('---- error on model : ') + print(model) + raise err + + + def counting_backward_hook(module, inp, out): + module.visited_backwards = True + + for name, module in model.named_modules(): + # if 'ReLU' in str(type(module)): + if isinstance(module, basic_blocks.RELU): + # hooks[name] = module.register_forward_hook(counting_hook) + module.visited_backwards = True + module.register_forward_hook(counting_forward_hook) + module.register_backward_hook(counting_backward_hook) + + x = input + jacobs, y = get_batch_jacobian(model, x) + + score = logdet(model.K) + + return float(score) + + + +def parse_cmd_options(argv): + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', type=int, default=16, help='number of instances in one mini-batch.') + parser.add_argument('--input_image_size', type=int, default=None, + help='resolution of input image, usually 32 for CIFAR and 224 for ImageNet.') + parser.add_argument('--repeat_times', type=int, default=32) + parser.add_argument('--gpu', type=int, default=None) + module_opt, _ = parser.parse_known_args(argv) + return module_opt + +if __name__ == "__main__": + import ModelLoader + opt = global_utils.parse_cmd_options(sys.argv) + args = parse_cmd_options(sys.argv) + the_model = ModelLoader.get_model(opt, sys.argv) + if args.gpu is not None: + the_model = the_model.cuda(args.gpu) + + + start_timer = time.time() + + for repeat_count in range(args.repeat_times): + the_score = compute_nas_score(gpu=args.gpu, model=the_model, + resolution=args.input_image_size, batch_size=args.batch_size) + + time_cost = (time.time() - start_timer) / args.repeat_times + + print(f'NASWOT={the_score:.4g}, time cost={time_cost:.4g} second(s)') \ No newline at end of file diff --git a/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_gradnorm_score.py b/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_gradnorm_score.py new file mode 100644 index 000000000..06bd914ce --- /dev/null +++ b/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_gradnorm_score.py @@ -0,0 +1,76 @@ +''' +Copyright (C) 2010-2021 Alibaba Group Holding Limited. +''' + + +import os, sys, time +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +from torch import nn +import numpy as np + +def network_weight_gaussian_init(net: nn.Module): + with torch.no_grad(): + for m in net.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + else: + continue + + return net + +import torch.nn.functional as F +def cross_entropy(logit, target): + # target must be one-hot format!! + prob_logit = F.log_softmax(logit, dim=1) + loss = -(target * prob_logit).sum(dim=1).mean() + return loss + +def compute_nas_score(gpu, model, resolution, batch_size): + + model.train() + model.requires_grad_(True) + + model.zero_grad() + + if gpu is not None: + torch.cuda.set_device(gpu) + model = model.cuda(gpu) + + network_weight_gaussian_init(model) + input = torch.randn(size=[batch_size, 3, resolution, resolution]) + if gpu is not None: + input = input.cuda(gpu) + output = model(input) + # y_true = torch.rand(size=[batch_size, output.shape[1]], device=torch.device('cuda:{}'.format(gpu))) + 1e-10 + # y_true = y_true / torch.sum(y_true, dim=1, keepdim=True) + + num_classes = output.shape[1] + y = torch.randint(low=0, high=num_classes, size=[batch_size]) + + one_hot_y = F.one_hot(y, num_classes).float() + if gpu is not None: + one_hot_y = one_hot_y.cuda(gpu) + + loss = cross_entropy(output, one_hot_y) + loss.backward() + norm2_sum = 0 + with torch.no_grad(): + for p in model.parameters(): + if hasattr(p, 'grad') and p.grad is not None: + norm2_sum += torch.norm(p.grad) ** 2 + + grad_norm = float(torch.sqrt(norm2_sum)) + + return grad_norm + + diff --git a/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_syncflow_score.py b/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_syncflow_score.py new file mode 100644 index 000000000..1222bcdc5 --- /dev/null +++ b/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_syncflow_score.py @@ -0,0 +1,138 @@ +''' +Copyright (C) 2010-2021 Alibaba Group Holding Limited. +This file is modified from +https://github.com/SamsungLabs/zero-cost-nas +''' + + +# Copyright 2021 Samsung Electronics Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + + +import os, sys, time +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +from torch import nn +import numpy as np + + + +import torch + +def network_weight_gaussian_init(net: nn.Module): + with torch.no_grad(): + for m in net.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + else: + continue + + return net + +def get_layer_metric_array(net, metric, mode): + metric_array = [] + + for layer in net.modules(): + if mode == 'channel' and hasattr(layer, 'dont_ch_prune'): + continue + if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): + metric_array.append(metric(layer)) + + return metric_array + + + +def compute_synflow_per_weight(net, inputs, mode): + device = inputs.device + + # convert params to their abs. Keep sign for converting it back. + @torch.no_grad() + def linearize(net): + signs = {} + for name, param in net.state_dict().items(): + signs[name] = torch.sign(param) + param.abs_() + return signs + + # convert to orig values + @torch.no_grad() + def nonlinearize(net, signs): + for name, param in net.state_dict().items(): + if 'weight_mask' not in name: + param.mul_(signs[name]) + + # keep signs of all params + signs = linearize(net) + + # Compute gradients with input of 1s + net.zero_grad() + net.double() + input_dim = list(inputs[0, :].shape) + inputs = torch.ones([1] + input_dim).double().to(device) + output = net.forward(inputs) + torch.sum(output).backward() + + # select the gradients that we want to use for search/prune + def synflow(layer): + if layer.weight.grad is not None: + return torch.abs(layer.weight * layer.weight.grad) + else: + return torch.zeros_like(layer.weight) + + grads_abs = get_layer_metric_array(net, synflow, mode) + + # apply signs of all params + nonlinearize(net, signs) + + return grads_abs + +def do_compute_nas_score(gpu, model, resolution, batch_size): + model.train() + model.requires_grad_(True) + + model.zero_grad() + + if gpu is not None: + torch.cuda.set_device(gpu) + model = model.cuda(gpu) + + network_weight_gaussian_init(model) + input = torch.randn(size=[batch_size, 3, resolution, resolution]) + if gpu is not None: + input = input.cuda(gpu) + + grads_abs_list = compute_synflow_per_weight(net=model, inputs=input, mode='') + score = 0 + for grad_abs in grads_abs_list: + if len(grad_abs.shape) == 4: + score += float(torch.mean(torch.sum(grad_abs, dim=[1,2,3]))) + elif len(grad_abs.shape) == 2: + score += float(torch.mean(torch.sum(grad_abs, dim=[1]))) + else: + raise RuntimeError('!!!') + + + return -1 * score + + diff --git a/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_te_nas_score.py b/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_te_nas_score.py new file mode 100644 index 000000000..7b7fe5216 --- /dev/null +++ b/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_te_nas_score.py @@ -0,0 +1,286 @@ +''' +Copyright (C) 2010-2021 Alibaba Group Holding Limited. + +This file is modified from: +https://github.com/VITA-Group/TENAS +''' + +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +from torch import nn +import global_utils, argparse, time + +class LinearRegionCount(object): + """Computes and stores the average and current value""" + def __init__(self, n_samples, gpu=None): + self.ActPattern = {} + self.n_LR = -1 + self.n_samples = n_samples + self.ptr = 0 + self.activations = None + self.gpu = gpu + + + @torch.no_grad() + def update2D(self, activations): + n_batch = activations.size()[0] + n_neuron = activations.size()[1] + self.n_neuron = n_neuron + if self.activations is None: + self.activations = torch.zeros(self.n_samples, n_neuron) + if self.gpu is not None: + self.activations = self.activations.cuda(self.gpu) + self.activations[self.ptr:self.ptr+n_batch] = torch.sign(activations) # after ReLU + self.ptr += n_batch + + @torch.no_grad() + def calc_LR(self): + res = torch.matmul(self.activations.half(), (1-self.activations).T.half()) + res += res.T + res = 1 - torch.sign(res) + res = res.sum(1) + res = 1. / res.float() + self.n_LR = res.sum().item() + del self.activations, res + self.activations = None + if self.gpu is not None: + torch.cuda.empty_cache() + + @torch.no_grad() + def update1D(self, activationList): + code_string = '' + for key, value in activationList.items(): + n_neuron = value.size()[0] + for i in range(n_neuron): + if value[i] > 0: + code_string += '1' + else: + code_string += '0' + if code_string not in self.ActPattern: + self.ActPattern[code_string] = 1 + + def getLinearReginCount(self): + if self.n_LR == -1: + self.calc_LR() + return self.n_LR + +class Linear_Region_Collector: + def __init__(self, models=[], input_size=(64, 3, 32, 32), gpu=None, + sample_batch=1, dataset=None, data_path=None, seed=0): + self.models = [] + self.input_size = input_size # BCHW + self.sample_batch = sample_batch + # self.input_numel = reduce(mul, self.input_size, 1) + self.interFeature = [] + self.dataset = dataset + self.data_path = data_path + self.seed = seed + self.gpu = gpu + self.device = torch.device('cuda:{}'.format(self.gpu)) if self.gpu is not None else torch.device('cpu') + # print('Using device:{}'.format(self.device)) + + self.reinit(models, input_size, sample_batch, seed) + + + def reinit(self, models=None, input_size=None, sample_batch=None, seed=None): + if models is not None: + assert isinstance(models, list) + del self.models + self.models = models + for model in self.models: + self.register_hook(model) + self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch, gpu=self.gpu) for _ in range(len(models))] + if input_size is not None or sample_batch is not None: + if input_size is not None: + self.input_size = input_size # BCHW + # self.input_numel = reduce(mul, self.input_size, 1) + if sample_batch is not None: + self.sample_batch = sample_batch + # if self.data_path is not None: + # self.train_data, _, class_num = get_datasets(self.dataset, self.data_path, self.input_size, -1) + # self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.input_size[0], num_workers=16, pin_memory=True, drop_last=True, shuffle=True) + # self.loader = iter(self.train_loader) + if seed is not None and seed != self.seed: + self.seed = seed + torch.manual_seed(seed) + if self.gpu is not None: + torch.cuda.manual_seed(seed) + del self.interFeature + self.interFeature = [] + if self.gpu is not None: + torch.cuda.empty_cache() + + def clear(self): + self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(self.models))] + del self.interFeature + self.interFeature = [] + if self.gpu is not None: + torch.cuda.empty_cache() + + def register_hook(self, model): + for m in model.modules(): + if isinstance(m, nn.ReLU): + m.register_forward_hook(hook=self.hook_in_forward) + + def hook_in_forward(self, module, input, output): + if isinstance(input, tuple) and len(input[0].size()) == 4: + self.interFeature.append(output.detach()) # for ReLU + + def forward_batch_sample(self): + for _ in range(self.sample_batch): + # try: + # inputs, targets = self.loader.next() + # except Exception: + # del self.loader + # self.loader = iter(self.train_loader) + # inputs, targets = self.loader.next() + inputs = torch.randn(self.input_size, device=self.device) + + for model, LRCount in zip(self.models, self.LRCounts): + self.forward(model, LRCount, inputs) + return [LRCount.getLinearReginCount() for LRCount in self.LRCounts] + + def forward(self, model, LRCount, input_data): + self.interFeature = [] + with torch.no_grad(): + # model.forward(input_data.cuda()) + model.forward(input_data) + if len(self.interFeature) == 0: return + feature_data = torch.cat([f.view(input_data.size(0), -1) for f in self.interFeature], 1) + LRCount.update2D(feature_data) + + +def compute_RN_score(model: nn.Module, batch_size=None, image_size=None, num_batch=None, gpu=None): + # # just debug + # gpu = 0 + # import ModelLoader + # model = ModelLoader._get_model_(arch='resnet18', num_classes=1000, pretrained=False, opt=None, argv=None) + # + # if gpu is not None: + # model = model.cuda(gpu) + lrc_model = Linear_Region_Collector(models=[model], input_size=(batch_size, 3, image_size, image_size), + gpu=gpu, sample_batch=num_batch) + num_linear_regions = float(lrc_model.forward_batch_sample()[0]) + del lrc_model + torch.cuda.empty_cache() + return num_linear_regions + + + +import numpy as np +import torch + + +def recal_bn(network, xloader, recalbn, device): + for m in network.modules(): + if isinstance(m, torch.nn.BatchNorm2d): + m.running_mean.data.fill_(0) + m.running_var.data.fill_(0) + m.num_batches_tracked.data.zero_() + m.momentum = None + network.train() + with torch.no_grad(): + for i, (inputs, targets) in enumerate(xloader): + if i >= recalbn: break + inputs = inputs.cuda(device=device, non_blocking=True) + _, _ = network(inputs) + return network + + +def get_ntk_n(networks, recalbn=0, train_mode=False, num_batch=None, + batch_size=None, image_size=None, gpu=None): + if gpu is not None: + device = torch.device('cuda:{}'.format(gpu)) + else: + device = torch.device('cpu') + + # if recalbn > 0: + # network = recal_bn(network, xloader, recalbn, device) + # if network_2 is not None: + # network_2 = recal_bn(network_2, xloader, recalbn, device) + ntks = [] + for network in networks: + if train_mode: + network.train() + else: + network.eval() + ###### + grads = [[] for _ in range(len(networks))] + + # for i, (inputs, targets) in enumerate(xloader): + # if num_batch > 0 and i >= num_batch: break + for i in range(num_batch): + inputs = torch.randn((batch_size, 3, image_size, image_size), device=device) + # inputs = inputs.cuda(device=device, non_blocking=True) + for net_idx, network in enumerate(networks): + network.zero_grad() + if gpu is not None: + inputs_ = inputs.clone().cuda(device=device, non_blocking=True) + else: + inputs_ = inputs.clone() + + logit = network(inputs_) + if isinstance(logit, tuple): + logit = logit[1] # 201 networks: return features and logits + for _idx in range(len(inputs_)): + logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True) + grad = [] + for name, W in network.named_parameters(): + if 'weight' in name and W.grad is not None: + grad.append(W.grad.view(-1).detach()) + grads[net_idx].append(torch.cat(grad, -1)) + network.zero_grad() + if gpu is not None: + torch.cuda.empty_cache() + + ###### + grads = [torch.stack(_grads, 0) for _grads in grads] + ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads] + conds = [] + for ntk in ntks: + eigenvalues, _ = torch.symeig(ntk) # ascending + # conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True, nan=100000.0)) + conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True)) + return conds + + + +def compute_NTK_score(gpu, model, resolution, batch_size): + ntk_score = get_ntk_n([model], recalbn=0, train_mode=True, num_batch=1, + batch_size=batch_size, image_size=resolution, gpu=gpu)[0] + return -1 * ntk_score + + + +def parse_cmd_options(argv): + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', type=int, default=16, help='number of instances in one mini-batch.') + parser.add_argument('--input_image_size', type=int, default=None, + help='resolution of input image, usually 32 for CIFAR and 224 for ImageNet.') + parser.add_argument('--repeat_times', type=int, default=32) + parser.add_argument('--gpu', type=int, default=None) + module_opt, _ = parser.parse_known_args(argv) + return module_opt + +if __name__ == "__main__": + import ModelLoader + opt = global_utils.parse_cmd_options(sys.argv) + args = parse_cmd_options(sys.argv) + the_model = ModelLoader.get_model(opt, sys.argv) + if args.gpu is not None: + the_model = the_model.cuda(args.gpu) + + + start_timer = time.time() + + for repeat_count in range(args.repeat_times): + ntk = compute_NTK_score(gpu=args.gpu, model=the_model, + resolution=args.input_image_size, batch_size=args.batch_size) + RN = compute_RN_score(model=the_model, batch_size=args.batch_size, image_size=args.input_image_size, + num_batch=1, gpu=args.gpu) + the_score = RN + ntk + time_cost = (time.time() - start_timer) / args.repeat_times + + print(f'ntk={the_score:.4g}, time cost={time_cost:.4g} second(s)') \ No newline at end of file diff --git a/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_zen_score.py b/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_zen_score.py new file mode 100644 index 000000000..1b1acef01 --- /dev/null +++ b/mmrazor/models/architectures/backbones/ZeroShotProxy/compute_zen_score.py @@ -0,0 +1,146 @@ +''' +Copyright (C) 2010-2021 Alibaba Group Holding Limited. +''' + + + +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +from torch import nn +import numpy as np +import global_utils, argparse, time + + +def network_weight_gaussian_init(net: nn.Module): + init_true=[] # 137 + init_false = [] # 233 + with torch.no_grad(): + for m in net.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + init_true.append(m._get_name()) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + init_true.append(m._get_name()) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + init_true.append(m._get_name()) + else: + init_false.append(m._get_name()) + continue + # print(init_true, init_false) + return net + +# ['Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Conv2d', 'BatchNorm2d', 'Linear'] +# ['MasterNet', 'ModuleList', 'SuperConvK3BNRELU', 'ModuleList', 'ConvKX', 'BN', 'RELU', 'SuperResK1K5K1', 'ModuleList', 'ConvKX', 'BN', 'RELU', 'ConvKX', 'BN', 'RELU', 'ConvKX', 'BN', 'RELU', 'ConvKX', 'BN', 'RELU', 'ConvKX', 'BN', 'RELU', 'ConvKX', 'BN', 'RELU', 'SuperResK3K3', 'ModuleList', 'ConvKX', 'BN', 'RELU', 'ConvKX', 'BN', 'RELU', 'SuperResK3K3', 'ModuleList', 'ConvKX', 'BN', 'RELU', 'ConvKX', 'BN', 'RELU', 'SuperResK3K3', 'ModuleList', 'ConvKX', 'BN', 'RELU', 'ConvKX', 'BN', 'RELU', 'SuperConvK1BNRELU', 'ModuleList', 'ConvKX', 'BN', 'RELU', 'Linear'] + + +# def network_weight_gaussian_init(net: nn.Module): +# with torch.no_grad(): +# for m in net.modules(): +# if isinstance(m, nn.Conv2d): +# nn.init.normal_(m.weight) +# if hasattr(m, 'bias') and m.bias is not None: +# nn.init.zeros_(m.bias) +# elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): +# nn.init.ones_(m.weight) +# nn.init.zeros_(m.bias) +# elif isinstance(m, nn.Linear): +# nn.init.normal_(m.weight) +# if hasattr(m, 'bias') and m.bias is not None: +# nn.init.zeros_(m.bias) +# else: +# continue + +# return net + +def compute_nas_score(gpu, model, mixup_gamma, resolution, batch_size, repeat, fp16=False): + info = {} + nas_score_list = [] + if gpu is not None: + device = torch.device('cuda:{}'.format(gpu)) + else: + device = torch.device('cpu') + + if fp16: + dtype = torch.half + else: + dtype = torch.float32 + + with torch.no_grad(): + for repeat_count in range(repeat): + network_weight_gaussian_init(model) + input = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype) # torch.mean(input) -> tensor(7.1894e-05, device='cuda:0') + input2 = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype) + mixup_input = input + mixup_gamma * input2 # mixup_gamma=0.01 torch.Size([64, 3, 224, 224]) + output = model.forward_pre_GAP(input) # torch.Size([64, 512, 7, 7]) # + mixup_output = model.forward_pre_GAP(mixup_input) # 给了0.01的输入扰动 + # 0.2 -> torch.Size([64, 2552, 7, 7]) / + nas_score = torch.sum(torch.abs(output - mixup_output), dim=[1, 2, 3]) # 输出绝对值 + nas_score = torch.mean(nas_score) # tensor(2535.5146, device='cuda:0') + # 0.2 -> 64366.5820 + # compute BN scaling + log_bn_scaling_factor = 0.0 + for m in model.modules(): + if isinstance(m, nn.BatchNorm2d): + bn_scaling_factor = torch.sqrt(torch.mean(m.running_var)) + log_bn_scaling_factor += torch.log(bn_scaling_factor) + pass + pass # 0.2 -> log(64366.5820) + 129.9849 = 141 + nas_score = torch.log(nas_score) + log_bn_scaling_factor # tensor(42.1751, device='cuda:0') + nas_score_list.append(float(nas_score)) + + # 一共68层BN + # index = 0 + # log_bn_scaling_factor = 0.0 + # for m in model.modules(): + # if isinstance(m, nn.BatchNorm2d): + # bn_scaling_factor = torch.sqrt(torch.mean(m.running_var)) + # log_bn_scaling_factor += torch.log(bn_scaling_factor) + # print('{}: {}->{}->{}, sum:{}'.format(index, torch.mean(m.running_var), bn_scaling_factor, torch.log(bn_scaling_factor), log_bn_scaling_factor)) + # index = index + 1 + + + std_nas_score = np.std(nas_score_list) + avg_precision = 1.96 * std_nas_score / np.sqrt(len(nas_score_list)) + avg_nas_score = np.mean(nas_score_list) # 42.175086975097656 + + + info['avg_nas_score'] = float(avg_nas_score) + info['std_nas_score'] = float(std_nas_score) + info['avg_precision'] = float(avg_precision) + return info + + +def parse_cmd_options(argv): + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', type=int, default=16, help='number of instances in one mini-batch.') + parser.add_argument('--input_image_size', type=int, default=None, + help='resolution of input image, usually 32 for CIFAR and 224 for ImageNet.') + parser.add_argument('--repeat_times', type=int, default=32) + parser.add_argument('--gpu', type=int, default=None) + parser.add_argument('--mixup_gamma', type=float, default=1e-2) + module_opt, _ = parser.parse_known_args(argv) + return module_opt + +if __name__ == "__main__": + import ModelLoader + opt = global_utils.parse_cmd_options(sys.argv) + args = parse_cmd_options(sys.argv) + the_model = ModelLoader.get_model(opt, sys.argv) + if args.gpu is not None: + the_model = the_model.cuda(args.gpu) + + + start_timer = time.time() + info = compute_nas_score(gpu=args.gpu, model=the_model, mixup_gamma=args.mixup_gamma, + resolution=args.input_image_size, batch_size=args.batch_size, repeat=args.repeat_times, fp16=False) + time_cost = (time.time() - start_timer) / args.repeat_times + zen_score = info['avg_nas_score'] + print(f'zen-score={zen_score:.4g}, time cost={time_cost:.4g} second(s)') \ No newline at end of file diff --git a/mmrazor/models/architectures/backbones/benchmark_network_latency.py b/mmrazor/models/architectures/backbones/benchmark_network_latency.py new file mode 100644 index 000000000..e0157a942 --- /dev/null +++ b/mmrazor/models/architectures/backbones/benchmark_network_latency.py @@ -0,0 +1,134 @@ +import os,sys, argparse +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import ModelLoader, global_utils +import train_image_classification as tic +import torch, time +import numpy as np + + +def __get_latency__(model, batch_size, resolution, channel, gpu, benchmark_repeat_times, fp16): + device = torch.device('cuda:{}'.format(gpu)) + torch.backends.cudnn.benchmark = True + + torch.cuda.set_device(gpu) + model = model.cuda(gpu) + if fp16: + model = model.half() + dtype = torch.float16 + else: + dtype = torch.float32 + + the_image = torch.randn(batch_size, channel, resolution, resolution, dtype=dtype, + device=device) + model.eval() + warmup_T = 3 + with torch.no_grad(): + for i in range(warmup_T): + the_output = model(the_image) + start_timer = time.time() + for repeat_count in range(benchmark_repeat_times): + the_output = model(the_image) + + end_timer = time.time() + the_latency = (end_timer - start_timer) / float(benchmark_repeat_times) / batch_size + return the_latency + + +def get_robust_latency_mean_std(model, batch_size, resolution, channel, gpu, benchmark_repeat_times=30, fp16=False): + robust_repeat_times = 10 + latency_list = [] + model = model.cuda(gpu) + for repeat_count in range(robust_repeat_times): + try: + the_latency = __get_latency__(model, batch_size, resolution, channel, gpu, benchmark_repeat_times, fp16) + except Exception as e: + print(e) + the_latency = np.inf + + latency_list.append(the_latency) + + pass # end for + latency_list.sort() + avg_latency = np.mean(latency_list[2:8]) + std_latency = np.std(latency_list[2:8]) + return avg_latency, std_latency + +def main(opt, argv): + global_utils.create_logging() + + batch_size_list = [int(x) for x in opt.batch_size_list.split(',')] + opt.batch_size = 1 + opt = tic.config_dist_env_and_opt(opt) + + # create model + model = ModelLoader.get_model(opt, argv) + + print('batch_size, latency_per_image') + + for the_batch_size_per_gpu in batch_size_list: + + the_latency, _ = get_robust_latency_mean_std(model=model, batch_size=the_batch_size_per_gpu, + resolution=opt.input_image_size, channel=3, gpu=opt.gpu, + benchmark_repeat_times=opt.repeat_times, + fp16=opt.fp16) + print('{},{:4g}'.format(the_batch_size_per_gpu, the_latency)) + + if opt.dist_mode == 'auto': + global_utils.release_gpu(opt.gpu) + + +def get_model_latency(model, batch_size, resolution, in_channels, gpu, repeat_times, fp16): + if gpu is not None: + device = torch.device('cuda:{}'.format(gpu)) + else: + device = torch.device('cpu') + + if fp16: + model = model.half() + dtype = torch.float16 + else: + dtype = torch.float32 + + the_image = torch.randn(batch_size, in_channels, resolution, resolution, dtype=dtype, + device=device) + + model.eval() + warmup_T = 3 + with torch.no_grad(): + for i in range(warmup_T): + the_output = model(the_image) + start_timer = time.time() + for repeat_count in range(repeat_times): + the_output = model(the_image) + + end_timer = time.time() + the_latency = (end_timer - start_timer) / float(repeat_times) / batch_size + return the_latency + + +def parse_cmd_options(argv): + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', type=int, default=None, help='number of instances in one mini-batch.') + parser.add_argument('--input_image_size', type=int, default=None, + help='resolution of input image, usually 32 for CIFAR and 224 for ImageNet.') + parser.add_argument('--save_dir', type=str, default=None, + help='output directory') + parser.add_argument('--repeat_times', type=int, default=1) + parser.add_argument('--gpu', type=int, default=None) + parser.add_argument('--fp16', action='store_true') + module_opt, _ = parser.parse_known_args(argv) + return module_opt + +if __name__ == "__main__": + opt = global_utils.parse_cmd_options(sys.argv) + args = parse_cmd_options(sys.argv) + the_model = ModelLoader.get_model(opt, sys.argv) + if args.gpu is not None: + the_model = the_model.cuda(args.gpu) + + + the_latency = get_model_latency(model=the_model, batch_size=args.batch_size, + resolution=args.input_image_size, + in_channels=3, gpu=args.gpu, repeat_times=args.repeat_times, + fp16=args.fp16) + print(f'{the_latency:.4g} second(s) per image, or {1.0/the_latency:.4g} image(s) per second.') diff --git a/mmrazor/models/architectures/backbones/global_utils.py b/mmrazor/models/architectures/backbones/global_utils.py new file mode 100644 index 000000000..042ac914b --- /dev/null +++ b/mmrazor/models/architectures/backbones/global_utils.py @@ -0,0 +1,282 @@ +import os +import distutils.dir_util +import pprint, ast, argparse, logging +import numpy as np +import torch + +def load_py_module_from_path(module_path, module_name=None): + if module_path.find(':') > 0: + split_path = module_path.split(':') + module_path = split_path[0] + function_name = split_path[1] + else: + function_name = None + + if module_name is None: + module_name = module_path.replace('/', '_').replace('.', '_') + + assert os.path.isfile(module_path) + + import importlib.util + spec = importlib.util.spec_from_file_location(module_name, module_path) + any_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(any_module) + if function_name is None: + return any_module + else: + return getattr(any_module, function_name) + +def mkfilepath(filename): + distutils.dir_util.mkpath(os.path.dirname(filename)) + +def mkdir(dirname): + distutils.dir_util.mkpath(dirname) + +def smart_round(x, base=None): + if base is None: + if x > 32 * 8: + round_base = 32 + elif x > 16 * 8: + round_base = 16 + else: + round_base = 8 + else: + round_base = base + + return max(round_base, round(x / float(round_base)) * round_base) + +def save_pyobj(filename, pyobj): + mkfilepath(filename) + the_s = pprint.pformat(pyobj, indent=2, width=120, compact=True) + with open(filename, 'w') as fid: + fid.write(the_s) + +def load_pyobj(filename): + with open(filename, 'r') as fid: + the_s = fid.readlines() + + if isinstance(the_s, list): + the_s = ''.join(the_s) + + the_s = the_s.replace('inf', '1e20') + pyobj = ast.literal_eval(the_s) + return pyobj + +def parse_cmd_options(argv): + + parser = argparse.ArgumentParser(description='Default command line parser.') + + parser.add_argument('--evaluate_only', action='store_true', help='Only evaluation.') + + # apex support + parser.add_argument('--apex', action='store_true', help='Mixed precision training using apex.') + parser.add_argument('--apex_loss_scale', type=str, default='dynamic', help='loss scale for apex.') + parser.add_argument('--apex_opt_level', type=str, default='O1') + parser.add_argument('--fp16', action='store_true', help='Using FP16.') + + # distributed training + parser.add_argument('--dist_mode', type=str, default='cpu', help='Distribution mode, could be cpu, single, horovod, mpi, auto.') + parser.add_argument('--independent_training', action='store_true', help='When distributed training, use each gpu separately.') + parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training') + parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') + + parser.add_argument('--gpu', default=None, type=int, help='GPU id to use. Used by torch.distributed package') + parser.add_argument('--sync_bn', action='store_true', help='Use synchronized BN.') + + parser.add_argument('--num_job_splits', default=None, type=str, help='Split jobs into multiple groups.') + parser.add_argument('--job_id', default=None, type=int, help='The id of this job node.') + + # horovod setting + parser.add_argument('--fp16_allreduce', action='store_true', help='use fp16 compression during allreduce.') + parser.add_argument('--batches_per_allreduce', + type=int, + default=1, + help='number of batches processed locally before ' + 'executing allreduce across workers; it multiplies ' + 'total batch size.') + + # learning rate setting + parser.add_argument('--lr', default=None, type=float, help='initial learning rate per 256 batch size') + parser.add_argument('--target_lr', default=None, type=float, help='target learning rate') + parser.add_argument('--lr_per_256', default=0.1, type=float, help='initial learning rate per 256 batch size') + parser.add_argument('--target_lr_per_256', default=0.0, type=float, help='target learning rate') + parser.add_argument('--lr_mode', default=None, type=str, help='learning rate decay mode.') + parser.add_argument('--warmup', default=0, type=int, help='epochs for warmup.') + parser.add_argument('--epoch_offset', default=0.0, type=float, help='Make the learning rate decaying as epochs + epoch_offset but start from epoch_offset. ') + + parser.add_argument('--lr_stage_list', default=None, type=str, help='stage-wise learning epoch list.') + parser.add_argument('--lr_stage_decay', default=None, type=float, help='stage-wise learning epoch list.') + + # optimizer + parser.add_argument('--optimizer', default='sgd', type=str, help='sgd optimizer') + parser.add_argument('--momentum', default=0.9, type=float, help='momentum') + parser.add_argument('--adadelta_rho', default=0.9, type=float) + parser.add_argument('--adadelta_eps', default=1e-9, type=float) + + parser.add_argument('--wd', + '--weight_decay', + default=4e-5, + type=float, + help='weight decay (default: 4e-5)', + dest='weight_decay') + + # training settings + + parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint (default: none)') + parser.add_argument('--auto_resume', action='store_true', help='auto resume from latest check point') + parser.add_argument('--load_parameters_from', default=None, type=str, help='Only load parameters from pth file.') + parser.add_argument('--strict_load', action='store_true', help='Mixed precision training using apex.') + + parser.add_argument('--start_epoch', default=0, type=int, help='manual epoch number (useful on restarts)') + parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run') + parser.add_argument('--save_dir', default=None, type=str, help='where to save models.') + parser.add_argument('--save_freq', default=10, type=int, help='How many epochs to save a model.') + parser.add_argument('--print_freq', default=100, type=int, help='print frequency (default: 100)') + + # training tricks + parser.add_argument('--label_smoothing', action='store_true') + parser.add_argument('--weight_init', type=str, default='None', help='How to initialize parameters') + parser.add_argument('--nesterov', action='store_true') + parser.add_argument('--grad_clip', type=float, default=None) + + # BN layer + parser.add_argument('--bn_momentum', type=float, default=None) + parser.add_argument('--bn_eps', type=float, default=None) + + # data augmentation + parser.add_argument('--mixup', action='store_true') + parser.add_argument('--random_erase', action='store_true') + parser.add_argument('--auto_augment', action='store_true') + parser.add_argument('--no_data_augment', action='store_true') + + # for loading dataset + parser.add_argument('--data_dir', type=str, default=None, help='path to dataset') + parser.add_argument('--dataset', type=str, default=None, help='name of the dataset') + parser.add_argument('--workers_per_gpu', + default=6, + type=int, + help='number of data loading workers per gpu. default 6.') + parser.add_argument( + '--batch_size', + default=None, + type=int, + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel', + ) + + parser.add_argument('--batch_size_per_gpu', default=None, type=int, help='batch size per GPU.') + parser.add_argument('--auto_batch_size', action='store_true', help='allow adjust batch size smartly.') + parser.add_argument('--num_cv_folds', type=int, default=None, help='Number of cross-validation folds.') + parser.add_argument('--cv_id', type=int, default=None, help='Current ID of cross-validation fold.') + parser.add_argument('--input_image_size', type=int, default=224, help='input image size.') + parser.add_argument('--input_image_crop', type=float, default=0.875, help='crop ratio of input image') + + # for loading model + parser.add_argument('--arch', default=None, help='model names/module to load') + parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') + parser.add_argument('--num_classes', type=int, default=None, help='number of classes.') + + # for testing + parser.add_argument('--dataloader_testing', action='store_true', help='Testing data loader.') + + # for teacher-student distillation + parser.add_argument('--teacher_input_image_size', type=int, default=None) + parser.add_argument('--teacher_arch', type=str, default=None) + parser.add_argument('--teacher_pretrained', action='store_true') + parser.add_argument('--ts_proj_no_relu', action='store_true') + parser.add_argument('--ts_proj_no_bn', action='store_true') + parser.add_argument('--teacher_load_parameters_from', type=str, default=None) + parser.add_argument('--teacher_feature_weight', type=float, default=None) + parser.add_argument('--teacher_logit_weight', type=float, default=None) + parser.add_argument('--ts_clip', type=float, default=None) + parser.add_argument('--target_downsample_ratio', type=int, default=None) + + opt, _ = parser.parse_known_args(argv) + return opt + +def create_logging(log_filename=None, level=logging.INFO): + if log_filename is not None: + mkfilepath(log_filename) + logging.basicConfig( + level=level, + format="%(message)s", + handlers=[ + logging.FileHandler(log_filename), + logging.StreamHandler() + ] + ) + else: + logging.basicConfig( + level=level, + format="%(message)s", + handlers=[ + logging.StreamHandler() + ] + ) + +class LearningRateScheduler(): + def __init__(self, + mode, + lr, + target_lr=None, + num_training_instances=None, + stop_epoch=None, + warmup_epoch=None, + stage_list=None, + stage_decay=None, + ): + self.mode = mode + self.lr = lr + self.target_lr = target_lr if target_lr is not None else 0 + self.num_training_instances = num_training_instances if num_training_instances is not None else 1 + self.stop_epoch = stop_epoch if stop_epoch is not None else np.inf + self.warmup_epoch = warmup_epoch if warmup_epoch is not None else 0 + self.stage_list = stage_list if stage_list is not None else None + self.stage_decay = stage_decay if stage_decay is not None else 0 + + self.num_received_training_instances = 0 + + if self.stage_list is not None: + self.stage_list = [int(x) for x in self.stage_list.split(',')] + + def update_lr(self, batch_size): + self.num_received_training_instances += batch_size + + def get_lr(self, num_received_training_instances=None): + if num_received_training_instances is None: + num_received_training_instances = self.num_received_training_instances + + # start_instances = self.num_training_instances * self.start_epoch + stop_instances = self.num_training_instances * self.stop_epoch + warmup_instances = self.num_training_instances * self.warmup_epoch + + assert stop_instances > warmup_instances + + current_epoch = self.num_received_training_instances // self.num_training_instances + + if num_received_training_instances < warmup_instances: + return float(num_received_training_instances + 1) / float(warmup_instances) * self.lr + + ratio_epoch = float(num_received_training_instances - warmup_instances + 1) / \ + float(stop_instances - warmup_instances) + + if self.mode == 'cosine': + factor = (1 - np.math.cos(np.math.pi * ratio_epoch)) / 2.0 + return self.lr + (self.target_lr - self.lr) * factor + elif self.mode == 'stagedecay': + stage_lr = self.lr + for stage_epoch in self.stage_list: + if current_epoch <= stage_epoch: + return stage_lr + else: + stage_lr *= self.stage_decay + pass # end if + pass # end for + return stage_lr + elif self.mode == 'linear': + factor = ratio_epoch + return self.lr + (self.target_lr - self.lr) * factor + else: + raise RuntimeError('Unknown learning rate mode: ' + self.mode) + pass # end if \ No newline at end of file diff --git a/mmrazor/models/architectures/backbones/masternet.py b/mmrazor/models/architectures/backbones/masternet.py new file mode 100644 index 000000000..1614e81e2 --- /dev/null +++ b/mmrazor/models/architectures/backbones/masternet.py @@ -0,0 +1,601 @@ +import os, sys +# sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +import numpy as np +import torch, argparse +from torch import nn +import torch.nn.functional as F + +import argparse, random, logging, time +import torch +from torch import nn +import numpy as np + +import global_utils +# import Masternet +# import PlainNet + +# import PlainNet. +# from PlainNet import parse_cmd_options, _create_netblock_list_from_str_, basic_blocks, super_blocks +# from PlainNet import create_netblock_list_from_str, _create_netblock_list_from_str_, basic_blocks, super_blocks +# from PlainNet.basic_blocks import create_netblock_list_from_str, _create_netblock_list_from_str_, _build_netblock_list_from_str_, build_netblock_list_from_str +from PlainNet.basic_blocks import _build_netblock_list_from_str_, build_netblock_list_from_str +from PlainNet import Linear, PlainNetSuperBlockClass, ResBlock, ResBlockProj, BN + +# import benchmark_network_latency +from ZeroShotProxy import compute_zen_score, compute_te_nas_score, compute_syncflow_score, compute_gradnorm_score, compute_NASWOT_score + + +class SuperBlock(nn.Module): + def __init__(self, fix_subnet=None, argv=None, opt=None, num_classes=None, plainnet_struct=None, no_create=False, + **kwargs): + super(SuperBlock, self).__init__() + self.argv = argv + self.opt = opt + self.num_classes = num_classes + self.plainnet_struct = plainnet_struct + self.plainnet_struct_txt = fix_subnet + + if self.plainnet_struct_txt is not None: + with open(self.plainnet_struct_txt, 'r') as fid: + the_line = fid.readlines()[0].strip() + self.plainnet_struct = the_line + pass + # self.module_opt = self.parse_cmd_options_2(self.argv) + + # if self.num_classes is None: + # self.num_classes = self.module_opt.num_classes + + # if self.plainnet_struct is None and self.module_opt.plainnet_struct is not None: + # self.plainnet_struct = self.module_opt.plainnet_struct + + # if self.plainnet_struct is None: + # # load structure from text file + # if hasattr(opt, 'plainnet_struct_txt') and opt.plainnet_struct_txt is not None: + # plainnet_struct_txt = opt.plainnet_struct_txt + # else: + # plainnet_struct_txt = self.module_opt.plainnet_struct_txt + + # if plainnet_struct_txt is not None: + # with open(plainnet_struct_txt, 'r') as fid: + # the_line = fid.readlines()[0].strip() + # self.plainnet_struct = the_line + # pass + + if self.plainnet_struct is None: + return + + the_s = self.plainnet_struct # type: str + # SuperConvK3BNRELU(3,32,2,1) + # SuperResK3K3(32,64,2,32,1) + # SuperResK3K3(64,128,2,64,1) + # SuperResK3K3(128,256,2,128,1) + # SuperResK3K3(256,512,2,256,1) + # SuperConvK1BNRELU(256,512,1,1) + block_list, remaining_s = _build_netblock_list_from_str_(the_s, no_create=no_create, **kwargs) + # block_list, remaining_s = _create_netblock_list_from_str_(the_s, no_create=no_create, **kwargs) + assert len(remaining_s) == 0 + + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) # register + + # def parse_cmd_options_2(self, argv, opt=None): + # parser = argparse.ArgumentParser() + # parser.add_argument('--plainnet_struct', type=str, default=None, help='PlainNet structure string') + # parser.add_argument('--plainnet_struct_txt', type=str, default=None, help='PlainNet structure file name') + # parser.add_argument('--num_classes', type=int, default=None, help='how to prune') + # module_opt, _ = parser.parse_known_args(argv) + + # return module_opt + + def forward(self, x): + output = x + for the_block in self.block_list: + output = the_block(output) + return output + + def __str__(self): + s = '' + for the_block in self.block_list: + s += str(the_block) + return s + + def __repr__(self): + return str(self) + + def get_FLOPs(self, input_resolution): + the_res = input_resolution + the_flops = 0 + for the_block in self.block_list: + the_flops += the_block.get_FLOPs(the_res) + the_res = the_block.get_output_resolution(the_res) + + return the_flops + + def get_model_size(self): + the_size = 0 + for the_block in self.block_list: + the_size += the_block.get_model_size() + + return the_size + + def replace_block(self, block_id, new_block): + self.block_list[block_id] = new_block + if block_id < len(self.block_list): + self.block_list[block_id + 1].set_in_channels(new_block.out_channels) + + self.module_list = nn.Module(self.block_list) + +class MasterNet(SuperBlock): + def __init__(self, fix_subnet=None, argv=None, opt=None, num_classes=1000, plainnet_struct=None, no_create=False, + no_reslink=None, no_BN=None, use_se=None): + + # fix_subnet = opt.plainnet_struct_txt + # no_BN = opt.no_BN + # no_reslink = opt.no_reslink + # use_se = opt.use_se + + super().__init__(fix_subnet=fix_subnet, argv=argv, opt=opt, num_classes=num_classes, plainnet_struct=plainnet_struct, + no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, use_se=use_se) + + self.last_channels = self.block_list[-1].out_channels + self.fc_linear = Linear(in_channels=self.last_channels, out_channels=self.num_classes, no_create=no_create) + + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + + # bn eps + for layer in self.modules(): + if isinstance(layer, nn.BatchNorm2d): + layer.eps = 1e-3 + + def extract_stage_features_and_logit(self, x, target_downsample_ratio=None): + stage_features_list = [] + image_size = x.shape[2] + output = x + + for block_id, the_block in enumerate(self.block_list): + output = the_block(output) + dowsample_ratio = round(image_size / output.shape[2]) + if dowsample_ratio == target_downsample_ratio: + stage_features_list.append(output) + target_downsample_ratio *= 2 + pass + pass + + output = F.adaptive_avg_pool2d(output, output_size=1) + output = torch.flatten(output, 1) + logit = self.fc_linear(output) + return stage_features_list, logit + + def forward(self, x): + output = x + for block_id, the_block in enumerate(self.block_list): + output = the_block(output) + + output = F.adaptive_avg_pool2d(output, output_size=1) + + output = torch.flatten(output, 1) + output = self.fc_linear(output) + return output + + def forward_pre_GAP(self, x): + output = x # torch.Size([64, 3, 224, 224]) + for the_block in self.block_list: + output = the_block(output) + return output # torch.Size([64, 512, 7, 7]) + + def get_FLOPs(self, input_resolution): + the_res = input_resolution + the_flops = 0 + for the_block in self.block_list: + the_flops += the_block.get_FLOPs(the_res) + the_res = the_block.get_output_resolution(the_res) + + the_flops += self.fc_linear.get_FLOPs(the_res) + + return the_flops + + def get_model_size(self): + the_size = 0 + for the_block in self.block_list: + the_size += the_block.get_model_size() + + the_size += self.fc_linear.get_model_size() + + return the_size + + def get_num_layers(self): + num_layers = 0 + for block in self.block_list: + assert isinstance(block, PlainNetSuperBlockClass) + num_layers += block.sub_layers + return num_layers + + def replace_block(self, block_id, new_block): + self.block_list[block_id] = new_block + + if block_id < len(self.block_list) - 1: + if self.block_list[block_id + 1].in_channels != new_block.out_channels: + self.block_list[block_id + 1].set_in_channels(new_block.out_channels) + else: + assert block_id == len(self.block_list) - 1 + self.last_channels = self.block_list[-1].out_channels + if self.fc_linear.in_channels != self.last_channels: + self.fc_linear.set_in_channels(self.last_channels) + + self.module_list = nn.ModuleList(self.block_list) + + def split(self, split_layer_threshold): + new_str = '' + for block in self.block_list: + new_str += block.split(split_layer_threshold=split_layer_threshold) + return new_str + + def init_parameters(self): + + for m in self.modules(): # 176 + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain=3.26033) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 3.26033 * np.sqrt(2 / (m.weight.shape[0] + m.weight.shape[1]))) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + else: + pass + + for superblock in self.block_list: + if not isinstance(superblock, PlainNetSuperBlockClass): + continue + for block in superblock.block_list: + if not (isinstance(block, ResBlock) or isinstance(block, ResBlockProj)): + continue + # print('---debug set bn weight zero in resblock {}:{}'.format(superblock, block)) + last_bn_block = None + for inner_resblock in block.block_list: + if isinstance(inner_resblock, BN): + last_bn_block = inner_resblock + pass + pass # end for + assert last_bn_block is not None + # print('-------- last_bn_block={}'.format(last_bn_block)) + nn.init.zeros_(last_bn_block.netblock.weight) + """ + i = 0 + for superblock in self.block_list: + print('\n-start-{}'.format(superblock)) + if not isinstance(superblock, super_blocks.PlainNetSuperBlockClass): + print('i-', i) + continue + j = 0 + for block in superblock.block_list: + if not (isinstance(block, basic_blocks.ResBlock) or isinstance(block, basic_blocks.ResBlockProj)): + print('--j', j) + continue + print('---debug set bn weight zero in resblock {}:{}'.format(superblock, block)) + last_bn_block = None + for inner_resblock in block.block_list: + if isinstance(inner_resblock, basic_blocks.BN): + last_bn_block = inner_resblock + pass + pass # end for + assert last_bn_block is not None + print('-------- last_bn_block={}'.format(last_bn_block)) + nn.init.zeros_(last_bn_block.netblock.weight) + print('--j', j) + j = j + 1 + print('i', i) + print('-end-{}\n'.format(superblock)) + i = i + 1 + """ + +def get_splitted_structure_str(AnyPlainNet, structure_str, num_classes): + the_net = AnyPlainNet(num_classes=num_classes, plainnet_struct=structure_str, no_create=True) + assert hasattr(the_net, 'split') + splitted_net_str = the_net.split(split_layer_threshold=6) + return splitted_net_str + +def get_new_random_structure_str(AnyPlainNet, structure_str, num_classes, get_search_space_func, + num_replaces=1): # structure_str 初始化的网络结构 init_plainnet + the_net = AnyPlainNet(num_classes=num_classes, plainnet_struct=structure_str, no_create=True) + assert isinstance(the_net, SuperBlock) + selected_random_id_set = set() + for replace_count in range(num_replaces): + random_id = random.randint(0, len(the_net.block_list) - 1) + if random_id in selected_random_id_set: + continue + selected_random_id_set.add(random_id) + to_search_student_blocks_list_list = get_search_space_func(the_net.block_list, random_id) + + to_search_student_blocks_list = [x for sublist in to_search_student_blocks_list_list for x in sublist] + new_student_block_str = random.choice(to_search_student_blocks_list) + + if len(new_student_block_str) > 0: + # new_student_block = PlainNet.create_netblock_list_from_str(new_student_block_str, no_create=True) + # new_student_block = create_netblock_list_from_str(new_student_block_str, no_create=True) + new_student_block = build_netblock_list_from_str(new_student_block_str, no_create=True) + assert len(new_student_block) == 1 + new_student_block = new_student_block[0] + if random_id > 0: + last_block_out_channels = the_net.block_list[random_id - 1].out_channels + new_student_block.set_in_channels(last_block_out_channels) + the_net.block_list[random_id] = new_student_block + else: + # replace with empty block + the_net.block_list[random_id] = None + pass # end for + + # adjust channels and remove empty layer + tmp_new_block_list = [x for x in the_net.block_list if x is not None] + last_channels = the_net.block_list[0].out_channels + for block in tmp_new_block_list[1:]: + block.set_in_channels(last_channels) + last_channels = block.out_channels + the_net.block_list = tmp_new_block_list + + new_random_structure_str = the_net.split(split_layer_threshold=6) + return new_random_structure_str + +def get_latency(AnyPlainNet, random_structure_str, gpu, args): + the_model = AnyPlainNet(num_classes=args.num_classes, plainnet_struct=random_structure_str, + no_create=False, no_reslink=False) + if gpu is not None: + the_model = the_model.cuda(gpu) + + def get_model_latency(model, batch_size, resolution, in_channels, gpu, repeat_times, fp16): + if gpu is not None: + device = torch.device('cuda:{}'.format(gpu)) + else: + device = torch.device('cpu') + + if fp16: + model = model.half() + dtype = torch.float16 + else: + dtype = torch.float32 + + the_image = torch.randn(batch_size, in_channels, resolution, resolution, dtype=dtype, + device=device) + + model.eval() + warmup_T = 3 + with torch.no_grad(): + for i in range(warmup_T): + the_output = model(the_image) + start_timer = time.time() + for repeat_count in range(repeat_times): + the_output = model(the_image) + + end_timer = time.time() + the_latency = (end_timer - start_timer) / float(repeat_times) / batch_size + return the_latency + + + # the_latency = benchmark_network_latency.get_model_latency(model=the_model, batch_size=args.batch_size, + the_latency = get_model_latency( + model=the_model, + batch_size=args.batch_size, + resolution=args.input_image_size, + in_channels=3, gpu=gpu, repeat_times=1, + fp16=True) + + del the_model + torch.cuda.empty_cache() + return the_latency + +def compute_nas_score(AnyPlainNet, random_structure_str, gpu, args): + # compute network zero-shot proxy score + # arch_1 = 'SuperConvK3BNRELU(3,24,2,1)SuperResK3K3(24,32,2,64,1)SuperResK5K5(32,64,2,32,1)SuperResK5K5(64,168,2,96,1)SuperResK1K5K1(168,320,1,120,1)SuperResK1K5K1(320,640,2,304,3)SuperResK1K5K1(640,512,1,384,1)SuperConvK1BNRELU(512,2384,1,1)' + # random_structure_str = 'SuperConvK3BNRELU(3,24,2,1)SuperResK1K5K1(24,32,2,32,1)SuperResK1K7K1(32,104,2,64,1)SuperResK1K5K1(104,512,2,160,1)SuperResK1K5K1(512,344,1,192,1)SuperResK1K5K1(344,688,2,320,4)SuperResK1K5K1(688,680,1,304,3)SuperConvK1BNRELU(680,2552,1,1)' + the_model = AnyPlainNet(num_classes=args.num_classes, plainnet_struct=random_structure_str, + no_create=False, no_reslink=True) + the_model = the_model.cuda(gpu) + try: + if args.zero_shot_score == 'Zen': + the_nas_core_info = compute_zen_score.compute_nas_score(model=the_model, gpu=gpu, + resolution=args.input_image_size, + mixup_gamma=args.gamma, batch_size=args.batch_size, + repeat=1) + the_nas_core = the_nas_core_info['avg_nas_score'] + elif args.zero_shot_score == 'TE-NAS': + the_nas_core = compute_te_nas_score.compute_NTK_score(model=the_model, gpu=gpu, + resolution=args.input_image_size, + batch_size=args.batch_size) + + elif args.zero_shot_score == 'Syncflow': + the_nas_core = compute_syncflow_score.do_compute_nas_score(model=the_model, gpu=gpu, + resolution=args.input_image_size, + batch_size=args.batch_size) + + elif args.zero_shot_score == 'GradNorm': + the_nas_core = compute_gradnorm_score.compute_nas_score(model=the_model, gpu=gpu, + resolution=args.input_image_size, + batch_size=args.batch_size) + + elif args.zero_shot_score == 'Flops': + the_nas_core = the_model.get_FLOPs(args.input_image_size) + + elif args.zero_shot_score == 'Params': + the_nas_core = the_model.get_model_size() + + elif args.zero_shot_score == 'Random': + the_nas_core = np.random.randn() + + elif args.zero_shot_score == 'NASWOT': + the_nas_core = compute_NASWOT_score.compute_nas_score(gpu=gpu, model=the_model, + resolution=args.input_image_size, + batch_size=args.batch_size) + except Exception as err: + logging.info(str(err)) + logging.info('--- Failed structure: ') + logging.info(str(the_model)) + # raise err + the_nas_core = -9999 + + + del the_model + torch.cuda.empty_cache() + return the_nas_core + +def search_cmd_options(argv): + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', type=int, default=None) + parser.add_argument('--zero_shot_score', type=str, default=None, + help='could be: Zen (for Zen-NAS), TE (for TE-NAS)') + parser.add_argument('--search_space', type=str, default=None, + help='.py file to specify the search space.') + parser.add_argument('--evolution_max_iter', type=int, default=int(48e4), + help='max iterations of evolution.') + parser.add_argument('--budget_model_size', type=float, default=None, help='budget of model size ( number of parameters), e.g., 1e6 means 1M params') + parser.add_argument('--budget_flops', type=float, default=None, help='budget of flops, e.g. , 1.8e6 means 1.8 GFLOPS') + parser.add_argument('--budget_latency', type=float, default=None, help='latency of forward inference per mini-batch, e.g., 1e-3 means 1ms.') + parser.add_argument('--max_layers', type=int, default=None, help='max number of layers of the network.') + parser.add_argument('--batch_size', type=int, default=None, help='number of instances in one mini-batch.') + parser.add_argument('--input_image_size', type=int, default=None, + help='resolution of input image, usually 32 for CIFAR and 224 for ImageNet.') + parser.add_argument('--population_size', type=int, default=512, help='population size of evolution.') + parser.add_argument('--save_dir', type=str, default=None, + help='output directory') + parser.add_argument('--gamma', type=float, default=1e-2, + help='noise perturbation coefficient') + parser.add_argument('--num_classes', type=int, default=None, + help='number of classes') + parser.add_argument('--plainnet_struct', type=str, default=None, help='PlainNet structure string') + parser.add_argument('--plainnet_struct_txt', type=str, default=None, help='PlainNet structure file name') + parser.add_argument('--no_BN', action='store_true') + parser.add_argument('--no_reslink', action='store_true') + parser.add_argument('--use_se', action='store_true') + module_opt, _ = parser.parse_known_args(argv) + return module_opt + + +def main(args, argv): + gpu = args.gpu + # if gpu is not None: + # torch.cuda.set_device('cuda:{}'.format(gpu)) + # torch.backends.cudnn.benchmark = True + + # 判断是否已经存在了best_structure + best_structure_txt = os.path.join(args.save_dir, 'best_structure.txt') + if os.path.isfile(best_structure_txt): + print('skip ' + best_structure_txt) + return None + + # load search space config .py file + select_search_space = global_utils.load_py_module_from_path(args.search_space) + + # load masternet + fix_subnet = args.plainnet_struct_txt + masternet = MasterNet(fix_subnet=fix_subnet, num_classes=args.num_classes, opt=args, argv=argv, no_create=True) + initial_structure_str = str(masternet) + + popu_structure_list = [] + popu_zero_shot_score_list = [] + popu_latency_list = [] + + start_timer = time.time() + for loop_count in range(args.evolution_max_iter): # 240 + # too many networks in the population pool, remove one with the smallest score + while len(popu_structure_list) > args.population_size: # 512个候选网络 + min_zero_shot_score = min(popu_zero_shot_score_list) + tmp_idx = popu_zero_shot_score_list.index(min_zero_shot_score) + popu_zero_shot_score_list.pop(tmp_idx) + popu_structure_list.pop(tmp_idx) + popu_latency_list.pop(tmp_idx) + pass + + if loop_count >= 1 and loop_count % 100 == 0: + max_score = max(popu_zero_shot_score_list) + min_score = min(popu_zero_shot_score_list) + elasp_time = time.time() - start_timer + logging.info(f'loop_count={loop_count}/{args.evolution_max_iter}, max_score={max_score:4g}, min_score={min_score:4g}, time={elasp_time/3600:4g}h') + + # ----- generate a random structure ----- # + if len(popu_structure_list) <= 10: + random_structure_str = get_new_random_structure_str( + AnyPlainNet=MasterNet, structure_str=initial_structure_str, num_classes=args.num_classes, + get_search_space_func=select_search_space.gen_search_space, num_replaces=1) + else: + tmp_idx = random.randint(0, len(popu_structure_list) - 1) + tmp_random_structure_str = popu_structure_list[tmp_idx] + random_structure_str = get_new_random_structure_str( + AnyPlainNet=MasterNet, structure_str=tmp_random_structure_str, num_classes=args.num_classes, + get_search_space_func=select_search_space.gen_search_space, num_replaces=2) + + random_structure_str = get_splitted_structure_str(MasterNet, random_structure_str, + num_classes=args.num_classes) + + the_model = None + # 经过筛选 + # max_layers / budget_model_size / budget_flops / budget_latency=0.0001 + if args.max_layers is not None: # 10 + if the_model is None: + the_model = MasterNet(num_classes=args.num_classes, plainnet_struct=random_structure_str, + no_create=True, no_reslink=False) # 这里去除VCNN指的是只有包含RELU激活函数的卷积层构成的网络,不包含BN、残差块等算子,并且只取到GAP(global average pool)层的前一层以保留更多的信息。 + the_layers = the_model.get_num_layers() + if args.max_layers < the_layers: + continue + + if args.budget_model_size is not None: + if the_model is None: + the_model = MasterNet(num_classes=args.num_classes, plainnet_struct=random_structure_str, + no_create=True, no_reslink=False) + the_model_size = the_model.get_model_size() + if args.budget_model_size < the_model_size: + continue + + if args.budget_flops is not None: + if the_model is None: + the_model = MasterNet(num_classes=args.num_classes, plainnet_struct=random_structure_str, + no_create=True, no_reslink=False) + the_model_flops = the_model.get_FLOPs(args.input_image_size) + if args.budget_flops < the_model_flops: + continue + + the_latency = np.inf + """ # latency """ + if args.budget_latency is not None: # 0.0001 + the_latency = get_latency(MasterNet, random_structure_str, gpu, args) + if args.budget_latency < the_latency: + continue + + # 计算这个score + the_nas_core = compute_nas_score(MasterNet, random_structure_str, gpu, args) + + popu_structure_list.append(random_structure_str) + popu_zero_shot_score_list.append(the_nas_core) + popu_latency_list.append(the_latency) + + return popu_structure_list, popu_zero_shot_score_list, popu_latency_list + + +if __name__ == '__main__': + args = search_cmd_options(sys.argv) + log_fn = os.path.join(args.save_dir, 'evolution_search.log') + global_utils.create_logging(log_fn) + + info = main(args, sys.argv) + if info is None: + exit() + + popu_structure_list, popu_zero_shot_score_list, popu_latency_list = info + + # export best structure + best_score = max(popu_zero_shot_score_list) + best_idx = popu_zero_shot_score_list.index(best_score) + best_structure_str = popu_structure_list[best_idx] + the_latency = popu_latency_list[best_idx] + + best_structure_txt = os.path.join(args.save_dir, 'best_structure.txt') + global_utils.mkfilepath(best_structure_txt) + with open(best_structure_txt, 'w') as fid: + fid.write(best_structure_str) + pass # end with diff --git a/slurm_zenscore.sh b/slurm_zenscore.sh new file mode 100755 index 000000000..b72ce5aa1 --- /dev/null +++ b/slurm_zenscore.sh @@ -0,0 +1,14 @@ +set -e + +./command.sh search_kk 1 1 "python mmrazor/models/architectures/backbones/masternet.py --gpu 0 \ + --zero_shot_score Zen \ + --search_space mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py \ + --budget_latency 1e-4 \ + --max_layers 10 \ + --batch_size 64 \ + --input_image_size 224 \ + --plainnet_struct_txt ./work_dirs/1ms/init_plainnet.txt \ + --num_classes 1000 \ + --evolution_max_iter 20000 \ + --population_size 512 \ + --save_dir ./work_dirs/1ms" \ No newline at end of file diff --git a/tools/slurm_test_.sh b/tools/slurm_test_.sh new file mode 100755 index 000000000..6dd67e574 --- /dev/null +++ b/tools/slurm_test_.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +CHECKPOINT=$4 +GPUS=${GPUS:-8} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +PY_ARGS=${@:5} +SRUN_ARGS=${SRUN_ARGS:-""} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} diff --git a/tools/slurm_train_.sh b/tools/slurm_train_.sh new file mode 100755 index 000000000..3c44208bd --- /dev/null +++ b/tools/slurm_train_.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +WORK_DIR=$4 +GPUS=${GPUS:-8} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +SRUN_ARGS=${SRUN_ARGS:-""} +PY_ARGS=${@:5} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + --preempt \ + ${SRUN_ARGS} \ + python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} From bdb9a96653c963c931ce8e8885b4ecc1439f7267 Mon Sep 17 00:00:00 2001 From: aptsunny Date: Tue, 3 Jan 2023 17:20:44 +0800 Subject: [PATCH 2/3] zennas test --- .../zennas_plainnet_search_8xb128_in1k.py | 18 ++ .../zennas_plainnet_supernet_8xb128_in1k.py | 39 +++ mmrazor/engine/runner/__init__.py | 3 +- mmrazor/engine/runner/zero_shot_loop.py | 103 ++++++++ mmrazor/models/algorithms/nas/__init__.py | 4 +- mmrazor/models/algorithms/nas/zennas.py | 135 ++++++++++ .../architectures/backbones/__init__.py | 4 +- .../architectures/backbones/masternet.py | 7 +- .../backbones/searchable_plainnet.py | 240 ++++++++++++++++++ 9 files changed, 547 insertions(+), 6 deletions(-) create mode 100644 configs/nas/mmcls/zennas/zennas_plainnet_search_8xb128_in1k.py create mode 100644 configs/nas/mmcls/zennas/zennas_plainnet_supernet_8xb128_in1k.py create mode 100644 mmrazor/engine/runner/zero_shot_loop.py create mode 100644 mmrazor/models/algorithms/nas/zennas.py create mode 100644 mmrazor/models/architectures/backbones/searchable_plainnet.py diff --git a/configs/nas/mmcls/zennas/zennas_plainnet_search_8xb128_in1k.py b/configs/nas/mmcls/zennas/zennas_plainnet_search_8xb128_in1k.py new file mode 100644 index 000000000..439c0cce7 --- /dev/null +++ b/configs/nas/mmcls/zennas/zennas_plainnet_search_8xb128_in1k.py @@ -0,0 +1,18 @@ +_base_ = ['./zennas_plainnet_supernet_8xb128_in1k.py'] + +# model = dict(norm_training=True) + +train_cfg = dict( + _delete_=True, + type='mmrazor.ZeroShotLoop', + dataloader=_base_.val_dataloader, + evaluator=_base_.val_evaluator, +) + # max_epochs=20, + # num_candidates=50, + # top_k=10, + # num_mutation=25, + # num_crossover=25, + # mutate_prob=0.1, + # constraints_range=dict(flops=(0, 330)), + # score_key='accuracy/top1') diff --git a/configs/nas/mmcls/zennas/zennas_plainnet_supernet_8xb128_in1k.py b/configs/nas/mmcls/zennas/zennas_plainnet_supernet_8xb128_in1k.py new file mode 100644 index 000000000..fb99019b9 --- /dev/null +++ b/configs/nas/mmcls/zennas/zennas_plainnet_supernet_8xb128_in1k.py @@ -0,0 +1,39 @@ +_base_ = [ + 'mmrazor::_base_/settings/imagenet_bs1024_spos.py', + # 'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py', + 'mmcls::_base_/default_runtime.py', +] + +nas_backbone = dict( + _scope_='mmrazor', + type='MasterNet', + fix_subnet='./work_dirs/1ms/init_plainnet.txt', + no_create=True, + num_classes=1000) + +# model +supernet = dict( + type='ImageClassifier', + data_preprocessor=_base_.preprocess_cfg, + backbone=nas_backbone, + # neck=dict(type='GlobalAveragePooling'), + # head=dict( + # type='LinearClsHead', + # num_classes=1000, + # in_channels=1024, + # loss=dict( + # type='LabelSmoothLoss', + # num_classes=1000, + # label_smooth_val=0.1, + # mode='original', + # loss_weight=1.0), + # topk=(1, 5)) +) + +model = dict( + type='mmrazor.ZenNAS', + architecture=supernet, + # mutator=dict(type='mmrazor.OneShotModuleMutator') +) + +find_unused_parameters = True diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 0f2b88d27..13ec1ed08 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -6,10 +6,11 @@ from .slimmable_val_loop import SlimmableValLoop from .subnet_sampler_loop import GreedySamplerTrainLoop from .subnet_val_loop import SubnetValLoop +from .zero_shot_loop import ZeroShotLoop __all__ = [ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', - 'ItePruneValLoop' + 'ItePruneValLoop', 'ZeroShotLoop' ] diff --git a/mmrazor/engine/runner/zero_shot_loop.py b/mmrazor/engine/runner/zero_shot_loop.py new file mode 100644 index 000000000..d372a8736 --- /dev/null +++ b/mmrazor/engine/runner/zero_shot_loop.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +from mmengine.evaluator import Evaluator +from mmengine.hooks import CheckpointHook +from mmengine.runner import ValLoop, BaseLoop, TestLoop +from torch.utils.data import DataLoader + +from mmrazor.models.utils import add_prefix +from mmrazor.registry import LOOPS, TASK_UTILS +# from .utils import CalibrateBNMixin + + +@LOOPS.register_module() +class ZeroShotLoop(ValLoop): + """Loop for subnet validation in NAS with BN re-calibration. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 validation. Defaults to + False. + evaluate_fixed_subnet (bool): Whether to evaluate a fixed subnet only + or not. Defaults to False. + calibrate_sample_num (int): The number of images to compute the true + average of per-batch mean/variance instead of the running average. + Defaults to 4096. + estimator_cfg (dict, Optional): Used for building a resource estimator. + Defaults to dict(type='mmrazor.ResourceEstimator'). + """ + + def __init__( + self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + fp16: bool = False, + evaluate_fixed_subnet: bool = False, + calibrate_sample_num: int = 4096, + estimator_cfg: Optional[Dict] = dict(type='mmrazor.ResourceEstimator') + ) -> None: + super().__init__(runner, dataloader) + # super().__init__(runner, dataloader, evaluator, fp16) + + if self.runner.distributed: + model = self.runner.model.module + else: + model = self.runner.model + + self.model = model + self.evaluate_fixed_subnet = evaluate_fixed_subnet + self.calibrate_sample_num = calibrate_sample_num + self.estimator = TASK_UTILS.build(estimator_cfg) + + # remove CheckpointHook to avoid extra problems. + for hook in self.runner._hooks: + if isinstance(hook, CheckpointHook): + self.runner._hooks.remove(hook) + break + + def run(self): + """Launch validation.""" + self.runner.call_hook('before_val') + self.runner.call_hook('before_val_epoch') + + all_metrics = dict() + + if self.evaluate_fixed_subnet: + metrics = self._evaluate_once() + all_metrics.update(add_prefix(metrics, 'fix_subnet')) + elif hasattr(self.model, 'sample_kinds'): + for kind in self.model.sample_kinds: + if kind == 'max': + self.model.set_max_subnet() + metrics = self._evaluate_once() + all_metrics.update(add_prefix(metrics, 'max_subnet')) + elif kind == 'min': + self.model.set_min_subnet() + metrics = self._evaluate_once() + all_metrics.update(add_prefix(metrics, 'min_subnet')) + elif 'random' in kind: + self.model.set_subnet(self.model.sample_subnet()) + metrics = self._evaluate_once() + all_metrics.update(add_prefix(metrics, f'{kind}_subnet')) + + self.runner.call_hook('after_val_epoch', metrics=all_metrics) + self.runner.call_hook('after_val') + + def _evaluate_once(self) -> Dict: + """Evaluate a subnet once with BN re-calibration.""" + self.calibrate_bn_statistics(self.runner.train_dataloader, + self.calibrate_sample_num) + self.runner.model.eval() + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch) + + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + resource_metrics = self.estimator.estimate(self.model) + metrics.update(resource_metrics) + + return metrics diff --git a/mmrazor/models/algorithms/nas/__init__.py b/mmrazor/models/algorithms/nas/__init__.py index 6a9c29161..2d36d44ac 100644 --- a/mmrazor/models/algorithms/nas/__init__.py +++ b/mmrazor/models/algorithms/nas/__init__.py @@ -5,8 +5,10 @@ from .darts import Darts, DartsDDP from .dsnas import DSNAS, DSNASDDP from .spos import SPOS +from .zennas import ZenNAS __all__ = [ 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'BigNAS', 'BigNASDDP', 'Darts', - 'DartsDDP', 'DSNAS', 'DSNASDDP', 'DSNAS', 'DSNASDDP', 'Autoformer' + 'DartsDDP', 'DSNAS', 'DSNASDDP', 'DSNAS', 'DSNASDDP', 'Autoformer', + 'ZenNAS' ] diff --git a/mmrazor/models/algorithms/nas/zennas.py b/mmrazor/models/algorithms/nas/zennas.py new file mode 100644 index 000000000..f7afd6aa0 --- /dev/null +++ b/mmrazor/models/algorithms/nas/zennas.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement +from torch import nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.mutators import OneShotModuleMutator +from mmrazor.registry import MODELS +from mmrazor.utils import SingleMutatorRandomSubnet, ValidFixMutable +from ..base import BaseAlgorithm, LossResults + + +@MODELS.register_module() +class ZenNAS(BaseAlgorithm): + """Implementation of `SPOS `_ + + SPOS means Single Path One-Shot, a classic NAS algorithm. + :class:`SPOS` implements the APIs required by the Single Path One-Shot + algorithm, as well as the supernet training and subnet retraining logic + for each iter. + + The logic of the search part is implemented in + :class:`mmrazor.core.EvolutionSearch` + + Args: + architecture (dict|:obj:`BaseModel`): The config of :class:`BaseModel` + or built model. Corresponding to supernet in NAS algorithm. + mutator (dict|:obj:`OneShotModuleMutator`): The config of + :class:`OneShotModuleMutator` or built mutator. + fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or + loaded dict or built :obj:`FixSubnet`. + norm_training (bool): Whether to set norm layers to training mode, + namely, not freeze running stats (mean and var). Note: Effect on + Batch Norm and its variants only. Defaults to False. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. Defaults to None. + init_cfg (dict): Init config for ``BaseModule``. + + Note: + SPOS has two training mode: supernet training and subnet retraining. + If `fix_subnet` is None, it means supernet training. + If `fix_subnet` is not None, it means subnet training. + + Note: + During supernet training, since each op is not fully trained, the + statistics of :obj:_BatchNorm are inaccurate. This problem affects the + evaluation of the performance of each subnet in the search phase. There + are usually two ways to solve this problem, both need to set + `norm_training` to True: + + 1) Using a large batch size, BNs use the mean and variance of the + current batch during forward. + 2) Recalibrate the statistics of BN before searching. + + Note: + SPOS only uses one mutator. If you want to inherit SPOS to develop + more complex algorithms, it is also feasible to use multiple mutators. + For example, one part of the supernet uses SPOS(OneShotModuleMutator) + to search, and the other part uses Darts(DiffModuleMutator) to search. + """ + + # TODO fix ea's name in doc-string. + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutator: Optional[Union[OneShotModuleMutator, Dict]] = None, + fix_subnet: Optional[ValidFixMutable] = None, + norm_training: bool = False, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): + super().__init__(architecture, data_preprocessor, init_cfg) + + # SPOS has two training mode: supernet training and subnet retraining. + # fix_subnet is not None, means subnet retraining. + if fix_subnet: + # Avoid circular import + from mmrazor.structures import load_fix_subnet + + # According to fix_subnet, delete the unchosen part of supernet + load_fix_subnet(self.architecture, fix_subnet) + self.is_supernet = False + else: + # assert mutator is not None, \ + # 'mutator cannot be None when fix_subnet is None.' + # if isinstance(mutator, OneShotModuleMutator): + # self.mutator = mutator + # elif isinstance(mutator, dict): + # self.mutator = MODELS.build(mutator) + # else: + # raise TypeError('mutator should be a `dict` or ' + # f'`OneShotModuleMutator` instance, but got ' + # f'{type(mutator)}') + + # Mutator is an essential component of the NAS algorithm. It + # provides some APIs commonly used by NAS. + # Before using it, you must do some preparations according to + # the supernet. + # self.mutator.prepare_from_supernet(self.architecture) + self.is_supernet = True + + self.norm_training = norm_training + + def sample_subnet(self) -> SingleMutatorRandomSubnet: + """Random sample subnet by mutator.""" + return self.mutator.sample_choices() + + def set_subnet(self, subnet: SingleMutatorRandomSubnet): + """Set the subnet sampled by :meth:sample_subnet.""" + self.mutator.set_choices(subnet) + + def loss( + self, + batch_inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + ) -> LossResults: + """Calculate losses from a batch of inputs and data samples.""" + if self.is_supernet: + random_subnet = self.sample_subnet() + self.set_subnet(random_subnet) + return self.architecture(batch_inputs, data_samples, mode='loss') + else: + return self.architecture(batch_inputs, data_samples, mode='loss') + + def train(self, mode=True): + """Convert the model into eval mode while keep normalization layer + unfreezed.""" + + super().train(mode) + if self.norm_training and not mode: + for module in self.architecture.modules(): + if isinstance(module, _BatchNorm): + module.training = True diff --git a/mmrazor/models/architectures/backbones/__init__.py b/mmrazor/models/architectures/backbones/__init__.py index 99313ee7e..7f32d18ce 100644 --- a/mmrazor/models/architectures/backbones/__init__.py +++ b/mmrazor/models/architectures/backbones/__init__.py @@ -5,8 +5,10 @@ from .searchable_mobilenet_v3 import AttentiveMobileNetV3 from .searchable_shufflenet_v2 import SearchableShuffleNetV2 from .wideresnet import WideResNet +from .searchable_plainnet import MasterNet __all__ = [ 'DartsBackbone', 'AutoformerBackbone', 'SearchableMobileNetV2', - 'AttentiveMobileNetV3', 'SearchableShuffleNetV2', 'WideResNet' + 'AttentiveMobileNetV3', 'SearchableShuffleNetV2', 'WideResNet', + 'MasterNet' ] diff --git a/mmrazor/models/architectures/backbones/masternet.py b/mmrazor/models/architectures/backbones/masternet.py index 1614e81e2..788e5e0b3 100644 --- a/mmrazor/models/architectures/backbones/masternet.py +++ b/mmrazor/models/architectures/backbones/masternet.py @@ -30,8 +30,8 @@ class SuperBlock(nn.Module): def __init__(self, fix_subnet=None, argv=None, opt=None, num_classes=None, plainnet_struct=None, no_create=False, **kwargs): super(SuperBlock, self).__init__() - self.argv = argv - self.opt = opt + # self.argv = argv + # self.opt = opt self.num_classes = num_classes self.plainnet_struct = plainnet_struct self.plainnet_struct_txt = fix_subnet @@ -494,7 +494,8 @@ def main(args, argv): # load masternet fix_subnet = args.plainnet_struct_txt - masternet = MasterNet(fix_subnet=fix_subnet, num_classes=args.num_classes, opt=args, argv=argv, no_create=True) + # masternet = MasterNet(fix_subnet=fix_subnet, num_classes=args.num_classes, opt=args, argv=argv, no_create=True) + masternet = MasterNet(fix_subnet=fix_subnet, num_classes=args.num_classes, no_create=True) initial_structure_str = str(masternet) popu_structure_list = [] diff --git a/mmrazor/models/architectures/backbones/searchable_plainnet.py b/mmrazor/models/architectures/backbones/searchable_plainnet.py new file mode 100644 index 000000000..15dda4e87 --- /dev/null +++ b/mmrazor/models/architectures/backbones/searchable_plainnet.py @@ -0,0 +1,240 @@ +import torch +from torch import nn +import torch.nn.functional as F +from mmrazor.models.architectures.backbones.PlainNet import Linear, PlainNetSuperBlockClass, ResBlock, ResBlockProj, BN +from mmrazor.models.architectures.backbones.PlainNet.basic_blocks import _build_netblock_list_from_str_, build_netblock_list_from_str + +from mmrazor.registry import MODELS + + +class SuperBlock(nn.Module): + def __init__(self, fix_subnet=None, argv=None, opt=None, num_classes=None, plainnet_struct=None, no_create=False, + **kwargs): + super(SuperBlock, self).__init__() + self.argv = argv + self.opt = opt + self.num_classes = num_classes + self.plainnet_struct = plainnet_struct + self.plainnet_struct_txt = fix_subnet + + if self.plainnet_struct_txt is not None: + with open(self.plainnet_struct_txt, 'r') as fid: + the_line = fid.readlines()[0].strip() + self.plainnet_struct = the_line + pass + + if self.plainnet_struct is None: + return + + the_s = self.plainnet_struct # type: str + block_list, remaining_s = _build_netblock_list_from_str_(the_s, no_create=no_create, **kwargs) + # block_list, remaining_s = _create_netblock_list_from_str_(the_s, no_create=no_create, **kwargs) + assert len(remaining_s) == 0 + + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) # register + + def forward(self, x): + output = x + for the_block in self.block_list: + output = the_block(output) + return output + + def __str__(self): + s = '' + for the_block in self.block_list: + s += str(the_block) + return s + + def __repr__(self): + return str(self) + + def get_FLOPs(self, input_resolution): + the_res = input_resolution + the_flops = 0 + for the_block in self.block_list: + the_flops += the_block.get_FLOPs(the_res) + the_res = the_block.get_output_resolution(the_res) + + return the_flops + + def get_model_size(self): + the_size = 0 + for the_block in self.block_list: + the_size += the_block.get_model_size() + + return the_size + + def replace_block(self, block_id, new_block): + self.block_list[block_id] = new_block + if block_id < len(self.block_list): + self.block_list[block_id + 1].set_in_channels(new_block.out_channels) + + self.module_list = nn.Module(self.block_list) + +@MODELS.register_module() +class MasterNet(SuperBlock): + def __init__(self, fix_subnet=None, argv=None, opt=None, num_classes=1000, plainnet_struct=None, no_create=False, + no_reslink=None, no_BN=None, use_se=None): + + super().__init__(fix_subnet=fix_subnet, argv=argv, opt=opt, num_classes=num_classes, plainnet_struct=plainnet_struct, + no_create=no_create, no_reslink=no_reslink, no_BN=no_BN, use_se=use_se) + + self.last_channels = self.block_list[-1].out_channels + self.fc_linear = Linear(in_channels=self.last_channels, out_channels=self.num_classes, no_create=no_create) + + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + + # bn eps + for layer in self.modules(): + if isinstance(layer, nn.BatchNorm2d): + layer.eps = 1e-3 + + def extract_stage_features_and_logit(self, x, target_downsample_ratio=None): + stage_features_list = [] + image_size = x.shape[2] + output = x + + for block_id, the_block in enumerate(self.block_list): + output = the_block(output) + dowsample_ratio = round(image_size / output.shape[2]) + if dowsample_ratio == target_downsample_ratio: + stage_features_list.append(output) + target_downsample_ratio *= 2 + pass + pass + + output = F.adaptive_avg_pool2d(output, output_size=1) + output = torch.flatten(output, 1) + logit = self.fc_linear(output) + return stage_features_list, logit + + def forward(self, x): + output = x + for block_id, the_block in enumerate(self.block_list): + output = the_block(output) + + output = F.adaptive_avg_pool2d(output, output_size=1) + + output = torch.flatten(output, 1) + output = self.fc_linear(output) + return output + + def forward_pre_GAP(self, x): + output = x # torch.Size([64, 3, 224, 224]) + for the_block in self.block_list: + output = the_block(output) + return output # torch.Size([64, 512, 7, 7]) + + def get_FLOPs(self, input_resolution): + the_res = input_resolution + the_flops = 0 + for the_block in self.block_list: + the_flops += the_block.get_FLOPs(the_res) + the_res = the_block.get_output_resolution(the_res) + + the_flops += self.fc_linear.get_FLOPs(the_res) + + return the_flops + + def get_model_size(self): + the_size = 0 + for the_block in self.block_list: + the_size += the_block.get_model_size() + + the_size += self.fc_linear.get_model_size() + + return the_size + + def get_num_layers(self): + num_layers = 0 + for block in self.block_list: + assert isinstance(block, PlainNetSuperBlockClass) + num_layers += block.sub_layers + return num_layers + + def replace_block(self, block_id, new_block): + self.block_list[block_id] = new_block + + if block_id < len(self.block_list) - 1: + if self.block_list[block_id + 1].in_channels != new_block.out_channels: + self.block_list[block_id + 1].set_in_channels(new_block.out_channels) + else: + assert block_id == len(self.block_list) - 1 + self.last_channels = self.block_list[-1].out_channels + if self.fc_linear.in_channels != self.last_channels: + self.fc_linear.set_in_channels(self.last_channels) + + self.module_list = nn.ModuleList(self.block_list) + + def split(self, split_layer_threshold): + new_str = '' + for block in self.block_list: + new_str += block.split(split_layer_threshold=split_layer_threshold) + return new_str + + def init_parameters(self): + + for m in self.modules(): # 176 + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain=3.26033) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 3.26033 * np.sqrt(2 / (m.weight.shape[0] + m.weight.shape[1]))) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + else: + pass + + for superblock in self.block_list: + if not isinstance(superblock, PlainNetSuperBlockClass): + continue + for block in superblock.block_list: + if not (isinstance(block, ResBlock) or isinstance(block, ResBlockProj)): + continue + # print('---debug set bn weight zero in resblock {}:{}'.format(superblock, block)) + last_bn_block = None + for inner_resblock in block.block_list: + if isinstance(inner_resblock, BN): + last_bn_block = inner_resblock + pass + pass # end for + assert last_bn_block is not None + # print('-------- last_bn_block={}'.format(last_bn_block)) + nn.init.zeros_(last_bn_block.netblock.weight) + """ + i = 0 + for superblock in self.block_list: + print('\n-start-{}'.format(superblock)) + if not isinstance(superblock, super_blocks.PlainNetSuperBlockClass): + print('i-', i) + continue + j = 0 + for block in superblock.block_list: + if not (isinstance(block, basic_blocks.ResBlock) or isinstance(block, basic_blocks.ResBlockProj)): + print('--j', j) + continue + print('---debug set bn weight zero in resblock {}:{}'.format(superblock, block)) + last_bn_block = None + for inner_resblock in block.block_list: + if isinstance(inner_resblock, basic_blocks.BN): + last_bn_block = inner_resblock + pass + pass # end for + assert last_bn_block is not None + print('-------- last_bn_block={}'.format(last_bn_block)) + nn.init.zeros_(last_bn_block.netblock.weight) + print('--j', j) + j = j + 1 + print('i', i) + print('-end-{}\n'.format(superblock)) + i = i + 1 + """ \ No newline at end of file From 1cd469236f5b10be15698b71b2ea5c4f1b9b1964 Mon Sep 17 00:00:00 2001 From: aptsunny Date: Thu, 5 Jan 2023 15:26:41 +0800 Subject: [PATCH 3/3] zero-shot loop fly --- .../zennas_plainnet_search_8xb128_in1k.py | 15 +- .../zennas_plainnet_supernet_8xb128_in1k.py | 25 +- mmrazor/engine/runner/zero_shot_loop.py | 323 +++++++++++++++--- mmrazor/models/algorithms/nas/zennas.py | 32 ++ .../backbones/PlainNet/basic_blocks.py | 4 +- .../SearchSpace/search_space_XXBL.py | 7 +- slurm_zenscore.sh | 35 +- 7 files changed, 347 insertions(+), 94 deletions(-) diff --git a/configs/nas/mmcls/zennas/zennas_plainnet_search_8xb128_in1k.py b/configs/nas/mmcls/zennas/zennas_plainnet_search_8xb128_in1k.py index 439c0cce7..090783772 100644 --- a/configs/nas/mmcls/zennas/zennas_plainnet_search_8xb128_in1k.py +++ b/configs/nas/mmcls/zennas/zennas_plainnet_search_8xb128_in1k.py @@ -7,12 +7,9 @@ type='mmrazor.ZeroShotLoop', dataloader=_base_.val_dataloader, evaluator=_base_.val_evaluator, -) - # max_epochs=20, - # num_candidates=50, - # top_k=10, - # num_mutation=25, - # num_crossover=25, - # mutate_prob=0.1, - # constraints_range=dict(flops=(0, 330)), - # score_key='accuracy/top1') + search_space='mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py', + plainnet_struct_txt='./work_dirs/1ms/init_plainnet.txt', + max_epochs=480000, # evolution_max_iter + population_size=512, + num_classes=1000, +) \ No newline at end of file diff --git a/configs/nas/mmcls/zennas/zennas_plainnet_supernet_8xb128_in1k.py b/configs/nas/mmcls/zennas/zennas_plainnet_supernet_8xb128_in1k.py index fb99019b9..9370b2cfa 100644 --- a/configs/nas/mmcls/zennas/zennas_plainnet_supernet_8xb128_in1k.py +++ b/configs/nas/mmcls/zennas/zennas_plainnet_supernet_8xb128_in1k.py @@ -1,9 +1,12 @@ _base_ = [ 'mmrazor::_base_/settings/imagenet_bs1024_spos.py', - # 'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py', 'mmcls::_base_/default_runtime.py', ] +# optim_wrapper=None +model_wrapper_cfg = None +optim_wrapper = dict(_delete_=True, type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)) + nas_backbone = dict( _scope_='mmrazor', type='MasterNet', @@ -15,25 +18,11 @@ supernet = dict( type='ImageClassifier', data_preprocessor=_base_.preprocess_cfg, - backbone=nas_backbone, - # neck=dict(type='GlobalAveragePooling'), - # head=dict( - # type='LinearClsHead', - # num_classes=1000, - # in_channels=1024, - # loss=dict( - # type='LabelSmoothLoss', - # num_classes=1000, - # label_smooth_val=0.1, - # mode='original', - # loss_weight=1.0), - # topk=(1, 5)) -) + backbone=nas_backbone) model = dict( type='mmrazor.ZenNAS', - architecture=supernet, - # mutator=dict(type='mmrazor.OneShotModuleMutator') + architecture=supernet ) -find_unused_parameters = True +# find_unused_parameters = True diff --git a/mmrazor/engine/runner/zero_shot_loop.py b/mmrazor/engine/runner/zero_shot_loop.py index d372a8736..1aa8b15c2 100644 --- a/mmrazor/engine/runner/zero_shot_loop.py +++ b/mmrazor/engine/runner/zero_shot_loop.py @@ -3,16 +3,71 @@ from mmengine.evaluator import Evaluator from mmengine.hooks import CheckpointHook -from mmengine.runner import ValLoop, BaseLoop, TestLoop +from mmengine.runner import ValLoop, BaseLoop, TestLoop, EpochBasedTrainLoop from torch.utils.data import DataLoader from mmrazor.models.utils import add_prefix from mmrazor.registry import LOOPS, TASK_UTILS -# from .utils import CalibrateBNMixin + +import mmrazor.models.architectures.backbones.global_utils as global_utils +import torch, time, random +import numpy as np + +from mmrazor.models.architectures.backbones import MasterNet +from mmrazor.models.architectures.backbones.PlainNet.basic_blocks import build_netblock_list_from_str +# from mmrazor.models.architectures.backbones.ZeroShotProxy import compute_zen_score, compute_te_nas_score, compute_syncflow_score, compute_gradnorm_score, compute_NASWOT_score +from mmrazor.models.architectures.backbones.ZeroShotProxy import compute_zen_score + +def get_splitted_structure_str(AnyPlainNet, structure_str, num_classes): + the_net = AnyPlainNet(num_classes=num_classes, plainnet_struct=structure_str, no_create=True) + assert hasattr(the_net, 'split') + splitted_net_str = the_net.split(split_layer_threshold=6) + return splitted_net_str + +def get_new_random_structure_str(AnyPlainNet, structure_str, num_classes, get_search_space_func, + num_replaces=1): + the_net = AnyPlainNet(num_classes=num_classes, plainnet_struct=structure_str, no_create=True) + assert isinstance(the_net, MasterNet) + selected_random_id_set = set() + for replace_count in range(num_replaces): + random_id = random.randint(0, len(the_net.block_list) - 1) + if random_id in selected_random_id_set: + continue + selected_random_id_set.add(random_id) + to_search_student_blocks_list_list = get_search_space_func(the_net.block_list, random_id) + + to_search_student_blocks_list = [x for sublist in to_search_student_blocks_list_list for x in sublist] + new_student_block_str = random.choice(to_search_student_blocks_list) + + if len(new_student_block_str) > 0: + # new_student_block = PlainNet.create_netblock_list_from_str(new_student_block_str, no_create=True) + # new_student_block = create_netblock_list_from_str(new_student_block_str, no_create=True) + new_student_block = build_netblock_list_from_str(new_student_block_str, no_create=True) + assert len(new_student_block) == 1 + new_student_block = new_student_block[0] + if random_id > 0: + last_block_out_channels = the_net.block_list[random_id - 1].out_channels + new_student_block.set_in_channels(last_block_out_channels) + the_net.block_list[random_id] = new_student_block + else: + # replace with empty block + the_net.block_list[random_id] = None + pass # end for + + # adjust channels and remove empty layer + tmp_new_block_list = [x for x in the_net.block_list if x is not None] + last_channels = the_net.block_list[0].out_channels + for block in tmp_new_block_list[1:]: + block.set_in_channels(last_channels) + last_channels = block.out_channels + the_net.block_list = tmp_new_block_list + + new_random_structure_str = the_net.split(split_layer_threshold=6) + return new_random_structure_str @LOOPS.register_module() -class ZeroShotLoop(ValLoop): +class ZeroShotLoop(EpochBasedTrainLoop): """Loop for subnet validation in NAS with BN re-calibration. Args: @@ -36,68 +91,224 @@ def __init__( runner, dataloader: Union[DataLoader, Dict], evaluator: Union[Evaluator, Dict, List], + search_space: Optional[str] = '', + plainnet_struct_txt: Optional[str] = '', + zero_shot_score: Optional[str] = 'Zen', + max_epochs: int = 20, + population_size: int = 512, + num_classes: int = 1000, + max_layers: int = 10, + budget_model_size: Optional[int] = None, + budget_flops: Optional[int] = None, + batch_size: Optional[int] = 64, + input_image_size: Optional[int] = 224, + gamma: Optional[float] = 1e-2, fp16: bool = False, evaluate_fixed_subnet: bool = False, calibrate_sample_num: int = 4096, estimator_cfg: Optional[Dict] = dict(type='mmrazor.ResourceEstimator') ) -> None: - super().__init__(runner, dataloader) - # super().__init__(runner, dataloader, evaluator, fp16) + super().__init__(runner, dataloader, max_epochs) - if self.runner.distributed: - model = self.runner.model.module - else: - model = self.runner.model + # if self.runner.distributed: + # model = self.runner.model.module + # else: + model = self.runner.model + + self.gpu=0 + self.gamma = gamma + self.input_image_size = input_image_size + self.batch_size = batch_size + self.zero_shot_score = zero_shot_score + self.population_size = population_size + self.num_classes = num_classes + self.max_layers = max_layers + self.budget_model_size = budget_model_size + self.budget_flops = budget_flops self.model = model self.evaluate_fixed_subnet = evaluate_fixed_subnet self.calibrate_sample_num = calibrate_sample_num - self.estimator = TASK_UTILS.build(estimator_cfg) + # self.estimator = TASK_UTILS.build(estimator_cfg) # remove CheckpointHook to avoid extra problems. - for hook in self.runner._hooks: - if isinstance(hook, CheckpointHook): - self.runner._hooks.remove(hook) - break - - def run(self): - """Launch validation.""" - self.runner.call_hook('before_val') - self.runner.call_hook('before_val_epoch') - - all_metrics = dict() - - if self.evaluate_fixed_subnet: - metrics = self._evaluate_once() - all_metrics.update(add_prefix(metrics, 'fix_subnet')) - elif hasattr(self.model, 'sample_kinds'): - for kind in self.model.sample_kinds: - if kind == 'max': - self.model.set_max_subnet() - metrics = self._evaluate_once() - all_metrics.update(add_prefix(metrics, 'max_subnet')) - elif kind == 'min': - self.model.set_min_subnet() - metrics = self._evaluate_once() - all_metrics.update(add_prefix(metrics, 'min_subnet')) - elif 'random' in kind: - self.model.set_subnet(self.model.sample_subnet()) - metrics = self._evaluate_once() - all_metrics.update(add_prefix(metrics, f'{kind}_subnet')) - - self.runner.call_hook('after_val_epoch', metrics=all_metrics) - self.runner.call_hook('after_val') - - def _evaluate_once(self) -> Dict: - """Evaluate a subnet once with BN re-calibration.""" - self.calibrate_bn_statistics(self.runner.train_dataloader, - self.calibrate_sample_num) - self.runner.model.eval() - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - resource_metrics = self.estimator.estimate(self.model) - metrics.update(resource_metrics) - - return metrics + # for hook in self.runner._hooks: + # if isinstance(hook, CheckpointHook): + # self.runner._hooks.remove(hook) + # break + + ## + # load search space config .py file + self.select_search_space = global_utils.load_py_module_from_path(search_space) + + # load masternet + # fix_subnet = plainnet_struct_txt + self.initial_structure_str = str(self.model.architecture.backbone) + + def run(self) -> None: + """Launch searching.""" + self.runner.call_hook('before_train') + + # if self.predictor_cfg is not None: + # self._init_predictor() + + # if self.resume_from: + # self._resume() + + self.popu_structure_list = [] + self.popu_zero_shot_score_list = [] + self.popu_latency_list = [] + + # while self._epoch < self._max_epochs: + self.start_timer = time.time() + for loop_count in range(self._max_epochs): + self.run_epoch(loop_count) + # self._save_searcher_ckpt() + + # self._save_best_fix_subnet() + + self.runner.call_hook('after_train') + + def run_epoch(self, loop_count) -> None: + """Iterate one epoch. + + Steps: + 1. Sample some new candidates from the supernet. Then Append them + to the candidates, Thus make its number equal to the specified + number. + 2. Validate these candidates(step 1) and update their scores. + 3. Pick the top k candidates based on the scores(step 2), which + will be used in mutation and crossover. + 4. Implement Mutation and crossover, generate better candidates. + """ + + while len(self.popu_structure_list) > self.population_size: # 512个候选网络 + min_zero_shot_score = min(self.popu_zero_shot_score_list) + tmp_idx = self.popu_zero_shot_score_list.index(min_zero_shot_score) + self.popu_zero_shot_score_list.pop(tmp_idx) + self.popu_structure_list.pop(tmp_idx) + self.popu_latency_list.pop(tmp_idx) + pass + + if loop_count >= 1 and loop_count % 100 == 0: + max_score = max(self.popu_zero_shot_score_list) + min_score = min(self.popu_zero_shot_score_list) + elasp_time = time.time() - self.start_timer + self.runner.logger.info(f'loop_count={loop_count}/{self._max_epochs}, max_score={max_score:4g}, min_score={min_score:4g}, time={elasp_time/3600:4g}h') + + # ----- generate a random structure ----- # + random_structure_str = self.sample_candidates() + + the_model = None + # 经过筛选 + # max_layers / budget_model_size / budget_flops / budget_latency=0.0001 + if self.max_layers is not None: # 10 + if the_model is None: + the_model = MasterNet(num_classes=self.num_classes, plainnet_struct=random_structure_str, + no_create=True, no_reslink=False) # 这里去除VCNN指的是只有包含RELU激活函数的卷积层构成的网络,不包含BN、残差块等算子,并且只取到GAP(global average pool)层的前一层以保留更多的信息。 + the_layers = the_model.get_num_layers() + if self.max_layers < the_layers: + return + + if self.budget_model_size is not None: + if the_model is None: + the_model = MasterNet(num_classes=self.num_classes, plainnet_struct=random_structure_str, + no_create=True, no_reslink=False) + the_model_size = the_model.get_model_size() + if self.budget_model_size < the_model_size: + return + + if self.budget_flops is not None: + if the_model is None: + the_model = MasterNet(num_classes=self.num_classes, plainnet_struct=random_structure_str, + no_create=True, no_reslink=False) + the_model_flops = the_model.get_FLOPs(self.input_image_size) + if self.budget_flops < the_model_flops: + return + + the_latency = np.inf + """ # latency + if self.budget_latency is not None: # 0.0001 + the_latency = get_latency(MasterNet, random_structure_str, gpu, args) + if self.budget_latency < the_latency: + return + """ + + the_nas_core = self.compute_nas_score(MasterNet, random_structure_str, self.gpu) + + self.popu_structure_list.append(random_structure_str) + self.popu_zero_shot_score_list.append(the_nas_core) + self.popu_latency_list.append(the_latency) + + self._epoch += 1 + + def sample_candidates(self): + # ----- generate a random structure ----- # + if len(self.popu_structure_list) <= 10: + random_structure_str = get_new_random_structure_str( + AnyPlainNet=MasterNet, structure_str=self.initial_structure_str, num_classes=self.num_classes, + get_search_space_func=self.select_search_space.gen_search_space, num_replaces=1) + else: + tmp_idx = random.randint(0, len(self.popu_structure_list) - 1) + tmp_random_structure_str = self.popu_structure_list[tmp_idx] + random_structure_str = get_new_random_structure_str( + AnyPlainNet=MasterNet, structure_str=tmp_random_structure_str, num_classes=self.num_classes, + get_search_space_func=self.select_search_space.gen_search_space, num_replaces=2) + + random_structure_str = get_splitted_structure_str(MasterNet, random_structure_str, + num_classes=self.num_classes) + return random_structure_str + + def compute_nas_score(self, AnyPlainNet, random_structure_str, gpu): + # compute network zero-shot proxy score + the_model = AnyPlainNet(num_classes=self.num_classes, plainnet_struct=random_structure_str, + no_create=False, no_reslink=True) + the_model = the_model.cuda(gpu) + try: + if self.zero_shot_score == 'Zen': + the_nas_core_info = compute_zen_score.compute_nas_score(model=the_model, gpu=gpu, + resolution=self.input_image_size, + mixup_gamma=self.gamma, batch_size=self.batch_size, + repeat=1) + the_nas_core = the_nas_core_info['avg_nas_score'] + """ + elif self.zero_shot_score == 'TE-NAS': + the_nas_core = compute_te_nas_score.compute_NTK_score(model=the_model, gpu=gpu, + resolution=self.input_image_size, + batch_size=self.batch_size) + + elif self.zero_shot_score == 'Syncflow': + the_nas_core = compute_syncflow_score.do_compute_nas_score(model=the_model, gpu=gpu, + resolution=self.input_image_size, + batch_size=self.batch_size) + + elif self.zero_shot_score == 'GradNorm': + the_nas_core = compute_gradnorm_score.compute_nas_score(model=the_model, gpu=gpu, + resolution=self.input_image_size, + batch_size=self.batch_size) + + elif self.zero_shot_score == 'Flops': + the_nas_core = the_model.get_FLOPs(self.input_image_size) + + elif self.zero_shot_score == 'Params': + the_nas_core = the_model.get_model_size() + + elif self.zero_shot_score == 'Random': + the_nas_core = np.random.randn() + + elif self.zero_shot_score == 'NASWOT': + the_nas_core = compute_NASWOT_score.compute_nas_score(gpu=gpu, model=the_model, + resolution=self.input_image_size, + batch_size=self.batch_size) + """ + except Exception as err: + # logging.info(str(err)) + # logging.info('--- Failed structure: ') + # logging.info(str(the_model)) + # raise err + the_nas_core = -9999 + + + del the_model + torch.cuda.empty_cache() + return the_nas_core diff --git a/mmrazor/models/algorithms/nas/zennas.py b/mmrazor/models/algorithms/nas/zennas.py index f7afd6aa0..35cc7af4e 100644 --- a/mmrazor/models/algorithms/nas/zennas.py +++ b/mmrazor/models/algorithms/nas/zennas.py @@ -103,6 +103,38 @@ def __init__(self, self.norm_training = norm_training + # def parameters(self): + # """split network weights into to categlories, + # one are weights in conv layer and linear layer, + # others are other learnable paramters(conv bias, + # bn weights, bn bias, linear bias) + # Args: + # net: network architecture + + # Returns: + # a dictionary of params splite into to categlories + # """ + + # decay = [] + # no_decay = [] + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + # decay.append(m.weight) + + # if m.bias is not None: + # no_decay.append(m.bias) + + # else: + # if hasattr(m, 'weight'): + # no_decay.append(m.weight) + # if hasattr(m, 'bias'): + # no_decay.append(m.bias) + + # assert len(list(self.parameters())) == len(decay) + len(no_decay) + + # return [dict(params=decay), dict(params=no_decay, weight_decay=0)] + def sample_subnet(self) -> SingleMutatorRandomSubnet: """Random sample subnet by mutator.""" return self.mutator.sample_choices() diff --git a/mmrazor/models/architectures/backbones/PlainNet/basic_blocks.py b/mmrazor/models/architectures/backbones/PlainNet/basic_blocks.py index 8e69e1f2c..f366afd90 100644 --- a/mmrazor/models/architectures/backbones/PlainNet/basic_blocks.py +++ b/mmrazor/models/architectures/backbones/PlainNet/basic_blocks.py @@ -176,7 +176,9 @@ def _build_netblock_list_from_str_(s, no_create=False, **kwargs): the_block_cfg, remaining_s = the_block_class.create_from_str(s, no_create=no_create, **kwargs) mutable_cfg.update(the_block_cfg) # print(mutable_cfg) - the_block = MODELS.build(mutable_cfg) + # the_block = MODELS.build(mutable_cfg) + mutable_cfg.pop('type') + the_block = the_block_class(**mutable_cfg) # the_block_class = _all_netblocks_dict_[the_block_class_name] # the_block, remaining_s = the_block_class.create_from_str(s, no_create=no_create, **kwargs) if the_block is not None: diff --git a/mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py b/mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py index 9850633cd..ea434f805 100644 --- a/mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py +++ b/mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py @@ -1,12 +1,13 @@ -import os, sys +# import os, sys -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) import itertools import global_utils # from PlainNet import basic_blocks, super_blocks, SuperResKXKX, SuperResK1KXK1 -from PlainNet import basic_blocks +# from PlainNet import basic_blocks +from mmrazor.models.architectures.backbones.PlainNet import basic_blocks seach_space_block_type_list_list = [ [basic_blocks.SuperResK1K3K1, basic_blocks.SuperResK1K5K1, basic_blocks.SuperResK1K7K1], diff --git a/slurm_zenscore.sh b/slurm_zenscore.sh index b72ce5aa1..1ac4949b9 100755 --- a/slurm_zenscore.sh +++ b/slurm_zenscore.sh @@ -1,14 +1,35 @@ set -e +save_dir=work_dirs/Zen_NAS_cifar_params1M +mkdir -p ${save_dir} + +echo "SuperConvK3BNRELU(3,8,1,1)SuperResK3K3(8,16,1,8,1)SuperResK3K3(16,32,2,16,1)SuperResK3K3(32,64,2,32,1)SuperResK3K3(64,64,2,32,1)SuperConvK1BNRELU(64,128,1,1)" \ +> ${save_dir}/init_plainnet.txt + ./command.sh search_kk 1 1 "python mmrazor/models/architectures/backbones/masternet.py --gpu 0 \ --zero_shot_score Zen \ --search_space mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py \ - --budget_latency 1e-4 \ - --max_layers 10 \ + --budget_model_size 1e6 \ + --max_layers 18 \ --batch_size 64 \ - --input_image_size 224 \ - --plainnet_struct_txt ./work_dirs/1ms/init_plainnet.txt \ - --num_classes 1000 \ - --evolution_max_iter 20000 \ + --input_image_size 32 \ + --plainnet_struct_txt work_dirs/Zen_NAS_cifar_params1M/init_plainnet.txt \ + --num_classes 100 \ + --evolution_max_iter 480000 \ --population_size 512 \ - --save_dir ./work_dirs/1ms" \ No newline at end of file + --save_dir work_dirs/Zen_NAS_cifar_params1M" + + +# 1ms +# ./command.sh search_kk 1 1 "python mmrazor/models/architectures/backbones/masternet.py --gpu 0 \ +# --zero_shot_score Zen \ +# --search_space mmrazor/models/architectures/backbones/SearchSpace/search_space_XXBL.py \ +# --budget_latency 1e-4 \ +# --max_layers 10 \ +# --batch_size 64 \ +# --input_image_size 224 \ +# --plainnet_struct_txt ./work_dirs/1ms/init_plainnet.txt \ +# --num_classes 1000 \ +# --evolution_max_iter 20000 \ +# --population_size 512 \ +# --save_dir ./work_dirs/1ms" \ No newline at end of file