Skip to content

Commit

Permalink
Merge pull request ben0oil1#3 from CNXudiandian/directml
Browse files Browse the repository at this point in the history
增加对AMD用户的友好功能~
  • Loading branch information
ben0oil1 authored Apr 11, 2024
2 parents 7f20a19 + d4c4b50 commit 5e3f926
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 60 deletions.
10 changes: 10 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"cnhubert_path":"./pretrained/chinese-hubert-base",
"bert_path":"./pretrained/chinese-roberta-wwm-ext-large",
"sovits_path":"./models/sovits.pth",
"gpt_path":"./models/gpt.ckpt",
"default_refer_path":"./models/referaudio.wav",
"default_refer_text":"参考文本",
"default_refer_language":"zh",
"is_half":false
}
74 changes: 29 additions & 45 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,95 +5,79 @@ anyio==4.3.0
async-timeout==4.0.3
attrs==23.2.0
audioread==3.0.1
blinker==1.4
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
cn2an==0.5.22
contourpy==1.2.0
cryptography==3.4.8
colorama==0.4.6
contourpy==1.2.1
cycler==0.12.1
dbus-python==1.2.18
decorator==5.1.1
distro==1.7.0
distro-info==1.1+ubuntu0.2
einops==0.7.0
exceptiongroup==1.2.0
fastapi==0.110.0
feature-extractor==0.0.1
ffmpeg-python==0.1.18
filelock==3.13.1
fonttools==4.49.0
ffmpeg-python==0.2.0
filelock==3.13.3
fonttools==4.51.0
frozenlist==1.4.1
fsspec==2024.2.0
fsspec==2024.3.1
future==1.0.0
h11==0.14.0
httplib2==0.20.2
huggingface-hub==0.21.4
huggingface-hub==0.22.2
idna==3.6
importlib-metadata==4.6.4
jeepney==0.7.1
importlib_resources==6.4.0
jieba==0.42.1
jieba-fast==0.53
Jinja2==3.1.3
joblib==1.3.2
keyring==23.5.0
kiwisolver==1.4.5
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
librosa==0.9.2
lightning-utilities==0.10.1
llvmlite==0.39.1
lazy_loader==0.4
librosa==0.10.1
lightning-utilities==0.11.2
llvmlite==0.42.0
MarkupSafe==2.1.5
matplotlib==3.8.3
more-itertools==8.10.0
mpmath==1.3.0
msgpack==1.0.8
multidict==6.0.5
networkx==3.2.1
numba==0.56.4
numpy==1.23.5
oauthlib==3.2.0
packaging==23.2
pillow==10.2.0
platformdirs==2.5.1
numba==0.59.1
numpy==1.26.4
packaging==24.0
pillow==10.3.0
platformdirs==4.2.0
pooch==1.8.1
proces==0.1.7
pycparser==2.21
pydantic==2.6.3
pycparser==2.22
pydantic==2.6.4
pydantic_core==2.16.3
PyGObject==3.42.1
PyJWT==2.3.0
pyparsing==3.1.2
pypinyin==0.50.0
python-apt==2.4.0+ubuntu3
python-dateutil==2.9.0.post0
pytorch-lightning==2.2.1
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
resampy==0.4.3
safetensors==0.4.2
scikit-learn==1.4.1.post1
scipy==1.12.0
SecretStorage==3.3.1
scipy==1.13.0
six==1.16.0
sniffio==1.3.1
soundfile==0.12.1
soxr==0.3.7
starlette==0.36.3
sympy==1.12
threadpoolctl==3.3.0
threadpoolctl==3.4.0
tokenizers==0.15.2
torch==2.2.1
torchaudio==2.2.1
torchmetrics==1.3.1
torch==2.0.0
torch-directml==0.2.0.dev230426
torchmetrics==1.3.2
torchvision==0.15.1
tqdm==4.66.2
transformers==4.38.2
typing_extensions==4.10.0
unattended-upgrades==0.1
typing_extensions==4.11.0
urllib3==2.2.1
uvicorn==0.27.1
wadllib==1.3.6
yarl==1.9.4
zipp==1.0.0
zipp==3.18.1
42 changes: 27 additions & 15 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import soundfile as sf
import torch
import torch_directml
import uvicorn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from fastapi import FastAPI, HTTPException, Request
Expand All @@ -22,21 +23,32 @@
from transformers import AutoModelForMaskedLM, AutoTokenizer
from text import cleaned_text_to_sequence
from text.cleaner import clean_text

# -------pretrained_models和训练好的模型,丢在下面的位置----------------
cnhubert_path = "./data/pretrained_models/chinese-hubert-base"
bert_path = "./data/pretrained_models/chinese-roberta-wwm-ext-large"
# 从云端下载回来的训练好的内容
sovits_path = './data/_models/svc/0128-0359_e12_s144.pth'
gpt_path = './data/_models/gpt/0128-0359-e30.ckpt'
# 推理引用的音频文件地址和文本信息
default_refer_path = './data/_models/000.wav'
default_refer_text = "云南凤庆给您发货,一斤装四十九,两斤装九十五,三斤装一百二十九,规格越大价格越划算。"
# 语言,我个人把日语韩语乱七八糟的直接删除了,因为我用不上,大家需要的话自己做适配
default_refer_language = 'zh'
# 使用cuda就是用英伟达的gpu,cpu就是用cpu
device = 'cpu' # cpu cuda
is_half = False
import json

# 增加一个读取配置文件的功能
with open("config.json", "r", encoding="utf-8") as f:
config = json.load(f)
cnhubert_path = config["cnhubert_path"]
bert_path = config["bert_path"]
sovits_path = config["sovits_path"]
gpt_path = config["gpt_path"]
default_refer_path = config["default_refer_path"]
default_refer_text = config["default_refer_text"]
default_refer_language = config["default_refer_language"]
is_half = config["is_half"]
# 自动判断环境是否支持CUDA和DirectML
if(torch.cuda.is_available()):
print("CUDA可用,将使用CUDA进行推理加速。")
print("设备名称:",torch.cuda.get_device_name(0))
device = "cuda"
else:
if(torch_directml.is_available()==False):
device = "cpu"
print("在本机没有发现可以用于加速的显卡,使用CPU进行推理运算。")
else:
device = torch_directml.device(0)
print("DirectML可用,将使用DirectML进行推理加速。")
print("设备名称:",torch_directml.device_name(0))
# -----------------------

# 如果要增加更多的参数选项,在这里设定
Expand Down

0 comments on commit 5e3f926

Please sign in to comment.