Skip to content

Commit

Permalink
Merge pull request #139 from deiteris/refactoring-4
Browse files Browse the repository at this point in the history
Refactoring
  • Loading branch information
deiteris authored Jul 20, 2024
2 parents c32e57e + 5d6e7c1 commit ffc0cb7
Show file tree
Hide file tree
Showing 48 changed files with 302 additions and 2,236 deletions.
8 changes: 0 additions & 8 deletions server/Exceptions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
class NoModeLoadedException(Exception):
def __init__(self, framework):
self.framework = framework

def __str__(self):
return repr(f"No model for {self.framework} loaded. Please confirm the model uploaded.")


class VoiceChangerIsNotSelectedException(Exception):
def __str__(self):
return repr("Voice Changer is not selected.")
Expand Down
1 change: 0 additions & 1 deletion server/MMVCServerSIO.spec
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ from PyInstaller.utils.hooks import collect_data_files, collect_all, collect_dyn
import sys
import os.path
import site
import logging

sys.setrecursionlimit(sys.getrecursionlimit() * 5)

Expand Down
8 changes: 3 additions & 5 deletions server/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import asyncio
import traceback

from main import setupArgParser, main
from utils.strtobool import strtobool
from mods.log_control import VoiceChangaerLogger

VoiceChangaerLogger.get_instance().initialize(initialize=True)
logger = VoiceChangaerLogger.get_instance().getLogger()
import logging
logger = logging.getLogger(__name__)

if __name__ == "__main__":
parser = setupArgParser()
Expand All @@ -16,5 +14,5 @@
try:
asyncio.run(main(args))
except Exception as e:
print(traceback.format_exc())
logger.exception(e)
input('Press Enter to continue...')
1 change: 0 additions & 1 deletion server/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class EnumInferenceTypes(Enum):
pyTorchRVCv2Nono = "pyTorchRVCv2Nono"
pyTorchWebUI = "pyTorchWebUI"
pyTorchWebUINono = "pyTorchWebUINono"
pyTorchVoRASbeta = "pyTorchVoRASbeta"
onnxRVC = "onnxRVC"
onnxRVCNono = "onnxRVCNono"

Expand Down
5 changes: 3 additions & 2 deletions server/data/ModelSlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import os
import json

import logging
logger = logging.getLogger(__name__)

@dataclass
class ModelSlot:
Expand Down Expand Up @@ -76,7 +77,7 @@ def loadAllSlotInfo(model_dir: str):

def saveSlotInfo(model_dir: str, slotIndex: int, slotInfo: ModelSlots):
slotDir = os.path.join(model_dir, str(slotIndex))
print("SlotInfo:::", slotInfo)
logger.info(f"SlotInfo::: {slotInfo}")
slotInfoDict = asdict(slotInfo)
slotInfo.slotIndex = -1 # スロットインデックスは動的に注入
json.dump(slotInfoDict, open(os.path.join(slotDir, "params.json"), "w"), indent=4)
8 changes: 4 additions & 4 deletions server/downloader/Downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from downloader.HttpClient import HttpClient
from tqdm import tqdm
from threading import Lock
from mods.log_control import VoiceChangaerLogger
from xxhash import xxh128
from utils.hasher import compute_hash
from const import ASSETS_FILE
from Exceptions import DownloadVerificationException

logger = VoiceChangaerLogger.get_instance().getLogger()
import logging
logger = logging.getLogger(__name__)

lock = Lock()

Expand Down Expand Up @@ -44,13 +44,13 @@ async def download(params: dict):
# If hash was provided with the file - verify against provided hash
if expected_hash is not None:
if hash == expected_hash:
logger.info(f'[Voice Changer] Verified {saveTo}')
logger.info(f'Verified {saveTo}')
return
# If hash was not provided - verify against local cache
elif saveTo in files:
fhash = files[saveTo]
if hash == fhash:
logger.info(f'[Voice Changer] Verified {saveTo}')
logger.info(f'Verified {saveTo}')
return
else:
hash = None
Expand Down
13 changes: 6 additions & 7 deletions server/downloader/SampleDownloader.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import json
import os
import sys
import asyncio
from typing import Any, Tuple

from const import RVCSampleMode, getSampleJsonAndModelIds
from data.ModelSample import ModelSamples, generateModelSample
from data.ModelSlot import ModelSlot, RVCModelSlot
from mods.log_control import VoiceChangaerLogger
import logging
from voice_changer.ModelSlotManager import ModelSlotManager
from voice_changer.RVC.RVCModelSlotGenerator import RVCModelSlotGenerator
from downloader.Downloader import download

logger = VoiceChangaerLogger.get_instance().getLogger()
logger = logging.getLogger(__name__)


async def downloadInitialSamples(mode: RVCSampleMode, model_dir: str):
Expand Down Expand Up @@ -86,7 +85,7 @@ async def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tup
match = True
break
if match is False:
logger.warn(f"[Voice Changer] initiail sample not found. {targetSampleId}")
logger.warn(f"Initial sample not found: {targetSampleId}")
continue

