Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNM] Proposal for Adding local model loading and automatically dependency installation in webui #91

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions ui/start_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,52 @@
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import importlib, pip

def fix_package_name(package):
a = package.split('>')[0]
a = a.split('<')[0]
b = a.split('=')[0]

package_name_map = {
}

if b in package_name_map:
b = package_name_map[b]
return b

def check_availability_and_install(package_or_list, verbose=1):
def actual_func(package):
pip_name = fix_package_name(package)
try:
if pip_name == 'gradio':
import gradio
a = gradio.Progress()
else:
return importlib.import_module(pip_name)
except ImportError:
print(f"automatically install for {package}")
pip.main(['install', '-q', package])
#importlib.import_module(pip_name)

if isinstance(package_or_list, list):
if verbose == 1 and len(package_or_list) > 0:
print(f"check_availability_and_install {package_or_list}")
for pkg in package_or_list:
actual_func(pkg)
elif isinstance(package_or_list, str):
if verbose == 1 and package_or_list != "":
print(f"check_availability_and_install {package_or_list}")
actual_func(package_or_list)
else:
raise ValueError(f"{package_or_list} with type of {type(package_or_list)} is not supported.")

check_availability_and_install(['gradio==3.36.1', 'gradio_client==0.7.3', 'langchain==0.1.4', 'langchain-community==0.0.16', 'lz4', 'sentence-transformers==2.2.2', 'pyrecdp'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a script in #101 to install the UI environment, maybe we don't need to check it in code anymore. And we have to use pip install 'git+https://github.com/intel/e2eAIOK.git#egg=pyrecdp&subdirectory=RecDP' to install pyrecdp, pip install pyrecdp does not contain the latest code and will change the ray version.

Collecting ray==2.7.1 (from pyrecdp)
  Using cached ray-2.7.1-cp39-cp39-manylinux2014_x86_64.whl.metadata (13 kB)


if not os.environ['RECDP_CACHE_HOME']:
os.environ['RECDP_CACHE_HOME'] = os.getcwd()

from inference.inference_config import all_models, ModelDescription, Prompt
from inference.inference_config import InferenceConfig as FinetunedConfig
from inference.chat_process import ChatModelGptJ, ChatModelLLama # noqa: F401
Expand Down Expand Up @@ -48,6 +94,13 @@
RAGTextFix,
)
from pyrecdp.primitives.document.reader import _default_file_readers
from pyrecdp.core.cache_utils import RECDP_MODELS_CACHE

import logging

lib_list = ["httpcore", "httpx", "paramiko", "urllib3", "markdown_it", "matplotlib"]
for lib in lib_list:
logging.getLogger(lib).setLevel(logging.ERROR)
Comment on lines +101 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is to solve the DEBUG log problem, I think it can be solved after #101 is merged.



class CustomStopper(Stopper):
Expand Down Expand Up @@ -143,7 +196,12 @@ def __init__(
self.finetune_actor = None
self.finetune_status = False
self.default_rag_path = default_rag_path
self.embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
local_embedding_model_path = os.path.join(RECDP_MODELS_CACHE, "sentence-transformers/all-mpnet-base-v2")
print(local_embedding_model_path)
if os.path.exists(local_embedding_model_path):
self.embedding_model_name = local_embedding_model_path
else:
self.embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model_name)

self._init_ui()
Expand Down Expand Up @@ -794,23 +852,24 @@ def set_rag_default_path(self, selector, rag_path):

def _init_ui(self):
mark_alive = None
private_key = paramiko.Ed25519Key.from_private_key_file("/root/.ssh/id_ed25519")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a custom path to set private key? self.ssh_connect[index].load_system_host_keys() can load the content of '~/.ssh/known_hosts' by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, the id_rsa decoder is reporting error from my env, I found similiar report on google.
Using a path provides flexibility for user to decide if they want to use ED25519 or RSA. I believe in some "very Healthy" system, self.ssh_connect[index].load_system_host_keys() may also works, but not in my system...
So if you can provides this automatic option + a user defined path, it will greatly prevent user hard code to make web_ui working.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see. I'm not sure about the format of your known_hosts, could you verify this method in your env?
For example, the node ip is 10.1.0.133, then run ssh-keyscan 10.1.0.133 >> ~/.ssh/known_hosts. And then modify 10.1.0.133 to [10.1.0.133]:22 like this:
image
If this doesn't work, it really needs a fallback method.

for index in range(len(self.ray_nodes)):
if "node:__internal_head__" in ray.nodes()[index]["Resources"]:
mark_alive = index
node_ip = self.ray_nodes[index]["NodeName"]
self.ssh_connect[index] = paramiko.SSHClient()
self.ssh_connect[index].load_system_host_keys()
self.ssh_connect[index].set_missing_host_key_policy(paramiko.RejectPolicy())
#self.ssh_connect[index].load_system_host_keys()
self.ssh_connect[index].set_missing_host_key_policy(paramiko.AutoAddPolicy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use AutoAddPolicy because of code scan issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KepingYan , I see, thanks for the clarification, I'll make the change

self.ssh_connect[index].connect(
hostname=node_ip, port=self.node_port, username=self.user_name
hostname=node_ip, username=self.user_name, pkey=private_key
)
self.ssh_connect[-1] = paramiko.SSHClient()
self.ssh_connect[-1].load_system_host_keys()
self.ssh_connect[-1].set_missing_host_key_policy(paramiko.RejectPolicy())
#self.ssh_connect[-1].load_system_host_keys()
self.ssh_connect[-1].set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.ssh_connect[-1].connect(
hostname=self.ray_nodes[mark_alive]["NodeName"],
port=self.node_port,
username=self.user_name,
pkey=private_key
)

title = "Manage LLM Lifecycle"
Expand Down