diff --git a/ui/start_ui.py b/ui/start_ui.py index d4c05e4a5..c1be933dd 100644 --- a/ui/start_ui.py +++ b/ui/start_ui.py @@ -20,6 +20,52 @@ import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +import importlib, pip + +def fix_package_name(package): + a = package.split('>')[0] + a = a.split('<')[0] + b = a.split('=')[0] + + package_name_map = { + } + + if b in package_name_map: + b = package_name_map[b] + return b + +def check_availability_and_install(package_or_list, verbose=1): + def actual_func(package): + pip_name = fix_package_name(package) + try: + if pip_name == 'gradio': + import gradio + a = gradio.Progress() + else: + return importlib.import_module(pip_name) + except ImportError: + print(f"automatically install for {package}") + pip.main(['install', '-q', package]) + #importlib.import_module(pip_name) + + if isinstance(package_or_list, list): + if verbose == 1 and len(package_or_list) > 0: + print(f"check_availability_and_install {package_or_list}") + for pkg in package_or_list: + actual_func(pkg) + elif isinstance(package_or_list, str): + if verbose == 1 and package_or_list != "": + print(f"check_availability_and_install {package_or_list}") + actual_func(package_or_list) + else: + raise ValueError(f"{package_or_list} with type of {type(package_or_list)} is not supported.") + +check_availability_and_install(['gradio==3.36.1', 'gradio_client==0.7.3', 'langchain==0.1.4', 'langchain-community==0.0.16', 'lz4', 'sentence-transformers==2.2.2', 'pyrecdp']) + +if not os.environ['RECDP_CACHE_HOME']: + os.environ['RECDP_CACHE_HOME'] = os.getcwd() + from inference.inference_config import all_models, ModelDescription, Prompt from inference.inference_config import InferenceConfig as FinetunedConfig from inference.chat_process import ChatModelGptJ, ChatModelLLama # noqa: F401 @@ -48,6 +94,13 @@ RAGTextFix, ) from pyrecdp.primitives.document.reader import _default_file_readers +from pyrecdp.core.cache_utils import RECDP_MODELS_CACHE + +import logging + +lib_list = ["httpcore", "httpx", "paramiko", "urllib3", "markdown_it", "matplotlib"] +for lib in lib_list: + logging.getLogger(lib).setLevel(logging.ERROR) class CustomStopper(Stopper): @@ -143,7 +196,12 @@ def __init__( self.finetune_actor = None self.finetune_status = False self.default_rag_path = default_rag_path - self.embedding_model_name = "sentence-transformers/all-mpnet-base-v2" + local_embedding_model_path = os.path.join(RECDP_MODELS_CACHE, "sentence-transformers/all-mpnet-base-v2") + print(local_embedding_model_path) + if os.path.exists(local_embedding_model_path): + self.embedding_model_name = local_embedding_model_path + else: + self.embedding_model_name = "sentence-transformers/all-mpnet-base-v2" self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model_name) self._init_ui() @@ -794,23 +852,24 @@ def set_rag_default_path(self, selector, rag_path): def _init_ui(self): mark_alive = None + private_key = paramiko.Ed25519Key.from_private_key_file("/root/.ssh/id_ed25519") for index in range(len(self.ray_nodes)): if "node:__internal_head__" in ray.nodes()[index]["Resources"]: mark_alive = index node_ip = self.ray_nodes[index]["NodeName"] self.ssh_connect[index] = paramiko.SSHClient() - self.ssh_connect[index].load_system_host_keys() - self.ssh_connect[index].set_missing_host_key_policy(paramiko.RejectPolicy()) + #self.ssh_connect[index].load_system_host_keys() + self.ssh_connect[index].set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.ssh_connect[index].connect( - hostname=node_ip, port=self.node_port, username=self.user_name + hostname=node_ip, username=self.user_name, pkey=private_key ) self.ssh_connect[-1] = paramiko.SSHClient() - self.ssh_connect[-1].load_system_host_keys() - self.ssh_connect[-1].set_missing_host_key_policy(paramiko.RejectPolicy()) + #self.ssh_connect[-1].load_system_host_keys() + self.ssh_connect[-1].set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.ssh_connect[-1].connect( hostname=self.ray_nodes[mark_alive]["NodeName"], - port=self.node_port, username=self.user_name, + pkey=private_key ) title = "Manage LLM Lifecycle"