Skip to content

Commit

Permalink
Enable JIT for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
deiteris committed Jun 7, 2024
1 parent a5458f2 commit 5b1b0e1
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 73 deletions.
4 changes: 4 additions & 0 deletions server/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

STORED_SETTING_FILE = "stored_setting.json"
ASSETS_FILE = 'assets.json'
# TODO: Need JIT cache invalidation strategy
JIT_DIR = '.jit'

os.makedirs(JIT_DIR, exist_ok=True)

SERVER_DEVICE_SAMPLE_RATES = [16000, 32000, 44100, 48000, 96000, 192000]

Expand Down
43 changes: 30 additions & 13 deletions server/voice_changer/RVC/inferencer/RVCInferencerv2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import json
import os
from safetensors import safe_open
from const import EnumInferenceTypes
from const import EnumInferenceTypes, JIT_DIR
from voice_changer.common.deviceManager.DeviceManager import DeviceManager
from voice_changer.RVC.inferencer.Inferencer import Inferencer
from .rvc_models.infer_pack.models import SynthesizerTrnMs768NSFsid
Expand All @@ -15,18 +16,33 @@ def loadModel(self, file: str, gpu: int):
isHalf = False
self.setProps(EnumInferenceTypes.pyTorchRVCv2, file, isHalf, gpu)

# Keep torch.load for backward compatibility, but discourage the use of this loading method
if '.safetensors' in file:
with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt:
config = json.loads(cpt.metadata()['config'])
model = SynthesizerTrnMs768NSFsid(*config, is_half=False).to(dev)
load_model(model, cpt, strict=False)
else:
cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu')
model = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=False).to(dev)
model.load_state_dict(cpt["weight"], strict=False)
filename = os.path.splitext(os.path.basename(file))[0]
jit_filename = f'{filename}_{dev.type}_{dev.index}.torchscript' if dev.index is not None else f'{filename}_{dev.type}.torchscript'
jit_file = os.path.join(JIT_DIR, jit_filename)
if not os.path.exists(jit_file):
# Keep torch.load for backward compatibility, but discourage the use of this loading method
if '.safetensors' in file:
with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt:
config = json.loads(cpt.metadata()['config'])
model = SynthesizerTrnMs768NSFsid(*config, is_half=False).to(dev)
load_model(model, cpt, strict=False)
else:
cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu')
model = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=False).to(dev)
model.load_state_dict(cpt["weight"], strict=False)

model.eval().remove_weight_norm()
model.remove_weight_norm()
# FIXME: DirectML backend seems to have issues with JIT. Disable it for now.
if dev.type == 'privateuseone':
model = model.eval()
self.use_jit = True
else:
model = torch.jit.freeze(torch.jit.script(model.eval()))
torch.jit.save(model, jit_file)
self.use_jit = False
else:
model = torch.jit.load(jit_file)
self.use_jit = False

