From 2d16e4d387c423aeb22bdc6ed90d4c19ddc3a5af Mon Sep 17 00:00:00 2001 From: Xiaochang Wu Date: Tue, 9 Apr 2024 09:15:04 +0800 Subject: [PATCH 1/5] Add OMP_PROC_BIND environment variable for VLLM-enabled inference (#180) Signed-off-by: Wu, Xiaochang --- llm_on_ray/inference/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llm_on_ray/inference/utils.py b/llm_on_ray/inference/utils.py index 152a336d3..3a0f4575e 100644 --- a/llm_on_ray/inference/utils.py +++ b/llm_on_ray/inference/utils.py @@ -38,6 +38,8 @@ def get_deployment_actor_options(infer_conf: InferenceConfig): runtime_env[_ray_env_key].update(_predictor_runtime_env_ipex) if infer_conf.deepspeed: runtime_env[_ray_env_key]["DS_ACCELERATOR"] = infer_conf.device + if infer_conf.vllm.enabled: + runtime_env[_ray_env_key]["OMP_PROC_BIND"] = "true" # now PredictorDeployment itself is a worker, we should require resources for it ray_actor_options: Dict[str, Any] = {"runtime_env": runtime_env} if infer_conf.device == "cpu": From 4d2b8e9c0a09b247b2fccf8f5d5a17825df010a9 Mon Sep 17 00:00:00 2001 From: harborn Date: Tue, 9 Apr 2024 09:15:45 +0800 Subject: [PATCH 2/5] [Finetune] Update doc, fine-tuning log field info, and fix evaluation speed. (#170) * some fixes and updates * update * update * update * update --- docs/setup.md | 7 +++-- llm_on_ray/common/trainer/default_trainer.py | 31 ++++++++++---------- llm_on_ray/finetune/finetune.py | 4 +-- llm_on_ray/finetune/finetune_config.py | 1 + 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/docs/setup.md b/docs/setup.md index 2451ff3aa..a411dc527 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -23,6 +23,7 @@ Intel® 1st, 2nd, 3rd, and 4th Gen Xeon® Scalable Performance processor ### Software Requirements - Git - Conda +- Docker ## Setup @@ -38,11 +39,11 @@ cd llm-on-ray conda create -n llm-on-ray python=3.9 conda activate llm-on-ray ``` -For CPU: +##### For CPU: ```bash pip install .[cpu] --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ ``` -For GPU: +##### For GPU: ```bash pip install .[gpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ ``` @@ -51,7 +52,7 @@ If DeepSpeed is enabled or doing distributed finetuing, oneCCL and Intel MPI lib source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl; print(torch_ccl.cwd)")/env/setvars.sh ``` -For Gaudi: +##### For Gaudi: Please use the [Dockerfile](../dev/docker/Dockerfile.habana) to build the image. Alternatively, you can install the dependecies on a bare metal machine. In this case, please refer to [here](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#build-docker-bare). diff --git a/llm_on_ray/common/trainer/default_trainer.py b/llm_on_ray/common/trainer/default_trainer.py index 122645ff9..366d6f28b 100644 --- a/llm_on_ray/common/trainer/default_trainer.py +++ b/llm_on_ray/common/trainer/default_trainer.py @@ -196,13 +196,13 @@ def train(self): num_train_epochs = self.config.get("num_train_epochs", 1) checkpoint = self.config.get("checkpoint") logging_steps = self.config.get("logging_steps", 1) - max_train_step = self.config.get("max_train_step") - max_eval_step = self.config.get("max_eval_step") + max_train_steps = self.config.get("max_train_steps") + steps_per_epoch = len(self.train_dataloader) + completed_steps = self.starting_epoch * steps_per_epoch for idx in range(self.starting_epoch, num_train_epochs, 1): self.model.train() start = time.time() - total_steps = len(self.train_dataloader) - logger.info(f"Start training epoch {idx}, total_steps {total_steps}") + logger.info(f"Start training epoch {idx}, steps_per_epoch {steps_per_epoch}") for step, batch in enumerate(self.train_dataloader): with self.accelerator.accumulate(self.model): self.model.train() @@ -224,7 +224,7 @@ def train(self): if step % logging_steps == 0: loss = loss.item() ppl = math.exp(loss) - epochs = (step + idx * total_steps) / (num_train_epochs * total_steps) + epochs = idx + step / steps_per_epoch logger.info( f"train epoch:{epochs:.6f}\tloss:{loss:.6f}\tppl:{ppl:.6f}\ttime:{time.time()-start:.6f}" ) @@ -235,15 +235,18 @@ def train(self): "train_epoch": idx, "total_epochs": num_train_epochs, "train_step": step, - "total_steps": min(max_train_step, total_steps) - if max_train_step - else total_steps, + "completed_steps": completed_steps, + "total_steps": min( + max_train_steps, steps_per_epoch * num_train_epochs + ) + if max_train_steps + else steps_per_epoch * num_train_epochs, } ) start = time.time() - if max_train_step is not None: - if step >= max_train_step - 1: - break + completed_steps += 1 + if max_train_steps is not None and completed_steps >= max_train_steps: + break if self.eval_dataloader: logger.info(f"start eval epoch {idx}") @@ -251,6 +254,7 @@ def train(self): start = time.time() losses = [] for step, batch in enumerate(self.eval_dataloader): + batch = batch.to(device=self.accelerator.device) with torch.no_grad(): outputs = self.model(**batch) loss = outputs.loss @@ -259,9 +263,6 @@ def train(self): loss.repeat(batch["input_ids"].shape[0]) ) ) - if max_eval_step is not None: - if step >= max_eval_step: - break losses = torch.cat(losses) try: @@ -271,7 +272,7 @@ def train(self): eval_loss = float("inf") perplexity = float("inf") logger.info( - f"eval epoch:[{idx}/{num_train_epochs}]\tloss:[{eval_loss:.6f}]\tppl:[{perplexity:.6f}]\ttime:[{time.time()-start:.6f}]" + f"eval epoch:{idx}\tloss:{eval_loss:.6f}\tppl:{perplexity:.6f}\ttime:{time.time()-start:.6f}" ) if checkpoint is not None: diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py index 7b21c3218..b31a5f01d 100644 --- a/llm_on_ray/finetune/finetune.py +++ b/llm_on_ray/finetune/finetune.py @@ -98,7 +98,7 @@ def get_accelerate_environment_variable(config: Dict[str, Any]) -> dict: "ACCELERATE_MIXED_PRECISION": mixed_precision, }, }, - "hpu:": { + "hpu": { "DDP": { "ACCELERATE_USE_CPU": "false", "ACCELERATE_USE_XPU": "false", @@ -232,7 +232,7 @@ def train_func(config: Dict[str, Any]): "device": config["Training"]["device"], "accelerate_mode": config["Training"]["accelerate_mode"], "num_train_epochs": epochs, - "max_train_step": config["Training"].get("max_train_steps", None), + "max_train_steps": config["Training"].get("max_train_steps", None), "logging_steps": config["Training"].get("logging_steps", 1), "output": output_dir, "dataprocesser": { diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index 834829897..8f5f6ed6f 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -81,6 +81,7 @@ class Training(BaseModel): optimizer: str batch_size: int epochs: int + max_train_steps: Optional[int] = None learning_rate: float lr_scheduler: str weight_decay: float From 518e3b0b5405ea5831ab801c07609f6ae73a6712 Mon Sep 17 00:00:00 2001 From: Xiaochang Wu Date: Tue, 9 Apr 2024 09:19:24 +0800 Subject: [PATCH 3/5] Support latest Ray 2.10 release (#158) * update * fix blocking * update Signed-off-by: Wu, Xiaochang * update Signed-off-by: Wu, Xiaochang * fix setup and getting started Signed-off-by: Wu, Xiaochang * update Signed-off-by: Wu, Xiaochang * update Signed-off-by: Wu, Xiaochang * nit Signed-off-by: Wu, Xiaochang * Add dependencies for tests and update pyproject.toml Signed-off-by: Wu, Xiaochang * Update dependencies and test workflow Signed-off-by: Wu, Xiaochang * Update dependencies and fix torch_dist.py Signed-off-by: Wu, Xiaochang * Update OpenAI SDK installation and start ray cluster Signed-off-by: Wu, Xiaochang --------- Signed-off-by: Wu, Xiaochang --- llm_on_ray/inference/api_server_openai.py | 3 +-- llm_on_ray/inference/api_server_simple.py | 5 ++--- llm_on_ray/inference/torch_dist.py | 4 ++-- pyproject.toml | 10 +++++----- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/llm_on_ray/inference/api_server_openai.py b/llm_on_ray/inference/api_server_openai.py index 4bc42e99b..a91f9f112 100644 --- a/llm_on_ray/inference/api_server_openai.py +++ b/llm_on_ray/inference/api_server_openai.py @@ -57,12 +57,11 @@ def router_application(deployments, max_concurrent_queries): def openai_serve_run(deployments, host, route_prefix, port, max_concurrent_queries): router_app = router_application(deployments, max_concurrent_queries) + serve.start(http_options={"host": host, "port": port}) serve.run( router_app, name="router", route_prefix=route_prefix, - host=host, - _blocking=True, ).options( stream=True, use_new_handle_api=True, diff --git a/llm_on_ray/inference/api_server_simple.py b/llm_on_ray/inference/api_server_simple.py index 0663700d8..f2cf0a1e7 100644 --- a/llm_on_ray/inference/api_server_simple.py +++ b/llm_on_ray/inference/api_server_simple.py @@ -22,11 +22,10 @@ def serve_run(deployments, model_list): for model_id, infer_conf in model_list.items(): print("deploy model: ", model_id) deployment = deployments[model_id] + + serve.start(http_options={"host": infer_conf.host, "port": infer_conf.port}) serve.run( deployment, - _blocking=True, - host=infer_conf.host, - port=infer_conf.port, name=infer_conf.name, route_prefix=infer_conf.route_prefix, ) diff --git a/llm_on_ray/inference/torch_dist.py b/llm_on_ray/inference/torch_dist.py index c99baf0c0..91358db03 100644 --- a/llm_on_ray/inference/torch_dist.py +++ b/llm_on_ray/inference/torch_dist.py @@ -44,7 +44,7 @@ import ray from ray.actor import ActorHandle from ray.train._internal.utils import get_address_and_port -from ray.air._internal.torch_utils import get_device +from ray.air._internal.torch_utils import get_devices from ray._private.accelerators.hpu import HPU_PACKAGE_AVAILABLE if HPU_PACKAGE_AVAILABLE: @@ -212,7 +212,7 @@ def _shutdown_torch_distributed(): return # Clean up cuda memory. - devices = get_device() + devices = get_devices() for device in devices: with torch.cuda.device(device): torch.cuda.empty_cache() diff --git a/pyproject.toml b/pyproject.toml index 4cd11d4a5..e9462638f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,11 +21,10 @@ dependencies = [ "accelerate", "datasets>=2.14.6", "numpy", - "ray<2.10", + "ray>=2.10", + "ray[serve,tune]>=2.10", "typing>=3.7.4.3", "tabulate", - "ray[tune]", - "ray[serve]", "gymnasium", "dm-tree", "tensorboard", @@ -35,7 +34,8 @@ dependencies = [ "deltatuner==1.1.9", "py-cpuinfo", "pydantic-yaml", - "async-timeout" + "async_timeout", + "typer" ] [project.optional-dependencies] @@ -85,4 +85,4 @@ llm_on_ray-pretrain = "llm_on_ray.pretrain.pretrain:main" llm_on_ray-megatron_deepspeed_pretrain = "llm_on_ray.pretrain.megatron_deepspeed_pretrain:main" [tool.black] -line-length = 100 +line-length = 100 \ No newline at end of file From bb798698057acc5bb6dc543ff4a88acd87f6c094 Mon Sep 17 00:00:00 2001 From: yutianchen Date: Tue, 9 Apr 2024 15:38:35 +0800 Subject: [PATCH 4/5] [Tests] Add query single test (#156) * single test * single test * single test * single test * fix hang error --- tests/inference/test_query_single.py | 107 +++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 tests/inference/test_query_single.py diff --git a/tests/inference/test_query_single.py b/tests/inference/test_query_single.py new file mode 100644 index 000000000..1c32f6b73 --- /dev/null +++ b/tests/inference/test_query_single.py @@ -0,0 +1,107 @@ +import subprocess +import pytest +import os + +os.environ["no_proxy"] = "localhost,127.0.0.1" + + +def start_serve(model_name): + current_path = os.path.dirname(os.path.abspath(__file__)) + + config_path = os.path.join( + current_path, "../../.github/workflows/config/" + model_name + "-ci.yaml" + ) + + cmd_serve = ["llm_on_ray-serve", "--config_file", config_path, "--simple"] + + result_serve = subprocess.run(cmd_serve, capture_output=True, text=True) + + # Ensure there are no errors in the serve script execution + assert result_serve.returncode == 0, print( + "\n" + "Serve error stderr message: " + "\n", result_serve.stderr + ) + + # Print the output of subprocess.run for checking if output is expected + print("\n" + "Serve message: " + "\n", result_serve.stdout) + + # Ensure there are no errors in the serve script execution + assert "Error" not in result_serve.stderr + + +def script_with_args( + base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k +): + current_path = os.path.dirname(os.path.abspath(__file__)) + + os.path.join(current_path, "../../.github/workflows/config/" + model_name + "-ci.yaml") + + example_query_single_path = os.path.join( + current_path, "../../examples/inference/api_server_simple/query_single.py" + ) + + cmd_single = [ + "python", + example_query_single_path, + "--model_endpoint", + base_url + model_name, + ] + + if streaming_response: + cmd_single.append("--streaming_response") + + if max_new_tokens is not None: + cmd_single.extend(["--max_new_tokens", str(max_new_tokens)]) + + if temperature is not None: + cmd_single.extend(["--temperature", str(temperature)]) + + if top_p is not None: + cmd_single.extend(["--top_p", str(top_p)]) + + if top_k is not None: + cmd_single.extend(["--top_k", str(top_k)]) + + result_query_single = subprocess.run(cmd_single, capture_output=True, text=True) + + # Print the output of subprocess.run for checking if output is expected + print(result_query_single) + + # Ensure there are no errors in the OpenAI API query script execution + assert "Error" not in result_query_single.stderr + + # Returncode should be 0 when there is no exception + assert result_query_single.returncode == 0 + + +executed_models = {} + + +# Parametrize the test function with different combinations of parameters +# TODO: more models and combinations will be added and tested. +@pytest.mark.parametrize( + "base_url,model_name,streaming_response,max_new_tokens,temperature,top_p, top_k", + [ + (base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k) + for base_url in ["http://localhost:8000/"] + for model_name in ["gpt2"] + for streaming_response in [None] + for max_new_tokens in [None] + for temperature in [None] + for top_p in [None] + for top_k in [None] + ], +) +def test_script( + base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k +): + global executed_models + + # Check if this modelname has already executed start_serve + if model_name not in executed_models: + start_serve(model_name) + # Mark this modelname has already executed start_serve + executed_models[model_name] = True + + script_with_args( + base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k + ) From 918290775735c21b1c35398e141becbffa4f247f Mon Sep 17 00:00:00 2001 From: minmingzhu <45281494+minmingzhu@users.noreply.github.com> Date: Wed, 10 Apr 2024 11:22:03 +0000 Subject: [PATCH 5/5] [Finetune] use base model mpt-7b instead of mpt-7b-chat (#181) * use base model mpt-7b instead of mpt-7b-chat Signed-off-by: minmingzhu * manual setting specify tokenizer Signed-off-by: minmingzhu * update Signed-off-by: minmingzhu * update doc/finetune_parameters.md Signed-off-by: minmingzhu --------- Signed-off-by: minmingzhu --- .github/workflows/night_build_memo.txt | 2 +- .github/workflows/workflow_finetune.yml | 6 +++--- docs/finetune_parameters.md | 1 + llm_on_ray/finetune/finetune.py | 6 +++++- llm_on_ray/finetune/finetune_config.py | 1 + .../finetune/models/{mpt-7b-chat.yaml => mpt-7b.yaml} | 3 ++- 6 files changed, 13 insertions(+), 6 deletions(-) rename llm_on_ray/finetune/models/{mpt-7b-chat.yaml => mpt-7b.yaml} (91%) diff --git a/.github/workflows/night_build_memo.txt b/.github/workflows/night_build_memo.txt index e5197571c..520e176e1 100644 --- a/.github/workflows/night_build_memo.txt +++ b/.github/workflows/night_build_memo.txt @@ -1 +1 @@ -finetune: gpt2, bigscience/bloom-560m, facebook/opt-125m, mosaicml/mpt-7b-chat, huggyllama/llama-7b \ No newline at end of file +finetune: gpt2, bigscience/bloom-560m, facebook/opt-125m, mosaicml/mpt-7b, huggyllama/llama-7b \ No newline at end of file diff --git a/.github/workflows/workflow_finetune.yml b/.github/workflows/workflow_finetune.yml index 76f1097a4..ddc547774 100644 --- a/.github/workflows/workflow_finetune.yml +++ b/.github/workflows/workflow_finetune.yml @@ -34,7 +34,7 @@ jobs: name: finetune strategy: matrix: - model: [ EleutherAI/gpt-j-6b, meta-llama/Llama-2-7b-chat-hf, gpt2, bigscience/bloom-560m, facebook/opt-125m, mosaicml/mpt-7b-chat, meta-llama/Llama-2-7b-hf, mistralai/Mistral-7B-v0.1, google/gemma-2b] + model: [ EleutherAI/gpt-j-6b, meta-llama/Llama-2-7b-chat-hf, gpt2, bigscience/bloom-560m, facebook/opt-125m, mosaicml/mpt-7b, meta-llama/Llama-2-7b-hf, mistralai/Mistral-7B-v0.1, google/gemma-2b] isPR: - ${{inputs.ci_type == 'pr'}} @@ -92,7 +92,7 @@ jobs: with open(conf_path, encoding="utf-8") as reader: result = yaml.load(reader, Loader=yaml.FullLoader) result['General']['base_model'] = "${{ matrix.model }}" - if "${{ matrix.model }}" == "mosaicml/mpt-7b-chat": + if "${{ matrix.model }}" == "mosaicml/mpt-7b": result['General']['config']['trust_remote_code'] = True else: result['General']['config']['trust_remote_code'] = False @@ -147,7 +147,7 @@ jobs: - name: Run Deltatuner Test on DENAS-LoRA Model run: | - if [[ ${{ matrix.model }} =~ ^(mosaicml\/mpt-7b-chat|huggyllama\/llama-7b|meta-llama\/Llama-2-7b-chat-hf|mistralai\/Mistral-7B-v0.1|google\/gemma-2b)$ ]]; then + if [[ ${{ matrix.model }} =~ ^(mosaicml\/mpt-7b|huggyllama\/llama-7b|meta-llama\/Llama-2-7b-chat-hf|mistralai\/Mistral-7B-v0.1|google\/gemma-2b)$ ]]; then echo ${{ matrix.model }} is not supported! else docker exec "finetune" bash -c "rm -rf /tmp/llm-ray/*" diff --git a/docs/finetune_parameters.md b/docs/finetune_parameters.md index 531549adf..5d24f42e6 100644 --- a/docs/finetune_parameters.md +++ b/docs/finetune_parameters.md @@ -7,6 +7,7 @@ The following are the parameters supported in the finetuning workflow. |Configuration Name| Default|Meaning| |-|-|-| |base_model| EleutherAI/gpt-j-6b|Path to pretrained model or model identifier from huggingface.co/models| +|tokenizer_name|None|Path to pretrained tokenizer from huggingface.co/models. If not provided, the tokenizer will be loaded from the `base_model`.| |gpt_base_model|True|This parameter is for [Transformers#22482](https://github.com/huggingface/transformers/issues/22482). It needs to be set to True when the pretrained model is realted to gpt, otherwise it is False.| |output_dir|/tmp/llm-ray/output|The output directory to store the finetuned model| |checkpoint_dir|/tmp/llm-ray/checkpoint|The directory to store checkpoint| diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py index b31a5f01d..0f9e96f96 100644 --- a/llm_on_ray/finetune/finetune.py +++ b/llm_on_ray/finetune/finetune.py @@ -155,6 +155,10 @@ def train_func(config: Dict[str, Any]): gradient_accumulation_steps = config["Training"].get("gradient_accumulation_steps", 1) base_model = config["General"]["base_model"] + if config["General"].get("tokenizer_name") is not None: + tokenizer_name = config["General"].get("tokenizer_name") + else: + tokenizer_name = base_model dataset_file = config["Dataset"]["train_file"] seed = config["Training"].get("seed") @@ -171,7 +175,7 @@ def train_func(config: Dict[str, Any]): tokenizer = common.tokenizer.Tokenizer.registory.get("HuggingFaceTokenizer")()( config={ - "name": base_model, + "name": tokenizer_name, "config": config["General"]["config"], } ) diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index 8f5f6ed6f..a01095c16 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -52,6 +52,7 @@ class DeltatunerConfig(BaseModel): class General(BaseModel): base_model: str + tokenizer_name: Optional[str] = None gpt_base_model: bool output_dir: str checkpoint_dir: Optional[str] diff --git a/llm_on_ray/finetune/models/mpt-7b-chat.yaml b/llm_on_ray/finetune/models/mpt-7b.yaml similarity index 91% rename from llm_on_ray/finetune/models/mpt-7b-chat.yaml rename to llm_on_ray/finetune/models/mpt-7b.yaml index b4644194f..067a093a2 100644 --- a/llm_on_ray/finetune/models/mpt-7b-chat.yaml +++ b/llm_on_ray/finetune/models/mpt-7b.yaml @@ -1,5 +1,6 @@ General: - base_model: mosaicml/mpt-7b-chat + base_model: mosaicml/mpt-7b + tokenizer_name: EleutherAI/gpt-neox-20b gpt_base_model: false output_dir: /tmp/llm-ray/output checkpoint_dir: /tmp/llm-ray/checkpoint