Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch JIT #70

Merged
merged 4 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions server/voice_changer/RVC/inferencer/RVCInferencerv2.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import torch
import json
import logging
from safetensors import safe_open
from const import EnumInferenceTypes
from voice_changer.common.deviceManager.DeviceManager import DeviceManager
from voice_changer.RVC.inferencer.Inferencer import Inferencer
from .rvc_models.infer_pack.models import SynthesizerTrnMs768NSFsid
from voice_changer.common.SafetensorsUtils import load_model

logger = logging.getLogger(__name__)

class RVCInferencerv2(Inferencer):
def load_model(self, file: str):
device_manager = DeviceManager.get_instance()
dev = device_manager.device
is_half = device_manager.use_fp16()
use_jit_compile = device_manager.use_jit_compile()
self.set_props(EnumInferenceTypes.pyTorchRVCv2, file)

# Keep torch.load for backward compatibility, but discourage the use of this loading method
Expand All @@ -25,12 +28,18 @@ def load_model(self, file: str):
cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu')
model = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half).to(dev)
model.load_state_dict(cpt["weight"], strict=False)
model = model.eval()

model.eval().remove_weight_norm()
model.remove_weight_norm()

if is_half:
model = model.half()

self.use_jit_eager = not use_jit_compile
if use_jit_compile:
logger.info('Compiling JIT model...')
model = torch.jit.optimize_for_inference(torch.jit.script(model), other_methods=['infer'])

self.model = model
return self

Expand All @@ -47,15 +56,16 @@ def infer(
) -> torch.Tensor:
assert pitch is not None or pitchf is not None, "Pitch or Pitchf is not found."

res = self.model.infer(
feats,
pitch_length,
pitch,
pitchf,
sid,
skip_head=skip_head,
return_length=return_length,
formant_length=formant_length
)
with torch.jit.optimized_execution(self.use_jit_eager):
res = self.model.infer(
feats,
pitch_length,
pitch,
pitchf,
sid,
skip_head=skip_head,
return_length=return_length,
formant_length=formant_length
)
res = res[0][0, 0]
return torch.clip(res, -1.0, 1.0, out=res)
28 changes: 19 additions & 9 deletions server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import torch
import json
import logging
from safetensors import safe_open
from const import EnumInferenceTypes
from voice_changer.common.deviceManager.DeviceManager import DeviceManager
from voice_changer.RVC.inferencer.Inferencer import Inferencer
from .rvc_models.infer_pack.models import SynthesizerTrnMs768NSFsid_nono
from voice_changer.common.SafetensorsUtils import load_model

logger = logging.getLogger(__name__)

class RVCInferencerv2Nono(Inferencer):
def load_model(self, file: str):
device_manager = DeviceManager.get_instance()
dev = device_manager.device
is_half = device_manager.use_fp16()
use_jit_compile = device_manager.use_jit_compile()
self.set_props(EnumInferenceTypes.pyTorchRVCv2Nono, file, is_half)

# Keep torch.load for backward compatibility, but discourage the use of this loading method
Expand All @@ -25,12 +28,18 @@ def load_model(self, file: str):
cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu')
model = SynthesizerTrnMs768NSFsid_nono(*cpt["config"], is_half=is_half).to(dev)
model.load_state_dict(cpt["weight"], strict=False)
model = model.eval()

model.eval().remove_weight_norm()
model.remove_weight_norm()

if is_half:
model = model.half()

self.use_jit_eager = not use_jit_compile
if use_jit_compile:
logger.info('Compiling JIT model...')
model = torch.jit.optimize_for_inference(torch.jit.script(model), other_methods=['infer'])

self.model = model
return self