self.model = model
return self
Expand All @@ -44,6 +60,7 @@ def infer(
if pitch is None or pitchf is None:
raise RuntimeError("[Voice Changer] Pitch or Pitchf is not found.")

res = self.model.infer(feats, pitch_length, pitch, pitchf, sid, skip_head=skip_head, return_length=return_length)
with torch.jit.optimized_execution(self.use_jit):
res = self.model.infer(feats, pitch_length, pitch, pitchf, sid, skip_head=skip_head, return_length=return_length)
res = res[0][0, 0].to(dtype=torch.float32)
return torch.clip(res, -1.0, 1.0, out=res)
43 changes: 30 additions & 13 deletions server/voice_changer/RVC/inferencer/RVCInferencerv2Nono.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import json
import os
from safetensors import safe_open
from const import EnumInferenceTypes
from const import EnumInferenceTypes, JIT_DIR
from voice_changer.common.deviceManager.DeviceManager import DeviceManager
from voice_changer.RVC.inferencer.Inferencer import Inferencer
from .rvc_models.infer_pack.models import SynthesizerTrnMs768NSFsid_nono
Expand All @@ -15,18 +16,33 @@ def loadModel(self, file: str, gpu: int):
isHalf = False
self.setProps(EnumInferenceTypes.pyTorchRVCv2Nono, file, isHalf, gpu)

# Keep torch.load for backward compatibility, but discourage the use of this loading method
if '.safetensors' in file:
with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt:
config = json.loads(cpt.metadata()['config'])
model = SynthesizerTrnMs768NSFsid_nono(*config, is_half=False).to(dev)
load_model(model, cpt, strict=False)
else:
cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu')
model = SynthesizerTrnMs768NSFsid_nono(*cpt["config"], is_half=False).to(dev)
model.load_state_dict(cpt["weight"], strict=False)
filename = os.path.splitext(os.path.basename(file))[0]
jit_filename = f'{filename}_{dev.type}_{dev.index}.torchscript' if dev.index is not None else f'{filename}_{dev.type}.torchscript'
jit_file = os.path.join(JIT_DIR, jit_filename)
if not os.path.exists(jit_file):
# Keep torch.load for backward compatibility, but discourage the use of this loading method
if '.safetensors' in file:
with safe_open(file, 'pt', device=str(dev) if dev.type == 'cuda' else 'cpu') as cpt:
config = json.loads(cpt.metadata()['config'])
model = SynthesizerTrnMs768NSFsid_nono(*config, is_half=False).to(dev)
load_model(model, cpt, strict=False)
else:
cpt = torch.load(file, map_location=dev if dev.type == 'cuda' else 'cpu')
model = SynthesizerTrnMs768NSFsid_nono(*cpt["config"], is_half=False).to(dev)
model.load_state_dict(cpt["weight"], strict=False)

model.eval().remove_weight_norm()
model.remove_weight_norm()
# FIXME: DirectML backend seems to have issues with JIT. Disable it for now.
if dev.type == 'privateuseone':
model = model.eval()
self.use_jit = True
else:
model = torch.jit.freeze(torch.jit.script(model.eval()))
torch.jit.save(model, jit_file)
self.use_jit = False
else:
model = torch.jit.load(jit_file)
self.use_jit = False

self.model = model
return self
Expand All @@ -41,6 +57,7 @@ def infer(
skip_head: torch.Tensor | None,
return_length: torch.Tensor | None,
) -> torch.Tensor:
res = self.model.infer(feats, pitch_length, sid, skip_head=skip_head, return_length=return_length)
with torch.jit.optimized_execution(self.use_jit):
res = self.model.infer(feats, pitch_length, sid, skip_head=skip_head, return_length=return_length)
res = res[0][0, 0].to(dtype=torch.float32)
return torch.clip(res, -1.0, 1.0, out=res)
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def __prepare_scriptable__(self):
torch.nn.utils.remove_weight_norm(self.enc_q)
return self

# @torch.jit.ignore
@torch.jit.ignore
def forward(
self,
phone: torch.Tensor,
Expand All @@ -776,7 +776,7 @@ def forward(
o = self.dec(z_slice, pitchf, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

# @torch.jit.export
@torch.jit.export
def infer(
self,
phone: torch.Tensor,
Expand Down Expand Up @@ -920,7 +920,7 @@ def __prepare_scriptable__(self):
torch.nn.utils.remove_weight_norm(self.enc_q)
return self

# @torch.jit.ignore
@torch.jit.ignore
def forward(
self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
): # 这里ds是id,[bs,1]
Expand All @@ -938,7 +938,7 @@ def forward(
o = self.dec(z_slice, pitchf, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

# @torch.jit.export
@torch.jit.export
def infer(
self,
phone: torch.Tensor,
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def __prepare_scriptable__(self):
torch.nn.utils.remove_weight_norm(self.enc_q)
return self

# @torch.jit.ignore
@torch.jit.ignore
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
Expand All @@ -1091,7 +1091,7 @@ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[b
o = self.dec(z_slice, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

# @torch.jit.export
@torch.jit.export
def infer(
self,
phone: torch.Tensor,
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def __prepare_scriptable__(self):
torch.nn.utils.remove_weight_norm(self.enc_q)
return self

# @torch.jit.ignore
@torch.jit.ignore
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
Expand All @@ -1240,7 +1240,7 @@ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[b
o = self.dec(z_slice, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

# @torch.jit.export
@torch.jit.export
def infer(
self,
phone: torch.Tensor,
Expand Down
Loading

0 comments on commit 5b1b0e1

Please sign in to comment.