From 9c747f33cfab03e349efba2786176ca848743f50 Mon Sep 17 00:00:00 2001 From: Yury Date: Fri, 7 Jun 2024 20:48:54 +0300 Subject: [PATCH 1/4] Enable JIT for inference --- server/const.py | 4 ++ .../RVC/inferencer/RVCInferencerv2.py | 61 ++++++++++++------- .../RVC/inferencer/RVCInferencerv2Nono.py | 57 +++++++++++------ .../rvc_models/infer_pack/models.py | 8 +-- server/voice_changer/common/rmvpe/rmvpe.py | 31 +++++++--- 5 files changed, 108 insertions(+), 53 deletions(-) diff --git a/server/const.py b/server/const.py index 74b006ef5..ab1123093 100644 --- a/server/const.py +++ b/server/const.py @@ -16,6 +16,10 @@ DOTENV_FILE = os.path.join(ROOT_PATH, '.env') STORED_SETTING_FILE = os.path.join(ROOT_PATH, 'stored_setting.json') ASSETS_FILE = os.path.join(ROOT_PATH, 'assets.json') +# TODO: Need JIT cache invalidation strategy +JIT_DIR = os.path.join(ROOT_PATH, '.jit') + +os.makedirs(JIT_DIR, exist_ok=True) SERVER_DEVICE_SAMPLE_RATES = [16000, 32000, 44100, 48000, 96000, 192000] diff --git a/server/voice_changer/RVC/inferencer/RVCInferencerv2.py b/server/voice_changer/RVC/inferencer/RVCInferencerv2.py index 066600ed0..a3b23fceb 100644 --- a/server/voice_changer/RVC/inferencer/RVCInferencerv2.py +++ b/server/voice_changer/RVC/inferencer/RVCInferencerv2.py @@ -1,7 +1,8 @@ import torch import json +import os from safetensors import safe_open -from const import EnumInferenceTypes +from const import EnumInferenceTypes, JIT_DIR from voice_changer.common.deviceManager.DeviceManager import DeviceManager from voice_changer.RVC.inferencer.Inferencer import Inferencer from .rvc_models.infer_pack.models import SynthesizerTrnMs768NSFsid @@ -15,18 +16,33 @@ def load_model(self, file: str): is_half = device_manager.use_fp16() self.set_props(EnumInferenceTypes.pyTorchRVCv2, file) - # Keep torch.load for backward compatibility, but discourage the use of this loading method - if file.endswith('.safetensors'): - with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt: - config = json.loads(cpt.metadata()['config']) - model = SynthesizerTrnMs768NSFsid(*config, is_half=is_half).to(dev) - load_model(model, cpt, strict=False) - else: - cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu') - model = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half).to(dev) - model.load_state_dict(cpt["weight"], strict=False) + filename = os.path.splitext(os.path.basename(file))[0] + jit_filename = f'{filename}_{dev.type}_{dev.index}.torchscript' if dev.index is not None else f'{filename}_{dev.type}.torchscript' + jit_file = os.path.join(JIT_DIR, jit_filename) + if not os.path.exists(jit_file): + # Keep torch.load for backward compatibility, but discourage the use of this loading method + if file.endswith('.safetensors'): + with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt: + config = json.loads(cpt.metadata()['config']) + model = SynthesizerTrnMs768NSFsid(*config, is_half=is_half).to(dev) + load_model(model, cpt, strict=False) + else: + cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu') + model = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half).to(dev) + model.load_state_dict(cpt["weight"], strict=False) + model = model.eval() - model.eval().remove_weight_norm() + model.remove_weight_norm() + # FIXME: DirectML backend seems to have issues with JIT. Disable it for now. + if dev.type == 'privateuseone': + self.use_jit = True + else: + model = torch.jit.freeze(torch.jit.script(model)) + torch.jit.save(model, jit_file) + self.use_jit = False + else: + model = torch.jit.load(jit_file) + self.use_jit = False if is_half: model = model.half() @@ -47,15 +63,16 @@ def infer( ) -> torch.Tensor: assert pitch is not None or pitchf is not None, "Pitch or Pitchf is not found." - res = self.model.infer( - feats, - pitch_length, - pitch, - pitchf, - sid, - skip_head=skip_head, - return_length=return_length, - formant_length=formant_length - ) + with torch.jit.optimized_execution(self.use_jit): + res = self.model.infer( + feats, + pitch_length, + pitch, + pitchf, + sid, + skip_head=skip_head, + return_length=return_length, + formant_length=formant_length + ) res = res[0][0, 0] return torch.clip(res, -1.0, 1.0, out=res) diff --git a/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py b/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py index 0628f2814..2dda3735c 100644 --- a/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py +++ b/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py @@ -1,7 +1,8 @@ import torch import json +import os from safetensors import safe_open -from const import EnumInferenceTypes +from const import EnumInferenceTypes, JIT_DIR from voice_changer.common.deviceManager.DeviceManager import DeviceManager from voice_changer.RVC.inferencer.Inferencer import Inferencer from .rvc_models.infer_pack.models import SynthesizerTrnMs768NSFsid_nono @@ -15,18 +16,33 @@ def load_model(self, file: str): is_half = device_manager.use_fp16() self.set_props(EnumInferenceTypes.pyTorchRVCv2Nono, file, is_half) - # Keep torch.load for backward compatibility, but discourage the use of this loading method - if file.endswith('.safetensors'): - with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt: - config = json.loads(cpt.metadata()['config']) - model = SynthesizerTrnMs768NSFsid_nono(*config, is_half=is_half).to(dev) - load_model(model, cpt, strict=False) - else: - cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu') - model = SynthesizerTrnMs768NSFsid_nono(*cpt["config"], is_half=is_half).to(dev) - model.load_state_dict(cpt["weight"], strict=False) + filename = os.path.splitext(os.path.basename(file))[0] + jit_filename = f'{filename}_{dev.type}_{dev.index}.torchscript' if dev.index is not None else f'{filename}_{dev.type}.torchscript' + jit_file = os.path.join(JIT_DIR, jit_filename) + if not os.path.exists(jit_file): + # Keep torch.load for backward compatibility, but discourage the use of this loading method + if file.endswith('.safetensors'): + with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt: + config = json.loads(cpt.metadata()['config']) + model = SynthesizerTrnMs768NSFsid_nono(*config, is_half=is_half).to(dev) + load_model(model, cpt, strict=False) + else: + cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu') + model = SynthesizerTrnMs768NSFsid_nono(*cpt["config"], is_half=is_half).to(dev) + model.load_state_dict(cpt["weight"], strict=False) + model = model.eval() - model.eval().remove_weight_norm() + model.remove_weight_norm() + # FIXME: DirectML backend seems to have issues with JIT. Disable it for now. + if dev.type == 'privateuseone': + self.use_jit = True + else: + model = torch.jit.freeze(torch.jit.script(model)) + torch.jit.save(model, jit_file) + self.use_jit = False + else: + model = torch.jit.load(jit_file) + self.use_jit = False if is_half: model = model.half() @@ -45,13 +61,14 @@ def infer( return_length: int, formant_length: int, ) -> torch.Tensor: - res = self.model.infer( - feats, - pitch_length, - sid, - skip_head=skip_head, - return_length=return_length, - formant_length=formant_length - ) + with torch.jit.optimized_execution(self.use_jit): + res = self.model.infer( + feats, + pitch_length, + sid, + skip_head=skip_head, + return_length=return_length, + formant_length=formant_length + ) res = res[0][0, 0] return torch.clip(res, -1.0, 1.0, out=res) diff --git a/server/voice_changer/RVC/inferencer/rvc_models/infer_pack/models.py b/server/voice_changer/RVC/inferencer/rvc_models/infer_pack/models.py index aed513c44..00543dbb4 100644 --- a/server/voice_changer/RVC/inferencer/rvc_models/infer_pack/models.py +++ b/server/voice_changer/RVC/inferencer/rvc_models/infer_pack/models.py @@ -716,7 +716,7 @@ def __prepare_scriptable__(self): torch.nn.utils.remove_weight_norm(self.enc_q) return self - # @torch.jit.ignore + @torch.jit.ignore def forward( self, phone: torch.Tensor, @@ -741,7 +741,7 @@ def forward( o = self.dec(z_slice, pitchf, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) - # @torch.jit.export + @torch.jit.export def infer( self, phone: torch.Tensor, @@ -937,7 +937,7 @@ def __prepare_scriptable__(self): torch.nn.utils.remove_weight_norm(self.enc_q) return self - # @torch.jit.ignore + @torch.jit.ignore def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1] g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) @@ -949,7 +949,7 @@ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[b o = self.dec(z_slice, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) - # @torch.jit.export + @torch.jit.export def infer( self, phone: torch.Tensor, diff --git a/server/voice_changer/common/rmvpe/rmvpe.py b/server/voice_changer/common/rmvpe/rmvpe.py index c57b12115..f4fcf0f78 100644 --- a/server/voice_changer/common/rmvpe/rmvpe.py +++ b/server/voice_changer/common/rmvpe/rmvpe.py @@ -1,3 +1,4 @@ +import os from typing import List import torch.nn as nn @@ -7,6 +8,7 @@ from safetensors import safe_open from voice_changer.common.SafetensorsUtils import load_model from librosa.filters import mel +from const import JIT_DIR class BiGRU(nn.Module): @@ -328,14 +330,29 @@ def forward(self, audio, keyshift=0, speed=1, center=True): class RMVPE: def __init__(self, model_path: str, is_half: bool, device: torch.device): - model = E2E(4, 1, (2, 2)) - if model_path.endswith('.safetensors'): - with safe_open(model_path, 'pt', device=str(device) if device.type == 'cuda' else 'cpu') as cpt: - load_model(model, cpt, strict=False) + filename = os.path.splitext(os.path.basename(model_path))[0] + jit_filename = f'{filename}_{device.type}_{device.index}.torchscript' if device.index is not None else f'{filename}_{device.type}.torchscript' + jit_file = os.path.join(JIT_DIR, jit_filename) + if not os.path.exists(jit_file): + model = E2E(4, 1, (2, 2)) + if model_path.endswith('.safetensors'): + with safe_open(model_path, 'pt', device=str(device) if device.type == 'cuda' else 'cpu') as cpt: + load_model(model, cpt, strict=False) + else: + cpt = torch.load(model_path, map_location=device if device.type == 'cuda' else 'cpu') + model.load_state_dict(cpt, strict=False) + model = model.eval().to(device) + + # FIXME: DirectML backend seems to have issues with JIT. Disable it for now. + if device.type == 'privateuseone': + self.use_jit = True + else: + model = torch.jit.freeze(torch.jit.script(model)) + torch.jit.save(model, jit_file) + self.use_jit = False else: - cpt = torch.load(model_path, map_location=device if device.type == 'cuda' else 'cpu') - model.load_state_dict(cpt, strict=False) - model.eval().to(device) + model = torch.jit.load(jit_file) + self.use_jit = False if is_half: model = model.half() From b2504b7698e88a7440b85454a8d638ffa86c36d5 Mon Sep 17 00:00:00 2001 From: Yury Date: Sat, 8 Jun 2024 16:57:29 +0300 Subject: [PATCH 2/4] Use optimize_for_inference --- server/voice_changer/RVC/inferencer/RVCInferencerv2.py | 2 +- server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py | 2 +- server/voice_changer/common/rmvpe/rmvpe.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/voice_changer/RVC/inferencer/RVCInferencerv2.py b/server/voice_changer/RVC/inferencer/RVCInferencerv2.py index a3b23fceb..938a7c981 100644 --- a/server/voice_changer/RVC/inferencer/RVCInferencerv2.py +++ b/server/voice_changer/RVC/inferencer/RVCInferencerv2.py @@ -37,7 +37,7 @@ def load_model(self, file: str): if dev.type == 'privateuseone': self.use_jit = True else: - model = torch.jit.freeze(torch.jit.script(model)) + model = torch.jit.optimize_for_inference(torch.jit.script(model), other_methods=['infer']) torch.jit.save(model, jit_file) self.use_jit = False else: diff --git a/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py b/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py index 2dda3735c..c1797c50d 100644 --- a/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py +++ b/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py @@ -37,7 +37,7 @@ def load_model(self, file: str): if dev.type == 'privateuseone': self.use_jit = True else: - model = torch.jit.freeze(torch.jit.script(model)) + model = torch.jit.optimize_for_inference(torch.jit.script(model), other_methods=['infer']) torch.jit.save(model, jit_file) self.use_jit = False else: diff --git a/server/voice_changer/common/rmvpe/rmvpe.py b/server/voice_changer/common/rmvpe/rmvpe.py index f4fcf0f78..84ff5ad61 100644 --- a/server/voice_changer/common/rmvpe/rmvpe.py +++ b/server/voice_changer/common/rmvpe/rmvpe.py @@ -347,7 +347,7 @@ def __init__(self, model_path: str, is_half: bool, device: torch.device): if device.type == 'privateuseone': self.use_jit = True else: - model = torch.jit.freeze(torch.jit.script(model)) + model = torch.jit.optimize_for_inference(torch.jit.script(model)) torch.jit.save(model, jit_file) self.use_jit = False else: From 4fb77b12492394032b72c8a3d8ecb5b37aa9aa1d Mon Sep 17 00:00:00 2001 From: Yury Date: Wed, 31 Jul 2024 20:37:54 +0300 Subject: [PATCH 3/4] Fix JIT and align with FP16 changes --- .../RVC/inferencer/RVCInferencerv2.py | 13 +++++++++---- .../RVC/inferencer/RVCInferencerv2Nono.py | 13 +++++++++---- .../inferencer/rvc_models/infer_pack/models.py | 3 ++- .../RVC/pitchExtractor/FcpePitchExtractor.py | 1 + server/voice_changer/common/rmvpe/rmvpe.py | 15 ++++++++++----- 5 files changed, 31 insertions(+), 14 deletions(-) diff --git a/server/voice_changer/RVC/inferencer/RVCInferencerv2.py b/server/voice_changer/RVC/inferencer/RVCInferencerv2.py index 938a7c981..6c9e3a636 100644 --- a/server/voice_changer/RVC/inferencer/RVCInferencerv2.py +++ b/server/voice_changer/RVC/inferencer/RVCInferencerv2.py @@ -1,6 +1,7 @@ import torch import json import os +import logging from safetensors import safe_open from const import EnumInferenceTypes, JIT_DIR from voice_changer.common.deviceManager.DeviceManager import DeviceManager @@ -8,6 +9,7 @@ from .rvc_models.infer_pack.models import SynthesizerTrnMs768NSFsid from voice_changer.common.SafetensorsUtils import load_model +logger = logging.getLogger(__name__) class RVCInferencerv2(Inferencer): def load_model(self, file: str): @@ -17,7 +19,8 @@ def load_model(self, file: str): self.set_props(EnumInferenceTypes.pyTorchRVCv2, file) filename = os.path.splitext(os.path.basename(file))[0] - jit_filename = f'{filename}_{dev.type}_{dev.index}.torchscript' if dev.index is not None else f'{filename}_{dev.type}.torchscript' + fp_prefix = 'fp16' if is_half else 'fp32' + jit_filename = f'{filename}_{dev.type}_{dev.index}_{fp_prefix}.torchscript' if dev.index is not None else f'{filename}_{dev.type}_{fp_prefix}.torchscript' jit_file = os.path.join(JIT_DIR, jit_filename) if not os.path.exists(jit_file): # Keep torch.load for backward compatibility, but discourage the use of this loading method @@ -33,10 +36,15 @@ def load_model(self, file: str): model = model.eval() model.remove_weight_norm() + + if is_half: + model = model.half() + # FIXME: DirectML backend seems to have issues with JIT. Disable it for now. if dev.type == 'privateuseone': self.use_jit = True else: + logger.info('Compiling JIT model...') model = torch.jit.optimize_for_inference(torch.jit.script(model), other_methods=['infer']) torch.jit.save(model, jit_file) self.use_jit = False @@ -44,9 +52,6 @@ def load_model(self, file: str): model = torch.jit.load(jit_file) self.use_jit = False - if is_half: - model = model.half() - self.model = model return self diff --git a/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py b/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py index c1797c50d..6271c8661 100644 --- a/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py +++ b/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py @@ -1,6 +1,7 @@ import torch import json import os +import logging from safetensors import safe_open from const import EnumInferenceTypes, JIT_DIR from voice_changer.common.deviceManager.DeviceManager import DeviceManager @@ -8,6 +9,7 @@ from .rvc_models.infer_pack.models import SynthesizerTrnMs768NSFsid_nono from voice_changer.common.SafetensorsUtils import load_model +logger = logging.getLogger(__name__) class RVCInferencerv2Nono(Inferencer): def load_model(self, file: str): @@ -17,7 +19,8 @@ def load_model(self, file: str): self.set_props(EnumInferenceTypes.pyTorchRVCv2Nono, file, is_half) filename = os.path.splitext(os.path.basename(file))[0] - jit_filename = f'{filename}_{dev.type}_{dev.index}.torchscript' if dev.index is not None else f'{filename}_{dev.type}.torchscript' + fp_prefix = 'fp16' if is_half else 'fp32' + jit_filename = f'{filename}_{dev.type}_{dev.index}_{fp_prefix}.torchscript' if dev.index is not None else f'{filename}_{dev.type}_{fp_prefix}.torchscript' jit_file = os.path.join(JIT_DIR, jit_filename) if not os.path.exists(jit_file): # Keep torch.load for backward compatibility, but discourage the use of this loading method @@ -33,10 +36,15 @@ def load_model(self, file: str): model = model.eval() model.remove_weight_norm() + + if is_half: + model = model.half() + # FIXME: DirectML backend seems to have issues with JIT. Disable it for now. if dev.type == 'privateuseone': self.use_jit = True else: + logger.info('Compiling JIT model...') model = torch.jit.optimize_for_inference(torch.jit.script(model), other_methods=['infer']) torch.jit.save(model, jit_file) self.use_jit = False @@ -44,9 +52,6 @@ def load_model(self, file: str): model = torch.jit.load(jit_file) self.use_jit = False - if is_half: - model = model.half() - self.model = model return self diff --git a/server/voice_changer/RVC/inferencer/rvc_models/infer_pack/models.py b/server/voice_changer/RVC/inferencer/rvc_models/infer_pack/models.py index 00543dbb4..734a392d1 100644 --- a/server/voice_changer/RVC/inferencer/rvc_models/infer_pack/models.py +++ b/server/voice_changer/RVC/inferencer/rvc_models/infer_pack/models.py @@ -531,7 +531,8 @@ def forward( har_source, noi_source, uv = self.m_source(f0, float(self.upp)) har_source = har_source.transpose(1, 2) if n_res is not None: - if (n := n_res * self.upp) != har_source.shape[-1]: + n = n_res * self.upp + if n != har_source.shape[-1]: har_source = F.interpolate(har_source, size=n, mode="linear") if n_res != x.shape[-1]: x = F.interpolate(x, size=n_res, mode="linear") diff --git a/server/voice_changer/RVC/pitchExtractor/FcpePitchExtractor.py b/server/voice_changer/RVC/pitchExtractor/FcpePitchExtractor.py index 8abd90a73..5f39a3476 100644 --- a/server/voice_changer/RVC/pitchExtractor/FcpePitchExtractor.py +++ b/server/voice_changer/RVC/pitchExtractor/FcpePitchExtractor.py @@ -14,6 +14,7 @@ def __init__(self, file: str): device_manager = DeviceManager.get_instance() # self.is_half = device_manager.use_fp16() # NOTE: FCPE doesn't seem to be behave correctly in FP16 mode. + # NOTE: FCPE doesn't work with torch JIT. Need to debug and find the issue. self.is_half = False model = spawn_infer_model_from_pt(file, self.is_half, device_manager.device, bundled_model=True) diff --git a/server/voice_changer/common/rmvpe/rmvpe.py b/server/voice_changer/common/rmvpe/rmvpe.py index 84ff5ad61..b44ae3ac5 100644 --- a/server/voice_changer/common/rmvpe/rmvpe.py +++ b/server/voice_changer/common/rmvpe/rmvpe.py @@ -1,6 +1,6 @@ import os from typing import List - +import logging import torch.nn as nn import torch.nn.functional as F import torch @@ -10,6 +10,7 @@ from librosa.filters import mel from const import JIT_DIR +logger = logging.getLogger(__file__) class BiGRU(nn.Module): def __init__(self, input_features, hidden_features, num_layers): @@ -331,7 +332,8 @@ def forward(self, audio, keyshift=0, speed=1, center=True): class RMVPE: def __init__(self, model_path: str, is_half: bool, device: torch.device): filename = os.path.splitext(os.path.basename(model_path))[0] - jit_filename = f'{filename}_{device.type}_{device.index}.torchscript' if device.index is not None else f'{filename}_{device.type}.torchscript' + fp_prefix = 'fp16' if is_half else 'fp32' + jit_filename = f'{filename}_{device.type}_{device.index}_{fp_prefix}.torchscript' if device.index is not None else f'{filename}_{device.type}_{fp_prefix}.torchscript' jit_file = os.path.join(JIT_DIR, jit_filename) if not os.path.exists(jit_file): model = E2E(4, 1, (2, 2)) @@ -343,19 +345,22 @@ def __init__(self, model_path: str, is_half: bool, device: torch.device): model.load_state_dict(cpt, strict=False) model = model.eval().to(device) + if is_half: + model = model.half() + # FIXME: DirectML backend seems to have issues with JIT. Disable it for now. if device.type == 'privateuseone': self.use_jit = True else: - model = torch.jit.optimize_for_inference(torch.jit.script(model)) + logger.info('Compiling JIT model...') + # NOTE: jit.optimize_for_inference produces unserializable model on CPU + model = torch.jit.freeze(torch.jit.script(model)) torch.jit.save(model, jit_file) self.use_jit = False else: model = torch.jit.load(jit_file) self.use_jit = False - if is_half: - model = model.half() self.model = model self.mel_extractor = MelSpectrogram( From eb7b4e9507f96697ab06499142ee634a2e537724 Mon Sep 17 00:00:00 2001 From: Yury Date: Sat, 3 Aug 2024 17:59:29 +0300 Subject: [PATCH 4/4] Compile JIT online --- server/const.py | 4 -- .../RVC/inferencer/RVCInferencerv2.py | 54 ++++++++----------- .../RVC/inferencer/RVCInferencerv2Nono.py | 52 +++++++----------- .../RVC/pitchExtractor/RMVPEPitchExtractor.py | 2 +- .../common/deviceManager/DeviceManager.py | 4 ++ server/voice_changer/common/rmvpe/rmvpe.py | 50 +++++++---------- 6 files changed, 64 insertions(+), 102 deletions(-) diff --git a/server/const.py b/server/const.py index ab1123093..74b006ef5 100644 --- a/server/const.py +++ b/server/const.py @@ -16,10 +16,6 @@ DOTENV_FILE = os.path.join(ROOT_PATH, '.env') STORED_SETTING_FILE = os.path.join(ROOT_PATH, 'stored_setting.json') ASSETS_FILE = os.path.join(ROOT_PATH, 'assets.json') -# TODO: Need JIT cache invalidation strategy -JIT_DIR = os.path.join(ROOT_PATH, '.jit') - -os.makedirs(JIT_DIR, exist_ok=True) SERVER_DEVICE_SAMPLE_RATES = [16000, 32000, 44100, 48000, 96000, 192000] diff --git a/server/voice_changer/RVC/inferencer/RVCInferencerv2.py b/server/voice_changer/RVC/inferencer/RVCInferencerv2.py index 6c9e3a636..66d34885a 100644 --- a/server/voice_changer/RVC/inferencer/RVCInferencerv2.py +++ b/server/voice_changer/RVC/inferencer/RVCInferencerv2.py @@ -1,9 +1,8 @@ import torch import json -import os import logging from safetensors import safe_open -from const import EnumInferenceTypes, JIT_DIR +from const import EnumInferenceTypes from voice_changer.common.deviceManager.DeviceManager import DeviceManager from voice_changer.RVC.inferencer.Inferencer import Inferencer from .rvc_models.infer_pack.models import SynthesizerTrnMs768NSFsid @@ -16,41 +15,30 @@ def load_model(self, file: str): device_manager = DeviceManager.get_instance() dev = device_manager.device is_half = device_manager.use_fp16() + use_jit_compile = device_manager.use_jit_compile() self.set_props(EnumInferenceTypes.pyTorchRVCv2, file) - filename = os.path.splitext(os.path.basename(file))[0] - fp_prefix = 'fp16' if is_half else 'fp32' - jit_filename = f'{filename}_{dev.type}_{dev.index}_{fp_prefix}.torchscript' if dev.index is not None else f'{filename}_{dev.type}_{fp_prefix}.torchscript' - jit_file = os.path.join(JIT_DIR, jit_filename) - if not os.path.exists(jit_file): - # Keep torch.load for backward compatibility, but discourage the use of this loading method - if file.endswith('.safetensors'): - with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt: - config = json.loads(cpt.metadata()['config']) - model = SynthesizerTrnMs768NSFsid(*config, is_half=is_half).to(dev) - load_model(model, cpt, strict=False) - else: - cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu') - model = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half).to(dev) - model.load_state_dict(cpt["weight"], strict=False) - model = model.eval() + # Keep torch.load for backward compatibility, but discourage the use of this loading method + if file.endswith('.safetensors'): + with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt: + config = json.loads(cpt.metadata()['config']) + model = SynthesizerTrnMs768NSFsid(*config, is_half=is_half).to(dev) + load_model(model, cpt, strict=False) + else: + cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu') + model = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half).to(dev) + model.load_state_dict(cpt["weight"], strict=False) + model = model.eval() - model.remove_weight_norm() + model.remove_weight_norm() - if is_half: - model = model.half() + if is_half: + model = model.half() - # FIXME: DirectML backend seems to have issues with JIT. Disable it for now. - if dev.type == 'privateuseone': - self.use_jit = True - else: - logger.info('Compiling JIT model...') - model = torch.jit.optimize_for_inference(torch.jit.script(model), other_methods=['infer']) - torch.jit.save(model, jit_file) - self.use_jit = False - else: - model = torch.jit.load(jit_file) - self.use_jit = False + self.use_jit_eager = not use_jit_compile + if use_jit_compile: + logger.info('Compiling JIT model...') + model = torch.jit.optimize_for_inference(torch.jit.script(model), other_methods=['infer']) self.model = model return self @@ -68,7 +56,7 @@ def infer( ) -> torch.Tensor: assert pitch is not None or pitchf is not None, "Pitch or Pitchf is not found." - with torch.jit.optimized_execution(self.use_jit): + with torch.jit.optimized_execution(self.use_jit_eager): res = self.model.infer( feats, pitch_length, diff --git a/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py b/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py index 6271c8661..956dda0d9 100644 --- a/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py +++ b/server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py @@ -1,9 +1,8 @@ import torch import json -import os import logging from safetensors import safe_open -from const import EnumInferenceTypes, JIT_DIR +from const import EnumInferenceTypes from voice_changer.common.deviceManager.DeviceManager import DeviceManager from voice_changer.RVC.inferencer.Inferencer import Inferencer from .rvc_models.infer_pack.models import SynthesizerTrnMs768NSFsid_nono @@ -16,41 +15,30 @@ def load_model(self, file: str): device_manager = DeviceManager.get_instance() dev = device_manager.device is_half = device_manager.use_fp16() + use_jit_compile = device_manager.use_jit_compile() self.set_props(EnumInferenceTypes.pyTorchRVCv2Nono, file, is_half) - filename = os.path.splitext(os.path.basename(file))[0] - fp_prefix = 'fp16' if is_half else 'fp32' - jit_filename = f'{filename}_{dev.type}_{dev.index}_{fp_prefix}.torchscript' if dev.index is not None else f'{filename}_{dev.type}_{fp_prefix}.torchscript' - jit_file = os.path.join(JIT_DIR, jit_filename) - if not os.path.exists(jit_file): - # Keep torch.load for backward compatibility, but discourage the use of this loading method - if file.endswith('.safetensors'): - with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt: - config = json.loads(cpt.metadata()['config']) - model = SynthesizerTrnMs768NSFsid_nono(*config, is_half=is_half).to(dev) - load_model(model, cpt, strict=False) - else: - cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu') - model = SynthesizerTrnMs768NSFsid_nono(*cpt["config"], is_half=is_half).to(dev) - model.load_state_dict(cpt["weight"], strict=False) - model = model.eval() + # Keep torch.load for backward compatibility, but discourage the use of this loading method + if file.endswith('.safetensors'): + with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt: + config = json.loads(cpt.metadata()['config']) + model = SynthesizerTrnMs768NSFsid_nono(*config, is_half=is_half).to(dev) + load_model(model, cpt, strict=False) + else: + cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu') + model = SynthesizerTrnMs768NSFsid_nono(*cpt["config"], is_half=is_half).to(dev) + model.load_state_dict(cpt["weight"], strict=False) + model = model.eval() - model.remove_weight_norm() + model.remove_weight_norm() - if is_half: - model = model.half() + if is_half: + model = model.half() - # FIXME: DirectML backend seems to have issues with JIT. Disable it for now. - if dev.type == 'privateuseone': - self.use_jit = True - else: - logger.info('Compiling JIT model...') - model = torch.jit.optimize_for_inference(torch.jit.script(model), other_methods=['infer']) - torch.jit.save(model, jit_file) - self.use_jit = False - else: - model = torch.jit.load(jit_file) - self.use_jit = False + self.use_jit_eager = not use_jit_compile + if use_jit_compile: + logger.info('Compiling JIT model...') + model = torch.jit.optimize_for_inference(torch.jit.script(model), other_methods=['infer']) self.model = model return self diff --git a/server/voice_changer/RVC/pitchExtractor/RMVPEPitchExtractor.py b/server/voice_changer/RVC/pitchExtractor/RMVPEPitchExtractor.py index 05617a322..03654c4e7 100644 --- a/server/voice_changer/RVC/pitchExtractor/RMVPEPitchExtractor.py +++ b/server/voice_changer/RVC/pitchExtractor/RMVPEPitchExtractor.py @@ -12,7 +12,7 @@ def __init__(self, file: str): self.type: PitchExtractorType = "rmvpe" device_manager = DeviceManager.get_instance() - self.rmvpe = RMVPE(model_path=file, is_half=device_manager.use_fp16(), device=device_manager.device) + self.rmvpe = RMVPE(model_path=file, is_half=device_manager.use_fp16(), use_jit_compile=device_manager.use_jit_compile(), device=device_manager.device) def extract( self, diff --git a/server/voice_changer/common/deviceManager/DeviceManager.py b/server/voice_changer/common/deviceManager/DeviceManager.py index d8d69352c..27e3be51e 100644 --- a/server/voice_changer/common/deviceManager/DeviceManager.py +++ b/server/voice_changer/common/deviceManager/DeviceManager.py @@ -71,6 +71,10 @@ def set_device(self, id: int): def use_fp16(self): return self.fp16_available and not self.force_fp32 + def use_jit_compile(self): + # FIXME: DirectML backend seems to have issues with JIT. Disable it for now. + return self.device_metadata['backend'] != 'directml' + # TODO: This function should also accept backend type def _get_device(self, dev_id: int) -> tuple[torch.device, DevicePresentation]: if dev_id == -1: diff --git a/server/voice_changer/common/rmvpe/rmvpe.py b/server/voice_changer/common/rmvpe/rmvpe.py index b44ae3ac5..4f84734c4 100644 --- a/server/voice_changer/common/rmvpe/rmvpe.py +++ b/server/voice_changer/common/rmvpe/rmvpe.py @@ -1,4 +1,3 @@ -import os from typing import List import logging import torch.nn as nn @@ -8,7 +7,6 @@ from safetensors import safe_open from voice_changer.common.SafetensorsUtils import load_model from librosa.filters import mel -from const import JIT_DIR logger = logging.getLogger(__file__) @@ -330,36 +328,23 @@ def forward(self, audio, keyshift=0, speed=1, center=True): class RMVPE: - def __init__(self, model_path: str, is_half: bool, device: torch.device): - filename = os.path.splitext(os.path.basename(model_path))[0] - fp_prefix = 'fp16' if is_half else 'fp32' - jit_filename = f'{filename}_{device.type}_{device.index}_{fp_prefix}.torchscript' if device.index is not None else f'{filename}_{device.type}_{fp_prefix}.torchscript' - jit_file = os.path.join(JIT_DIR, jit_filename) - if not os.path.exists(jit_file): - model = E2E(4, 1, (2, 2)) - if model_path.endswith('.safetensors'): - with safe_open(model_path, 'pt', device=str(device) if device.type == 'cuda' else 'cpu') as cpt: - load_model(model, cpt, strict=False) - else: - cpt = torch.load(model_path, map_location=device if device.type == 'cuda' else 'cpu') - model.load_state_dict(cpt, strict=False) - model = model.eval().to(device) - - if is_half: - model = model.half() - - # FIXME: DirectML backend seems to have issues with JIT. Disable it for now. - if device.type == 'privateuseone': - self.use_jit = True - else: - logger.info('Compiling JIT model...') - # NOTE: jit.optimize_for_inference produces unserializable model on CPU - model = torch.jit.freeze(torch.jit.script(model)) - torch.jit.save(model, jit_file) - self.use_jit = False + def __init__(self, model_path: str, is_half: bool, use_jit_compile: bool, device: torch.device): + model = E2E(4, 1, (2, 2)) + if model_path.endswith('.safetensors'): + with safe_open(model_path, 'pt', device=str(device) if device.type == 'cuda' else 'cpu') as cpt: + load_model(model, cpt, strict=False) else: - model = torch.jit.load(jit_file) - self.use_jit = False + cpt = torch.load(model_path, map_location=device if device.type == 'cuda' else 'cpu') + model.load_state_dict(cpt, strict=False) + model = model.eval().to(device) + + if is_half: + model = model.half() + + self.use_jit_eager = not use_jit_compile + if use_jit_compile: + logger.info('Compiling JIT model...') + model = torch.jit.optimize_for_inference(torch.jit.script(model)) self.model = model @@ -372,7 +357,8 @@ def __init__(self, model_path: str, is_half: bool, device: torch.device): def mel2hidden(self, mel: torch.Tensor) -> torch.Tensor: n_frames = mel.shape[-1] mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='reflect') - return self.model(mel)[:, :n_frames] + with torch.jit.optimized_execution(self.use_jit_eager): + return self.model(mel)[:, :n_frames] def decode(self, hidden: torch.Tensor, threshold: float): center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1]