Expand All @@ -45,13 +54,14 @@ def infer(
return_length: int,
formant_length: int,
) -> torch.Tensor:
res = self.model.infer(
feats,
pitch_length,
sid,
skip_head=skip_head,
return_length=return_length,
formant_length=formant_length
)
with torch.jit.optimized_execution(self.use_jit):
res = self.model.infer(
feats,
pitch_length,
sid,
skip_head=skip_head,
return_length=return_length,
formant_length=formant_length
)
res = res[0][0, 0]
return torch.clip(res, -1.0, 1.0, out=res)
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,8 @@ def forward(
har_source, noi_source, uv = self.m_source(f0, float(self.upp))
har_source = har_source.transpose(1, 2)
if n_res is not None:
if (n := n_res * self.upp) != har_source.shape[-1]:
n = n_res * self.upp
if n != har_source.shape[-1]:
har_source = F.interpolate(har_source, size=n, mode="linear")
if n_res != x.shape[-1]:
x = F.interpolate(x, size=n_res, mode="linear")
Expand Down Expand Up @@ -716,7 +717,7 @@ def __prepare_scriptable__(self):
torch.nn.utils.remove_weight_norm(self.enc_q)
return self

# @torch.jit.ignore
@torch.jit.ignore
def forward(
self,
phone: torch.Tensor,
Expand All @@ -741,7 +742,7 @@ def forward(
o = self.dec(z_slice, pitchf, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

# @torch.jit.export
@torch.jit.export
def infer(
self,
phone: torch.Tensor,
Expand Down Expand Up @@ -937,7 +938,7 @@ def __prepare_scriptable__(self):
torch.nn.utils.remove_weight_norm(self.enc_q)
return self

# @torch.jit.ignore
@torch.jit.ignore
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
Expand All @@ -949,7 +950,7 @@ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[b
o = self.dec(z_slice, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

# @torch.jit.export
@torch.jit.export
def infer(
self,
phone: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, file: str):
device_manager = DeviceManager.get_instance()
# self.is_half = device_manager.use_fp16()
# NOTE: FCPE doesn't seem to be behave correctly in FP16 mode.
# NOTE: FCPE doesn't work with torch JIT. Need to debug and find the issue.
self.is_half = False

model = spawn_infer_model_from_pt(file, self.is_half, device_manager.device, bundled_model=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, file: str):
self.type: PitchExtractorType = "rmvpe"

device_manager = DeviceManager.get_instance()
self.rmvpe = RMVPE(model_path=file, is_half=device_manager.use_fp16(), device=device_manager.device)
self.rmvpe = RMVPE(model_path=file, is_half=device_manager.use_fp16(), use_jit_compile=device_manager.use_jit_compile(), device=device_manager.device)

def extract(
self,
Expand Down
4 changes: 4 additions & 0 deletions server/voice_changer/common/deviceManager/DeviceManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def set_device(self, id: int):
def use_fp16(self):
return self.fp16_available and not self.force_fp32

def use_jit_compile(self):
# FIXME: DirectML backend seems to have issues with JIT. Disable it for now.
return self.device_metadata['backend'] != 'directml'

# TODO: This function should also accept backend type
def _get_device(self, dev_id: int) -> tuple[torch.device, DevicePresentation]:
if dev_id == -1:
Expand Down
16 changes: 12 additions & 4 deletions server/voice_changer/common/rmvpe/rmvpe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List

import logging
import torch.nn as nn
import torch.nn.functional as F
import torch
Expand All @@ -8,6 +8,7 @@
from voice_changer.common.SafetensorsUtils import load_model
from librosa.filters import mel

logger = logging.getLogger(__file__)

class BiGRU(nn.Module):
def __init__(self, input_features, hidden_features, num_layers):
Expand Down Expand Up @@ -327,18 +328,24 @@ def forward(self, audio, keyshift=0, speed=1, center=True):


class RMVPE:
def __init__(self, model_path: str, is_half: bool, device: torch.device):
def __init__(self, model_path: str, is_half: bool, use_jit_compile: bool, device: torch.device):
model = E2E(4, 1, (2, 2))
if model_path.endswith('.safetensors'):
with safe_open(model_path, 'pt', device=str(device) if device.type == 'cuda' else 'cpu') as cpt:
load_model(model, cpt, strict=False)
else:
cpt = torch.load(model_path, map_location=device if device.type == 'cuda' else 'cpu')
model.load_state_dict(cpt, strict=False)
model.eval().to(device)
model = model.eval().to(device)

if is_half:
model = model.half()

self.use_jit_eager = not use_jit_compile
if use_jit_compile:
logger.info('Compiling JIT model...')
model = torch.jit.optimize_for_inference(torch.jit.script(model))

self.model = model

self.mel_extractor = MelSpectrogram(
Expand All @@ -350,7 +357,8 @@ def __init__(self, model_path: str, is_half: bool, device: torch.device):
def mel2hidden(self, mel: torch.Tensor) -> torch.Tensor:
n_frames = mel.shape[-1]
mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='reflect')
return self.model(mel)[:, :n_frames]
with torch.jit.optimized_execution(self.use_jit_eager):
return self.model(mel)[:, :n_frames]

def decode(self, hidden: torch.Tensor, threshold: float):
center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1]
Expand Down
Loading