diff --git a/mmdeploy/apis/onnx/export.py b/mmdeploy/apis/onnx/export.py index 92a9002d8d..a5d6adfb1d 100644 --- a/mmdeploy/apis/onnx/export.py +++ b/mmdeploy/apis/onnx/export.py @@ -109,7 +109,15 @@ def _add_or_update(cfg: dict, key: str, val: Any): if 'onnx_custom_passes' not in context_info: onnx_custom_passes = optimize_onnx if optimize else None context_info['onnx_custom_passes'] = onnx_custom_passes + with RewriterContext(**context_info), torch.no_grad(): + + from mmrazor.models import MMArchitectureQuant + from_mmrazor = isinstance(patched_model, MMArchitectureQuant) + if from_mmrazor: + quantizer = patched_model.quantizer + patched_model = patched_model.get_deploy_model() + # patch input_metas if input_metas is not None: assert isinstance( @@ -128,17 +136,31 @@ def wrapper(*arg, **kwargs): patched_model.forward = partial(patched_model.forward, **input_metas) - torch.onnx.export( - patched_model, - args, - output_path, - export_params=True, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - keep_initializers_as_inputs=keep_initializers_as_inputs, - verbose=verbose) + if from_mmrazor: + quantizer.export_onnx( + patched_model, + args, + output_path, + export_params=True, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + verbose=verbose) + + else: + torch.onnx.export( + patched_model, + args, + output_path, + export_params=True, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + verbose=verbose) if input_metas is not None: patched_model.forward = model_forward diff --git a/mmdeploy/apis/utils/utils.py b/mmdeploy/apis/utils/utils.py index d7630e6637..7fbba34d1a 100644 --- a/mmdeploy/apis/utils/utils.py +++ b/mmdeploy/apis/utils/utils.py @@ -41,7 +41,7 @@ def build_task_processor(model_cfg: mmengine.Config, BaseTask: A task processor. """ check_backend_device(deploy_cfg=deploy_cfg, device=device) - codebase_type = get_codebase(deploy_cfg) + codebase_type = get_codebase(deploy_cfg, model_cfg=model_cfg) custom_module_list = get_codebase_external_module(deploy_cfg) import_codebase(codebase_type, custom_module_list) codebase = get_codebase_class(codebase_type) diff --git a/mmdeploy/codebase/mmrazor/deploy/__init__.py b/mmdeploy/codebase/mmrazor/deploy/__init__.py new file mode 100644 index 0000000000..45068849c3 --- /dev/null +++ b/mmdeploy/codebase/mmrazor/deploy/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mmrazor import MMCodebase, MMRazor + +__all__ = ['MMRazor', 'MMCodebase'] diff --git a/mmdeploy/codebase/mmrazor/deploy/mmrazor.py b/mmdeploy/codebase/mmrazor/deploy/mmrazor.py new file mode 100644 index 0000000000..19428909b4 --- /dev/null +++ b/mmdeploy/codebase/mmrazor/deploy/mmrazor.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import torch +from mmengine import Config +from mmengine.model import BaseDataPreprocessor +from mmengine.registry import Registry + +from mmdeploy.apis.utils import build_task_processor +from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase +from mmdeploy.utils import Codebase, Task + +MMRAZOR_TASK = Registry('mmrazor_tasks') + + +@CODEBASE.register_module(Codebase.MMRAZOR.value) +class MMRazor(MMCodebase): + """MMRazor codebase class.""" + task_registry = MMRAZOR_TASK + + @classmethod + def register_deploy_modules(cls): + """Register all rewriters for mmrazor.""" + pass + + @classmethod + def register_all_modules(cls): + """Register all related modules and rewriters for mmrazor.""" + from mmrazor.utils import register_all_modules + register_all_modules(True) + + @classmethod + def build_task_processor(cls, model_cfg: Config, deploy_cfg: Config, + device: str): + """Build task processor for mmrazor. + + Now we use ModelCompress by default. + """ + return ModelCompress( + model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device) + + +@MMRAZOR_TASK.register_module(Task.ModelCompress.value) +class ModelCompress(BaseTask): + """General model compress task for mmrazor. + + Args: + model_cfg (Config): Original PyTorch model config file + deploy_cfg (Config): Deployment config file or loaded Config + object. + device (str): A string represents device type. + experiment_name (str, optional): Name of current experiment. + If not specified, timestamp will be used as + ``experiment_name``. Defaults to ``None``. + """ + + def __init__(self, + model_cfg: Config, + deploy_cfg: Config, + device: str, + experiment_name: str = 'BaseTask'): + + super().__init__(model_cfg, deploy_cfg, device, experiment_name) + self.origin_model_cfg = self.revert_model_cfg(model_cfg) + self.base_task = build_task_processor(self.origin_model_cfg, + deploy_cfg, device) + + def revert_model_cfg(self, model_cfg: Config): + """Restore the original model config from the model config of the + compressed model.""" + origin_model_cfg = copy.deepcopy(model_cfg) + model = model_cfg['model'] + if 'architecture' in model: + origin_model = model['architecture'] + elif 'algorithm' in model: + origin_model = model['algorithm']['architecture'] + else: + raise NotImplementedError() + origin_model_cfg['model'] = origin_model + if 'data_preprocessor' in origin_model: + origin_model_cfg['data_preprocessor'] = origin_model[ + 'data_preprocessor'] + return origin_model_cfg + + # abstract method + + def build_backend_model(self, + model_files=None, + data_preprocessor_updater=None, + **kwargs) -> torch.nn.Module: + """Build backend model for using base task.""" + return self.base_task.build_backend_model(model_files, + data_preprocessor_updater, + **kwargs) + + def create_input(self, + imgs: Union[str, np.ndarray], + input_shape=None, + data_preprocessor: Optional[BaseDataPreprocessor] = None, + **kwargs) -> Tuple[Dict, torch.Tensor]: + """Create input using base task.""" + return self.base_task.create_input(imgs, input_shape, + data_preprocessor, **kwargs) + + def get_model_name(self, *args, **kwargs) -> str: + """Get model name using base task.""" + return self.base_task.get_model_name(*args, **kwargs) + + def get_preprocess(self, *args, **kwargs) -> Dict: + """Get data preprocess name using base task.""" + return self.base_task.get_preprocess(*args, **kwargs) + + def get_postprocess(self, *args, **kwargs) -> Dict: + """Get data poseprocess name using base task.""" + return self.base_task.get_postprocess(*args, **kwargs) + + @staticmethod + def get_partition_cfg(partition_type: str, **kwargs) -> Dict: + """Get a certain partition config.""" + raise NotImplementedError() + + def build_pytorch_model(self, + model_checkpoint: Optional[str] = None, + cfg_options: Optional[Dict] = None, + **kwargs) -> torch.nn.Module: + """Build PyTorch model for mmrazor and execute post process for + mmdeploy.""" + model = super().build_pytorch_model(model_checkpoint, cfg_options, + **kwargs) + if hasattr(model, 'post_process_for_mmdeploy'): + model.post_process_for_mmdeploy() + + return model diff --git a/mmdeploy/utils/config_utils.py b/mmdeploy/utils/config_utils.py index fb2db9af5f..5565596fee 100644 --- a/mmdeploy/utils/config_utils.py +++ b/mmdeploy/utils/config_utils.py @@ -83,7 +83,8 @@ def register_codebase(codebase: str) -> Codebase: return Codebase.get(codebase) -def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase: +def get_codebase(deploy_cfg: Union[str, mmengine.Config], + model_cfg=None) -> Codebase: """Get the codebase from the config. Args: @@ -92,6 +93,12 @@ def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase: Returns: Codebase : An enumeration denotes the codebase type. """ + if model_cfg is not None: + # using mmrazor codebase if the model is a mmrazor model. + model_cfg: dict = model_cfg['model'] + if model_cfg.get('_scope_', None) == 'mmrazor'\ + or model_cfg['type'].startswith('mmrazor.'): + return register_codebase('mmrazor') codebase_config = get_codebase_config(deploy_cfg) assert 'type' in codebase_config, 'The codebase config of deploy config'\ 'requires a "type" field' diff --git a/mmdeploy/utils/constants.py b/mmdeploy/utils/constants.py index c56175cbb0..9b929b4cb7 100644 --- a/mmdeploy/utils/constants.py +++ b/mmdeploy/utils/constants.py @@ -28,6 +28,7 @@ class Task(AdvancedEnum): POSE_DETECTION = 'PoseDetection' ROTATED_DETECTION = 'RotatedDetection' VIDEO_RECOGNITION = 'VideoRecognition' + ModelCompress = 'ModelCompress' class Codebase(AdvancedEnum): @@ -41,6 +42,7 @@ class Codebase(AdvancedEnum): MMPOSE = 'mmpose' MMROTATE = 'mmrotate' MMACTION = 'mmaction' + MMRAZOR = 'mmrazor' class IR(AdvancedEnum):