# 検出されたら、、、
Expand Down Expand Up @@ -145,17 +144,17 @@ async def _downloadSamples(samples: list[ModelSamples], sampleModelIds: list[Tup
slotInfo.isONNX = slotInfo.modelFile.endswith(".onnx")
modelSlotManager.save_model_slot(targetSlotIndex, slotInfo)
else:
logger.warn(f"[Voice Changer] {sample.voiceChangerType} is not supported.")
logger.warn(f"{sample.voiceChangerType} is not supported.")

# ダウンロード
logger.info("[Voice Changer] Downloading model files...")
logger.info("Downloading model files...")
tasks: list[asyncio.Task] = []
for file in downloadParams:
tasks.append(asyncio.ensure_future(download(file)))
await asyncio.gather(*tasks)

# メタデータ作成
logger.info("[Voice Changer] Generating metadata...")
logger.info("Generating metadata...")
for targetSlotIndex in slotIndex:
slotInfo = modelSlotManager.get_slot_info(targetSlotIndex)
modelPath = os.path.join(model_dir, str(slotInfo.slotIndex), os.path.basename(slotInfo.modelFile))
Expand Down
10 changes: 5 additions & 5 deletions server/downloader/WeightDownloader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import asyncio

from downloader.Downloader import download
from mods.log_control import VoiceChangaerLogger
import logging
from settings import ServerSettings
from Exceptions import WeightDownloadException

logger = VoiceChangaerLogger.get_instance().getLogger()
logger = logging.getLogger(__name__)

async def downloadWeight(params: ServerSettings):
logger.info('[Voice Changer] Loading weights.')
logger.info('Loading weights.')
file_params = [
# {
# "url": "https://huggingface.co/ddPn08/rvc-webui-models/resolve/main/embeddings/hubert_base.pt",
Expand Down Expand Up @@ -82,8 +82,8 @@ async def downloadWeight(params: ServerSettings):
for res in await asyncio.gather(*tasks, return_exceptions=True):
if isinstance(res, Exception):
fail = True
logger.error(f'[Voice Changer] {res}')
logger.exception(res)
if fail:
raise WeightDownloadException()

logger.info('[Voice Changer] All weights are loaded!')
logger.info('All weights are loaded!')
95 changes: 34 additions & 61 deletions server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,66 +2,48 @@
import multiprocessing as mp
# NOTE: This is required to avoid recursive process call bug for macOS
mp.freeze_support()
from const import SSL_KEY_DIR, DOTENV_FILE, ROOT_PATH, UPLOAD_DIR, TMP_DIR, get_version, get_edition
from const import SSL_KEY_DIR, ROOT_PATH, UPLOAD_DIR, TMP_DIR, LOG_FILE, get_version, get_edition
# NOTE: This is required to fix current working directory on macOS
os.chdir(ROOT_PATH)

import sys
import uvicorn
import asyncio
import traceback

import threading
import socket
import time
from dotenv import set_key
import logging
from utils.strtobool import strtobool
from datetime import datetime
import platform
import argparse
from downloader.WeightDownloader import downloadWeight
from downloader.SampleDownloader import downloadInitialSamples
from mods.ssl import create_self_signed_cert
from webbrowser import open_new_tab
from settings import ServerSettings
from mods.log_control import VoiceChangaerLogger

VoiceChangaerLogger.get_instance().initialize(initialize=True)
logger = VoiceChangaerLogger.get_instance().getLogger()
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)

logging.basicConfig(
level=logging.INFO,
format="%(asctime)-15s %(levelname)-8s [%(module)s] %(message)s",
handlers=[logging.FileHandler(LOG_FILE), stream_handler]
)
logger = logging.getLogger(__name__)
settings = ServerSettings()

def setupArgParser():
parser = argparse.ArgumentParser()
parser.add_argument("--logLevel", type=str, default="error", help="Log level info|critical|error. (default: error)")
parser.add_argument("--log-level", type=str, default="error", help="Log level info|critical|error.")
parser.add_argument("--https", type=strtobool, default=False, help="use https")
parser.add_argument("--httpsKey", type=str, default="ssl.key", help="path for the key of https")
parser.add_argument("--httpsCert", type=str, default="ssl.cert", help="path for the cert of https")
parser.add_argument("--httpsSelfSigned", type=strtobool, default=True, help="generate self-signed certificate")
parser.add_argument("--https-key", type=str, default="ssl.key", help="path for the key of https")
parser.add_argument("--https-cert", type=str, default="ssl.cert", help="path for the cert of https")
parser.add_argument("--https-self-signed", type=strtobool, default=True, help="generate self-signed certificate")

return parser

def printMessage(message, level=0):
pf = platform.system()
if pf == "Windows":
if level == 0:
message = f"{message}"
elif level == 1:
message = f" {message}"
elif level == 2:
message = f" {message}"
else:
message = f" {message}"
else:
if level == 0:
message = f"\033[17m{message}\033[0m"
elif level == 1:
message = f"\033[34m {message}\033[0m"
elif level == 2:
message = f"\033[32m {message}\033[0m"
else:
message = f"\033[47m {message}\033[0m"
logger.info(message)

def check_port(port) -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", port))
Expand All @@ -73,13 +55,13 @@ def wait_for_server(proto: str, launch_browser: bool):
result = sock.connect_ex(('127.0.0.1', settings.port))
if result == 0:
break
print('-' * 8)
print(f"The server is listening on {proto}://{settings.host}:{settings.port}/")
print('-' * 8)
logger.info('-' * 8)
logger.info(f"The server is listening on {proto}://{settings.host}:{settings.port}/")
logger.info('-' * 8)
if launch_browser:
open_new_tab(f'{proto}://127.0.0.1:{settings.port}')

async def runServer(host: str, port: int, launch_browser: bool = False, logLevel: str = 'critical', key_path: str | None = None, cert_path: str | None = None):
async def runServer(host: str, port: int, launch_browser: bool = False, log_level: str = 'error', key_path: str | None = None, cert_path: str | None = None):
check_port(port)

config = uvicorn.Config(
Expand All @@ -89,7 +71,7 @@ async def runServer(host: str, port: int, launch_browser: bool = False, logLevel
reload=False,
ssl_keyfile=key_path,
ssl_certfile=cert_path,
log_level=logLevel
log_level=log_level
)
server = uvicorn.Server(config)

Expand All @@ -101,31 +83,25 @@ async def runServer(host: str, port: int, launch_browser: bool = False, logLevel
async def main(args):
logger.debug(args)

if not os.path.exists(DOTENV_FILE):
for key, value in settings.model_dump().items():
set_key(DOTENV_FILE, key.upper(), str(value))

printMessage(f"Python: {sys.version}", level=2)
printMessage(f"Voice changer version: {get_version()} {get_edition()}", level=2)
# printMessage("Voice Changerを起動しています。", level=2)
printMessage("Activating the Voice Changer.", level=2)
logger.info(f"Python: {sys.version}")
logger.info(f"Voice changer version: {get_version()} {get_edition()}")
# ダウンロード(Weight)

await downloadWeight(settings)

try:
await downloadInitialSamples(settings.sample_mode, settings.model_dir)
except Exception as e:
print(traceback.format_exc())
printMessage(f"Failed to download samples. Reason: {e}", level=2)
logger.error(f"Failed to download samples.")
logger.exception(e)

# FIXME: Need to refactor samples download logic
os.makedirs(settings.model_dir, exist_ok=True)
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(TMP_DIR, exist_ok=True)

# HTTPS key/cert作成
if args.https and args.httpsSelfSigned:
if args.https and args.https_self_signed:
# HTTPS(おれおれ証明書生成)
os.makedirs(SSL_KEY_DIR, exist_ok=True)
key_base_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
Expand All @@ -145,35 +121,32 @@ async def main(args):
)
key_path = os.path.join(SSL_KEY_DIR, keyname)
cert_path = os.path.join(SSL_KEY_DIR, certname)
printMessage(f"protocol: HTTPS(self-signed), key:{key_path}, cert:{cert_path}", level=1)
logger.info(f"protocol: HTTPS(self-signed), key:{key_path}, cert:{cert_path}")

elif args.https and not args.httpsSelfSigned:
elif args.https and not args.https_self_signed:
# HTTPS
key_path = args.httpsKey
cert_path = args.httpsCert
printMessage(f"protocol: HTTPS, key:{key_path}, cert:{cert_path}", level=1)
key_path = args.https_key
cert_path = args.https_cert
logger.info(f"protocol: HTTPS, key:{key_path}, cert:{cert_path}")
else:
# HTTP
printMessage("protocol: HTTP", level=1)
printMessage("-- ---- -- ", level=1)
logger.info("protocol: HTTP")

# サーバ起動
if args.https:
# HTTPS サーバ起動
await runServer(settings.host, settings.port, args.launch_browser, args.logLevel, key_path, cert_path)
await runServer(settings.host, settings.port, args.launch_browser, args.log_level, key_path, cert_path)
else:
await runServer(settings.host, settings.port, args.launch_browser, args.logLevel)
await runServer(settings.host, settings.port, args.launch_browser, args.log_level)


if __name__ == "__main__":
parser = setupArgParser()
args, _ = parser.parse_known_args()
args.launch_browser = False

printMessage(f"Booting PHASE :{__name__}", level=2)

try:
asyncio.run(main(args))
except Exception as e:
print(traceback.format_exc())
logger.exception(e)
input('Press Enter to continue...')
Loading

0 comments on commit ffc0cb7

Please sign in to comment.