-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DNM] Proposal for Adding local model loading and automatically dependency installation in webui #91
base: main
Are you sure you want to change the base?
[DNM] Proposal for Adding local model loading and automatically dependency installation in webui #91
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,52 @@ | |
import sys | ||
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | ||
|
||
import importlib, pip | ||
|
||
def fix_package_name(package): | ||
a = package.split('>')[0] | ||
a = a.split('<')[0] | ||
b = a.split('=')[0] | ||
|
||
package_name_map = { | ||
} | ||
|
||
if b in package_name_map: | ||
b = package_name_map[b] | ||
return b | ||
|
||
def check_availability_and_install(package_or_list, verbose=1): | ||
def actual_func(package): | ||
pip_name = fix_package_name(package) | ||
try: | ||
if pip_name == 'gradio': | ||
import gradio | ||
a = gradio.Progress() | ||
else: | ||
return importlib.import_module(pip_name) | ||
except ImportError: | ||
print(f"automatically install for {package}") | ||
pip.main(['install', '-q', package]) | ||
#importlib.import_module(pip_name) | ||
|
||
if isinstance(package_or_list, list): | ||
if verbose == 1 and len(package_or_list) > 0: | ||
print(f"check_availability_and_install {package_or_list}") | ||
for pkg in package_or_list: | ||
actual_func(pkg) | ||
elif isinstance(package_or_list, str): | ||
if verbose == 1 and package_or_list != "": | ||
print(f"check_availability_and_install {package_or_list}") | ||
actual_func(package_or_list) | ||
else: | ||
raise ValueError(f"{package_or_list} with type of {type(package_or_list)} is not supported.") | ||
|
||
check_availability_and_install(['gradio==3.36.1', 'gradio_client==0.7.3', 'langchain==0.1.4', 'langchain-community==0.0.16', 'lz4', 'sentence-transformers==2.2.2', 'pyrecdp']) | ||
|
||
if not os.environ['RECDP_CACHE_HOME']: | ||
os.environ['RECDP_CACHE_HOME'] = os.getcwd() | ||
|
||
from inference.inference_config import all_models, ModelDescription, Prompt | ||
from inference.inference_config import InferenceConfig as FinetunedConfig | ||
from inference.chat_process import ChatModelGptJ, ChatModelLLama # noqa: F401 | ||
|
@@ -48,6 +94,13 @@ | |
RAGTextFix, | ||
) | ||
from pyrecdp.primitives.document.reader import _default_file_readers | ||
from pyrecdp.core.cache_utils import RECDP_MODELS_CACHE | ||
|
||
import logging | ||
|
||
lib_list = ["httpcore", "httpx", "paramiko", "urllib3", "markdown_it", "matplotlib"] | ||
for lib in lib_list: | ||
logging.getLogger(lib).setLevel(logging.ERROR) | ||
Comment on lines
+101
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is to solve the DEBUG log problem, I think it can be solved after #101 is merged. |
||
|
||
|
||
class CustomStopper(Stopper): | ||
|
@@ -143,7 +196,12 @@ def __init__( | |
self.finetune_actor = None | ||
self.finetune_status = False | ||
self.default_rag_path = default_rag_path | ||
self.embedding_model_name = "sentence-transformers/all-mpnet-base-v2" | ||
local_embedding_model_path = os.path.join(RECDP_MODELS_CACHE, "sentence-transformers/all-mpnet-base-v2") | ||
print(local_embedding_model_path) | ||
if os.path.exists(local_embedding_model_path): | ||
self.embedding_model_name = local_embedding_model_path | ||
else: | ||
self.embedding_model_name = "sentence-transformers/all-mpnet-base-v2" | ||
self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model_name) | ||
|
||
self._init_ui() | ||
|
@@ -794,23 +852,24 @@ def set_rag_default_path(self, selector, rag_path): | |
|
||
def _init_ui(self): | ||
mark_alive = None | ||
private_key = paramiko.Ed25519Key.from_private_key_file("/root/.ssh/id_ed25519") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why use a custom path to set private key? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For some reason, the id_rsa decoder is reporting error from my env, I found similiar report on google. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I see. I'm not sure about the format of your known_hosts, could you verify this method in your env? |
||
for index in range(len(self.ray_nodes)): | ||
if "node:__internal_head__" in ray.nodes()[index]["Resources"]: | ||
mark_alive = index | ||
node_ip = self.ray_nodes[index]["NodeName"] | ||
self.ssh_connect[index] = paramiko.SSHClient() | ||
self.ssh_connect[index].load_system_host_keys() | ||
self.ssh_connect[index].set_missing_host_key_policy(paramiko.RejectPolicy()) | ||
#self.ssh_connect[index].load_system_host_keys() | ||
self.ssh_connect[index].set_missing_host_key_policy(paramiko.AutoAddPolicy()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't use AutoAddPolicy because of code scan issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @KepingYan , I see, thanks for the clarification, I'll make the change |
||
self.ssh_connect[index].connect( | ||
hostname=node_ip, port=self.node_port, username=self.user_name | ||
hostname=node_ip, username=self.user_name, pkey=private_key | ||
) | ||
self.ssh_connect[-1] = paramiko.SSHClient() | ||
self.ssh_connect[-1].load_system_host_keys() | ||
self.ssh_connect[-1].set_missing_host_key_policy(paramiko.RejectPolicy()) | ||
#self.ssh_connect[-1].load_system_host_keys() | ||
self.ssh_connect[-1].set_missing_host_key_policy(paramiko.AutoAddPolicy()) | ||
self.ssh_connect[-1].connect( | ||
hostname=self.ray_nodes[mark_alive]["NodeName"], | ||
port=self.node_port, | ||
username=self.user_name, | ||
pkey=private_key | ||
) | ||
|
||
title = "Manage LLM Lifecycle" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a script in #101 to install the UI environment, maybe we don't need to check it in code anymore. And we have to use
pip install 'git+https://github.com/intel/e2eAIOK.git#egg=pyrecdp&subdirectory=RecDP'
to install pyrecdp,pip install pyrecdp
does not contain the latest code and will change the ray version.