Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support rewritting the origin model in mmrazor #1915

Open
wants to merge 7 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions mmdeploy/apis/onnx/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,15 @@ def _add_or_update(cfg: dict, key: str, val: Any):
if 'onnx_custom_passes' not in context_info:
onnx_custom_passes = optimize_onnx if optimize else None
context_info['onnx_custom_passes'] = onnx_custom_passes

with RewriterContext(**context_info), torch.no_grad():

from mmrazor.models import MMArchitectureQuant
from_mmrazor = isinstance(patched_model, MMArchitectureQuant)
if from_mmrazor:
quantizer = patched_model.quantizer
patched_model = patched_model.get_deploy_model()

# patch input_metas
if input_metas is not None:
assert isinstance(
Expand All @@ -128,17 +136,31 @@ def wrapper(*arg, **kwargs):
patched_model.forward = partial(patched_model.forward,
**input_metas)

torch.onnx.export(
patched_model,
args,
output_path,
export_params=True,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
verbose=verbose)
if from_mmrazor:
quantizer.export_onnx(
patched_model,
args,
output_path,
export_params=True,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
verbose=verbose)

else:
torch.onnx.export(
patched_model,
args,
output_path,
export_params=True,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
verbose=verbose)

if input_metas is not None:
patched_model.forward = model_forward
2 changes: 1 addition & 1 deletion mmdeploy/apis/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def build_task_processor(model_cfg: mmengine.Config,
BaseTask: A task processor.
"""
check_backend_device(deploy_cfg=deploy_cfg, device=device)
codebase_type = get_codebase(deploy_cfg)
codebase_type = get_codebase(deploy_cfg, model_cfg=model_cfg)
custom_module_list = get_codebase_external_module(deploy_cfg)
import_codebase(codebase_type, custom_module_list)
codebase = get_codebase_class(codebase_type)
Expand Down
4 changes: 4 additions & 0 deletions mmdeploy/codebase/mmrazor/deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mmrazor import MMCodebase, MMRazor

__all__ = ['MMRazor', 'MMCodebase']
135 changes: 135 additions & 0 deletions mmdeploy/codebase/mmrazor/deploy/mmrazor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, Optional, Tuple, Union

import numpy as np
import torch
from mmengine import Config
from mmengine.model import BaseDataPreprocessor
from mmengine.registry import Registry

from mmdeploy.apis.utils import build_task_processor
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, Task

MMRAZOR_TASK = Registry('mmrazor_tasks')


@CODEBASE.register_module(Codebase.MMRAZOR.value)
class MMRazor(MMCodebase):
"""MMRazor codebase class."""
task_registry = MMRAZOR_TASK

@classmethod
def register_deploy_modules(cls):
"""Register all rewriters for mmrazor."""
pass

@classmethod
def register_all_modules(cls):
"""Register all related modules and rewriters for mmrazor."""
from mmrazor.utils import register_all_modules
register_all_modules(True)

@classmethod
def build_task_processor(cls, model_cfg: Config, deploy_cfg: Config,
device: str):
"""Build task processor for mmrazor.

Now we use ModelCompress by default.
"""
return ModelCompress(
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device)


@MMRAZOR_TASK.register_module(Task.ModelCompress.value)
class ModelCompress(BaseTask):
"""General model compress task for mmrazor.

Args:
model_cfg (Config): Original PyTorch model config file
deploy_cfg (Config): Deployment config file or loaded Config
object.
device (str): A string represents device type.
experiment_name (str, optional): Name of current experiment.
If not specified, timestamp will be used as
``experiment_name``. Defaults to ``None``.
"""

def __init__(self,
model_cfg: Config,
deploy_cfg: Config,
device: str,
experiment_name: str = 'BaseTask'):

super().__init__(model_cfg, deploy_cfg, device, experiment_name)
self.origin_model_cfg = self.revert_model_cfg(model_cfg)
self.base_task = build_task_processor(self.origin_model_cfg,
deploy_cfg, device)

def revert_model_cfg(self, model_cfg: Config):
"""Restore the original model config from the model config of the
compressed model."""
origin_model_cfg = copy.deepcopy(model_cfg)
model = model_cfg['model']
if 'architecture' in model:
origin_model = model['architecture']
elif 'algorithm' in model:
origin_model = model['algorithm']['architecture']
else:
raise NotImplementedError()
origin_model_cfg['model'] = origin_model
if 'data_preprocessor' in origin_model:
origin_model_cfg['data_preprocessor'] = origin_model[
'data_preprocessor']
return origin_model_cfg

# abstract method

def build_backend_model(self,
model_files=None,
data_preprocessor_updater=None,
**kwargs) -> torch.nn.Module:
"""Build backend model for using base task."""
return self.base_task.build_backend_model(model_files,
data_preprocessor_updater,
**kwargs)

def create_input(self,
imgs: Union[str, np.ndarray],
input_shape=None,
data_preprocessor: Optional[BaseDataPreprocessor] = None,
**kwargs) -> Tuple[Dict, torch.Tensor]:
"""Create input using base task."""
return self.base_task.create_input(imgs, input_shape,
data_preprocessor, **kwargs)

def get_model_name(self, *args, **kwargs) -> str:
"""Get model name using base task."""
return self.base_task.get_model_name(*args, **kwargs)

def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get data preprocess name using base task."""
return self.base_task.get_preprocess(*args, **kwargs)

def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get data poseprocess name using base task."""
return self.base_task.get_postprocess(*args, **kwargs)

@staticmethod
def get_partition_cfg(partition_type: str, **kwargs) -> Dict:
"""Get a certain partition config."""
raise NotImplementedError()

def build_pytorch_model(self,
model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module:
"""Build PyTorch model for mmrazor and execute post process for
mmdeploy."""
model = super().build_pytorch_model(model_checkpoint, cfg_options,
**kwargs)
if hasattr(model, 'post_process_for_mmdeploy'):
model.post_process_for_mmdeploy()

return model
9 changes: 8 additions & 1 deletion mmdeploy/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def register_codebase(codebase: str) -> Codebase:
return Codebase.get(codebase)


def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
def get_codebase(deploy_cfg: Union[str, mmengine.Config],
model_cfg=None) -> Codebase:
"""Get the codebase from the config.

Args:
Expand All @@ -92,6 +93,12 @@ def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
Returns:
Codebase : An enumeration denotes the codebase type.
"""
if model_cfg is not None:
# using mmrazor codebase if the model is a mmrazor model.
model_cfg: dict = model_cfg['model']
if model_cfg.get('_scope_', None) == 'mmrazor'\
or model_cfg['type'].startswith('mmrazor.'):
return register_codebase('mmrazor')
codebase_config = get_codebase_config(deploy_cfg)
assert 'type' in codebase_config, 'The codebase config of deploy config'\
'requires a "type" field'
Expand Down
2 changes: 2 additions & 0 deletions mmdeploy/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Task(AdvancedEnum):
POSE_DETECTION = 'PoseDetection'
ROTATED_DETECTION = 'RotatedDetection'
VIDEO_RECOGNITION = 'VideoRecognition'
ModelCompress = 'ModelCompress'


class Codebase(AdvancedEnum):
Expand All @@ -41,6 +42,7 @@ class Codebase(AdvancedEnum):
MMPOSE = 'mmpose'
MMROTATE = 'mmrotate'
MMACTION = 'mmaction'
MMRAZOR = 'mmrazor'


class IR(AdvancedEnum):
Expand Down