From 352b64ebed5b1fffe1aced0580a843d19b0f2977 Mon Sep 17 00:00:00 2001 From: KepingYan Date: Thu, 11 Jan 2024 21:05:03 +0800 Subject: [PATCH] [Lint] add lint (#34) * add pre-commit, port from https://github.com/intel-sandbox/llm-ray/pull/185 * move lint check to github * modify permission * test * test * test * test * move to ubuntu-latest * fix lint * fix parameter error * Recover lines that should not be deleted * test lint in ci * done * add needs in ci * move hpu tokenizer --- .../update_finetune_config_on_intel_gpu.py | 2 +- .../config/update_inference_config.py | 4 +- .github/workflows/workflow_lint.yml | 28 + .../workflows/workflow_orders_on_merge.yml | 2 + .github/workflows/workflow_orders_on_pr.yml | 4 + .pre-commit-config.yaml | 30 + common/__init__.py | 14 +- common/agentenv/__init__.py | 2 +- common/agentenv/agentenv.py | 3 +- common/agentenv/rlhf_env.py | 82 +- common/common.py | 7 +- common/config.py | 30 +- common/dataprocesser/__init__.py | 2 +- common/dataprocesser/dataprocesser.py | 4 +- common/dataprocesser/general_processer.py | 47 +- common/dataprocesser/rm_dataprocesser.py | 32 +- common/dataset/__init__.py | 2 +- common/dataset/dataset.py | 6 +- common/dataset/huggingface_dataset.py | 14 +- common/init.py | 4 +- common/initializer/__init__.py | 2 +- common/initializer/initializer.py | 5 +- common/load.py | 10 +- common/logging.py | 22 +- .../model/huggingface_model_for_causal_lm.py | 1 + common/model/model.py | 6 +- common/model/reward_model.py | 20 +- common/optimizer/__init__.py | 2 +- common/optimizer/default_optimizer.py | 9 +- common/optimizer/group_optimizer.py | 17 +- common/optimizer/optimizer.py | 6 +- common/tokenizer/empty_tokenizer.py | 5 +- common/tokenizer/huggingface_tokenizer.py | 3 +- common/tokenizer/tokenizer.py | 6 +- common/torch_config.py | 19 +- common/trainer/__init__.py | 2 +- common/trainer/default_trainer.py | 95 +- common/trainer/rm_trainer.py | 36 +- common/trainer/trainer.py | 6 +- dev/memory_status/csv_analysis.py | 6 +- dev/memory_status/memory_collect_ray.py | 5 +- ...process_data.py => process_data_dolly1.py} | 13 +- ...data.py => process_data_open_assistant.py} | 15 +- .../api_server_openai/query_http_requests.py | 62 +- .../api_server_openai/query_openai_sdk.py | 6 +- .../api_server_simple/query_single.py | 57 +- examples/rlhf/process_data.py | 37 +- finetune/finetune.py | 173 +-- format.sh | 24 + .../api_openai_backend/openai_protocol.py | 51 +- inference/api_openai_backend/query_client.py | 14 +- .../api_openai_backend/request_handler.py | 20 +- inference/api_openai_backend/router_app.py | 34 +- inference/api_server_openai.py | 15 +- inference/api_server_simple.py | 9 +- inference/chat_process.py | 24 +- inference/deepspeed_predictor.py | 144 ++- inference/inference_config.py | 63 +- inference/logger.py | 4 +- inference/predictor.py | 36 +- inference/predictor_deployment.py | 43 +- inference/serve.py | 73 +- inference/transformer_predictor.py | 57 +- inference/utils.py | 32 +- pretrain/backend/deepspeed_backend.py | 14 +- pretrain/backend/habana_backend.py | 13 +- pretrain/megatron_deepspeed_pretrain.py | 73 +- pretrain/plugin/group_dataset.py | 2 +- pretrain/plugin/hf_pretrainer.py | 80 +- .../plugin/huggingface_model_from_config.py | 61 +- pretrain/plugin/megatron_dataset.py | 10 +- pretrain/plugin/megatron_pretrainer.py | 102 +- pretrain/plugin/megatron_processer.py | 20 +- pretrain/plugin/megtron_initializer.py | 7 +- pretrain/plugin/plain_id_processer.py | 22 +- pretrain/plugin/pretrainer.py | 124 +- pretrain/pretrain.py | 88 +- pyproject.toml | 3 + rlhf/ppo.py | 61 +- rlhf/reward.py | 134 +- rlhf/rl_algo/ppo/ppo_rlhf.py | 112 +- rlhf/rl_algo/ppo/rlhf_buffer.py | 267 ++-- rlhf/rl_algo/ppo/rlhf_ppo_module.py | 30 +- rlhf/rl_algo/ppo/rlhf_ppo_torch_learner.py | 26 +- rlhf/rl_algo/ppo/util.py | 3 +- ui/start_ui.py | 1074 +++++++++++++---- 86 files changed, 2438 insertions(+), 1496 deletions(-) create mode 100644 .github/workflows/workflow_lint.yml create mode 100644 .pre-commit-config.yaml rename examples/finetune/dolly1/{process_data.py => process_data_dolly1.py} (73%) rename examples/finetune/open_assistant/{process_data.py => process_data_open_assistant.py} (85%) create mode 100755 format.sh diff --git a/.github/workflows/config/update_finetune_config_on_intel_gpu.py b/.github/workflows/config/update_finetune_config_on_intel_gpu.py index f0e2a715e..e46dda811 100644 --- a/.github/workflows/config/update_finetune_config_on_intel_gpu.py +++ b/.github/workflows/config/update_finetune_config_on_intel_gpu.py @@ -11,7 +11,7 @@ def update_finetune_config(base_model): # avaiable base models are: # # Mistral-7B-v0.1 - # Llama-2-7b + # Llama-2-7b # pythia-1.4b # pythia-2.8b # pythia-70m diff --git a/.github/workflows/config/update_inference_config.py b/.github/workflows/config/update_inference_config.py index c1c700cdd..cf0c5abe6 100644 --- a/.github/workflows/config/update_inference_config.py +++ b/.github/workflows/config/update_inference_config.py @@ -16,8 +16,8 @@ def get_parser(): parser = argparse.ArgumentParser(description="Adjust Inference Config File") parser.add_argument("--config_file", type=str, required=True) parser.add_argument("--output_file", type=str, required=True) - parser.add_argument("--deepspeed", action='store_true') - parser.add_argument("--ipex", action='store_true') + parser.add_argument("--deepspeed", action="store_true") + parser.add_argument("--ipex", action="store_true") return parser diff --git a/.github/workflows/workflow_lint.yml b/.github/workflows/workflow_lint.yml new file mode 100644 index 000000000..7a7be885e --- /dev/null +++ b/.github/workflows/workflow_lint.yml @@ -0,0 +1,28 @@ +name: Lint + +on: + workflow_call: + inputs: + ci_type: + type: string + default: 'pr' + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-lt + cancel-in-progress: true + +jobs: + lint: + name: lint check + runs-on: ubuntu-latest + + defaults: + run: + shell: bash + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Run Lint + run: ./format.sh -a diff --git a/.github/workflows/workflow_orders_on_merge.yml b/.github/workflows/workflow_orders_on_merge.yml index 56bda5006..780b010b9 100644 --- a/.github/workflows/workflow_orders_on_merge.yml +++ b/.github/workflows/workflow_orders_on_merge.yml @@ -16,6 +16,8 @@ on: - 'pyproject.toml' jobs: + call-lint: + uses: ./.github/workflows/workflow_lint.yml call-inference: uses: ./.github/workflows/workflow_inference.yml diff --git a/.github/workflows/workflow_orders_on_pr.yml b/.github/workflows/workflow_orders_on_pr.yml index 2c8f93f3d..22ce15dcb 100644 --- a/.github/workflows/workflow_orders_on_pr.yml +++ b/.github/workflows/workflow_orders_on_pr.yml @@ -16,9 +16,13 @@ on: - 'pyproject.toml' jobs: + call-lint: + uses: ./.github/workflows/workflow_lint.yml call-inference: + needs: call-lint uses: ./.github/workflows/workflow_inference.yml call-finetune: + needs: call-lint uses: ./.github/workflows/workflow_finetune.yml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..eef34287b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,30 @@ +ci: + autoupdate_schedule: monthly + +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.289 + hooks: + - id: ruff + args: [ --fix, --exit-non-zero-on-fix, --ignore=E402, --ignore=E501, --ignore=E731] + + # Black needs to be ran after ruff with --fix + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: "v0.950" + hooks: + - id: mypy + exclude: tests + additional_dependencies: + - mypy-extensions + - pydantic==1.10.0 + - types-cachetools + - types-filelock + - types-PyYAML + - types-redis + - types-requests + - types-paramiko diff --git a/common/__init__.py b/common/__init__.py index ed67a6544..0c882ef13 100644 --- a/common/__init__.py +++ b/common/__init__.py @@ -1,12 +1,12 @@ -from .config import Config, parse_args, parse_config -from .torch_config import TorchConfig -from .logging import get_logger, logger -from .load import * -from .init import init -from .common import import_all_module +from .logging import logger +from .load import * # noqa: F403 # unable to detect undefined names from . import agentenv +from .torch_config import TorchConfig # noqa: F401 +from typing import Dict, Any +import sys -@load_check_decorator + +@load_check_decorator # noqa: F405 # may be undefined, or defined from star imports def get_agentenv(config: Dict[str, Any]): logger.info(f"{sys._getframe().f_code.co_name} config: {config}") agentenv_type = config.get("type", None) diff --git a/common/agentenv/__init__.py b/common/agentenv/__init__.py index 49a28dea8..fe05d28ec 100644 --- a/common/agentenv/__init__.py +++ b/common/agentenv/__init__.py @@ -6,4 +6,4 @@ basedir = os.path.dirname(realpath) import_all_module(basedir, "common.agentenv") -__all__ = ["AgentEnv"] +__all__ = ["AgentEnv"] diff --git a/common/agentenv/agentenv.py b/common/agentenv/agentenv.py index 545605874..24e53bd8b 100644 --- a/common/agentenv/agentenv.py +++ b/common/agentenv/agentenv.py @@ -1,13 +1,14 @@ class Meta(type): def __init__(cls, name, bases, namespace, **kwargs): super().__init__(name, bases, namespace, **kwargs) - if not hasattr(cls, 'registory'): + if not hasattr(cls, "registory"): # this is the base class cls.registory = {} else: # this is the subclass cls.registory[name] = cls + class AgentEnv(metaclass=Meta): def __init__(self, config): self.config = config diff --git a/common/agentenv/rlhf_env.py b/common/agentenv/rlhf_env.py index dab4de834..d4b0a5833 100644 --- a/common/agentenv/rlhf_env.py +++ b/common/agentenv/rlhf_env.py @@ -1,4 +1,3 @@ -from typing import Any import gymnasium as gym import numpy as np @@ -13,15 +12,15 @@ def generate_response( - model: torch.nn.Module, - *, - input_ids: torch.tensor, - max_length:int, - eos_token_id: int + model: torch.nn.Module, + *, + input_ids: torch.tensor, + max_length: int, + eos_token_id: int, ): """Generate a response using the model.""" generated_sequence = [] - probs_list = [] + # probs_list = [] model_in = torch.clone(input_ids) with torch.no_grad(): for i in range(max_length): @@ -66,6 +65,7 @@ def generate_response( "n_generated_tokens": generated_tokens.shape[-1], } + def compute_approx_kl( logits: torch.Tensor, logits_base: torch.Tensor, @@ -79,53 +79,52 @@ def compute_approx_kl( class RLHFEnv(gym.Env, AgentEnv): - def __init__(self, config): - self.config = config agentenv_config = config.get("config") # Prompt dataset - self.prompt_dataset = load_dataset(agentenv_config.get("datasets")) + self.prompt_dataset = load_dataset(agentenv_config.get("datasets")) self.dsize = len(self.prompt_dataset) - + # base tokenizer self.tokenizer = load_tokenizer(agentenv_config.get("tokenizer")) vocab_size = self.tokenizer.vocab_size - model_max_length = min(agentenv_config['model_max_length'], self.tokenizer.model_max_length) + model_max_length = min(agentenv_config["model_max_length"], self.tokenizer.model_max_length) # reward and sft model self.reward_model = load_model(agentenv_config.get("reward_model")) self.sft_model = load_model(agentenv_config.get("sft_model")) - + # the KL coefficient self.kl_coeff = agentenv_config["kl_coeff"] # The maximum length of the generated text self.max_generation_length = agentenv_config["max_generation_length"] # action space - self.action_space = sp.Dict({ - "sequence": Repeated(sp.Discrete(vocab_size), max_len=model_max_length), - "attention_mask": Repeated(sp.Discrete(2), max_len=model_max_length), - "response_mask": Repeated(sp.Discrete(2), max_len=model_max_length), - "logits": Repeated( - sp.Box(0, 1, shape=(vocab_size,)), max_len=model_max_length - ), - }) + self.action_space = sp.Dict( + { + "sequence": Repeated(sp.Discrete(vocab_size), max_len=model_max_length), + "attention_mask": Repeated(sp.Discrete(2), max_len=model_max_length), + "response_mask": Repeated(sp.Discrete(2), max_len=model_max_length), + "logits": Repeated(sp.Box(0, 1, shape=(vocab_size,)), max_len=model_max_length), + } + ) # observation space - self.observation_space = sp.Dict({ - "input_ids": Repeated(sp.Discrete(vocab_size), max_len=model_max_length), - "attention_mask": Repeated(sp.Discrete(2), max_len=model_max_length), - }) + self.observation_space = sp.Dict( + { + "input_ids": Repeated(sp.Discrete(vocab_size), max_len=model_max_length), + "attention_mask": Repeated(sp.Discrete(2), max_len=model_max_length), + } + ) def reset(self, *, seed=None, options=None): - if seed: np.random.seed(seed) index = np.random.randint(self.dsize) - if 'train' in self.prompt_dataset: - self.prompt_dataset = self.prompt_dataset['train'] + if "train" in self.prompt_dataset: + self.prompt_dataset = self.prompt_dataset["train"] prompt = self.prompt_dataset[index]["prompt"] prompt_tokens = self.tokenizer(prompt, return_tensors="np") # remove the batch dimension since we can only do one sentence generation at a time @@ -134,7 +133,6 @@ def reset(self, *, seed=None, options=None): return prompt_tokens, {} def step(self, action): - sequence = action["sequence"] response_mask = action["response_mask"] attention_mask = action["attention_mask"] @@ -147,28 +145,28 @@ def step(self, action): r_align = r_align[-1].item() # Compute the probs from the sft model for the same number of tokens - sequence = torch.tensor(sequence, dtype=torch.long)[None] # add batch dim + sequence = torch.tensor(sequence, dtype=torch.long)[None] # add batch dim sft_output = generate_response( - self.sft_model, - input_ids=sequence, - max_length=n_response_tokens, - eos_token_id=self.tokenizer.eos_token_id + self.sft_model, + input_ids=sequence, + max_length=n_response_tokens, + eos_token_id=self.tokenizer.eos_token_id, ) - - logits = torch.tensor(logits, dtype=torch.float32)[None] # add batch dim + + logits = torch.tensor(logits, dtype=torch.float32)[None] # add batch dim # only compute kl on the response tokens r_kl = compute_approx_kl( - logits[:, -n_response_tokens:], # the inner term - sft_output["logits"][:, -n_response_tokens:] # the outer term + logits[:, -n_response_tokens:], # the inner term + sft_output["logits"][:, -n_response_tokens:], # the outer term ).item() reward = r_align - self.kl_coeff * r_kl info = { - "r_align": r_align, - "r_kl": r_kl, - "n_response_tokens": n_response_tokens + "r_align": r_align, + "r_kl": r_kl, + "n_response_tokens": n_response_tokens, } # Produce a random reward when we reach the goal. - return self.observation_space.sample(), reward, True, False, info \ No newline at end of file + return self.observation_space.sample(), reward, True, False, info diff --git a/common/common.py b/common/common.py index 35e9d7da8..b846ea75a 100644 --- a/common/common.py +++ b/common/common.py @@ -4,8 +4,9 @@ from .logging import logger -def import_all_module(basedir, prefix = None): - all_py_files = glob.glob(basedir+"/*.py") + +def import_all_module(basedir, prefix=None): + all_py_files = glob.glob(basedir + "/*.py") modules = [os.path.basename(f) for f in all_py_files] for module in modules: @@ -17,5 +18,5 @@ def import_all_module(basedir, prefix = None): module_name = f"{prefix}.{module}" try: importlib.import_module(module_name) - except Exception as e: + except Exception: logger.warning(f"import {module_name} erro", exc_info=True) diff --git a/common/config.py b/common/config.py index 801e48f72..392410128 100644 --- a/common/config.py +++ b/common/config.py @@ -1,10 +1,12 @@ -import os import yaml import argparse from typing import Dict + def parse_args(): - parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") + parser = argparse.ArgumentParser( + description="Finetune a transformers model on a causal language modeling task" + ) parser.add_argument( "--config_file", type=str, @@ -15,6 +17,7 @@ def parse_args(): args, unparsed = parser.parse_known_args() return args + def parse_config(config_file=None): if config_file is None: args = parse_args() @@ -30,25 +33,31 @@ def parse_config(config_file=None): assert isinstance(config, dict) return config + def _singleton(cls): _instance = {} + def inner(): if cls not in _instance: _instance[cls] = cls() return _instance[cls] + return inner + def flat(x, separator="."): for key, value in x.items(): if isinstance(value, dict): for k, v in flat(value): - k = f'{key}{separator}{k}' + k = f"{key}{separator}{k}" yield (k, v) else: yield (key, value) + def pack(x, separator="."): - return {k:v for k,v in flat(x, separator)} + return {k: v for k, v in flat(x, separator)} + def rank(key, value): if len(key) == 1: @@ -57,10 +66,12 @@ def rank(key, value): prefix = key.pop(0) return {prefix: rank(key, value)} + def deflat(x, separator="."): for key, value in x.items(): yield rank(key.split(separator), value) + def recursive_merge(dst, src): for key, value in src.items(): if key not in dst: @@ -70,15 +81,17 @@ def recursive_merge(dst, src): else: dst[key] = value + def unpack(x, separator="."): result = {} for i in deflat(x, separator): recursive_merge(result, i) return result -def mapping(x, table, only_in_table = True): + +def mapping(x, table, only_in_table=True): new_x = {} - for k,v in x.items(): + for k, v in x.items(): if k in table: new_keys = table[k] if isinstance(new_keys, list): @@ -106,16 +119,17 @@ def mapping(x, table, only_in_table = True): return new_x -def merge_with_mapping(dict1, dict2, table, only_in_table = True): +def merge_with_mapping(dict1, dict2, table, only_in_table=True): dict1_pack = pack(dict1) dict2_pack = pack(dict2) dict2_pack = mapping(dict2_pack, table, only_in_table) recursive_merge(dict1_pack, dict2_pack) dict1.clear() - for k,v in unpack(dict1_pack).items(): + for k, v in unpack(dict1_pack).items(): dict1[k] = v return dict1 + @_singleton class Config(Dict): def __init__(self): diff --git a/common/dataprocesser/__init__.py b/common/dataprocesser/__init__.py index b92b29cb1..7e74e6a13 100644 --- a/common/dataprocesser/__init__.py +++ b/common/dataprocesser/__init__.py @@ -6,4 +6,4 @@ basedir = os.path.dirname(realpath) import_all_module(basedir, "common.dataprocesser") -__all__ = ["DataProcesser"] +__all__ = ["DataProcesser"] diff --git a/common/dataprocesser/dataprocesser.py b/common/dataprocesser/dataprocesser.py index ab4820e5d..40fe5f4cf 100644 --- a/common/dataprocesser/dataprocesser.py +++ b/common/dataprocesser/dataprocesser.py @@ -1,14 +1,14 @@ - class Meta(type): def __init__(cls, name, bases, namespace, **kwargs): super().__init__(name, bases, namespace, **kwargs) - if not hasattr(cls, 'registory'): + if not hasattr(cls, "registory"): # this is the base class cls.registory = {} else: # this is the subclass cls.registory[name] = cls + class DataProcesser(metaclass=Meta): def __init__(self, config): self.config = config diff --git a/common/dataprocesser/general_processer.py b/common/dataprocesser/general_processer.py index d1eb27aeb..4873b4594 100644 --- a/common/dataprocesser/general_processer.py +++ b/common/dataprocesser/general_processer.py @@ -1,5 +1,3 @@ -import math -import time from itertools import chain import numpy as np @@ -8,11 +6,8 @@ import transformers from .dataprocesser import DataProcesser -from ..logging import logger -INTRO_BLURB = ( - "Below is an instruction that describes a task. Write a response that appropriately completes the request." -) +INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request." INSTRUCTION_KEY = "### Instruction:" INPUT_KEY = "Input:" RESPONSE_KEY = "### Response:" @@ -58,6 +53,8 @@ end_key=END_KEY, ) TEXT_COLUMN_NAME = "text" + + class DataCollatorForCompletionOnlyLM(transformers.DataCollatorForLanguageModeling): def torch_call(self, examples): batch = super().torch_call(examples) @@ -69,7 +66,6 @@ def torch_call(self, examples): labels = batch["labels"].clone() for i in range(len(examples)): - response_token_ids_start_idx = None for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]: response_token_ids_start_idx = idx @@ -85,12 +81,13 @@ def torch_call(self, examples): return batch + class GeneralProcesser(DataProcesser): def prepare(self, tokenizer, dataset): per_device_train_batch_size = self.config.get("per_device_train_batch_size", 1) per_device_eval_batch_size = self.config.get("per_device_eval_batch_size", 1) group = self.config.get("group", False) - shuffle = self.config.get("shuffle", False) + self.config.get("shuffle", False) tokenizer.pad_token = tokenizer.eos_token if isinstance(dataset, datasets.Dataset): @@ -100,6 +97,7 @@ def prepare(self, tokenizer, dataset): column_names = dataset["train"].column_names if column_names and TEXT_COLUMN_NAME not in column_names: + def prompt(rec): instruction = rec["instruction"] response = rec["response"] @@ -109,10 +107,14 @@ def prompt(rec): if not response: raise ValueError(f"Expected a response in: {rec}") if context: - rec["text"] = PROMPT_WITH_INPUT_FORMAT.format(instruction=instruction, response=response, input=context) + rec["text"] = PROMPT_WITH_INPUT_FORMAT.format( + instruction=instruction, response=response, input=context + ) else: - rec["text"] = PROMPT_NO_INPUT_FORMAT.format(instruction=instruction, response=response) - return rec + rec["text"] = PROMPT_NO_INPUT_FORMAT.format( + instruction=instruction, response=response + ) + return rec dataset = dataset.map( prompt, @@ -122,6 +124,7 @@ def prompt(rec): column_names += [TEXT_COLUMN_NAME] max_length = self.config.get("max_length", 1024) + def tokenize_function(examples): return tokenizer(examples[TEXT_COLUMN_NAME], max_length=max_length) @@ -134,6 +137,7 @@ def tokenize_function(examples): if group: block_size = self.config.get("block_size", 1024) + def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} @@ -156,27 +160,30 @@ def group_texts(examples): load_from_cache_file=False, desc=f"Grouping texts in chunks of {block_size}", ) - default_data_collator=transformers.default_data_collator + default_data_collator = transformers.default_data_collator else: default_data_collator = DataCollatorForCompletionOnlyLM( - tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 + tokenizer=tokenizer, + mlm=False, + return_tensors="pt", + pad_to_multiple_of=8, ) train_dataset = tokenized_datasets["train"] train_dataloader = torch.utils.data.DataLoader( - train_dataset, - shuffle=True, - collate_fn=default_data_collator, - batch_size=per_device_train_batch_size + train_dataset, + shuffle=True, + collate_fn=default_data_collator, + batch_size=per_device_train_batch_size, ) eval_dataloader = None if "validation" in tokenized_datasets: eval_dataset = tokenized_datasets["validation"] eval_dataloader = torch.utils.data.DataLoader( - eval_dataset, - collate_fn=default_data_collator, - batch_size=per_device_eval_batch_size + eval_dataset, + collate_fn=default_data_collator, + batch_size=per_device_eval_batch_size, ) return train_dataloader, eval_dataloader diff --git a/common/dataprocesser/rm_dataprocesser.py b/common/dataprocesser/rm_dataprocesser.py index b4141c421..36ead7d8b 100644 --- a/common/dataprocesser/rm_dataprocesser.py +++ b/common/dataprocesser/rm_dataprocesser.py @@ -1,20 +1,14 @@ -import math -import time -from itertools import chain - import torch import transformers from .dataprocesser import DataProcesser from ..logging import logger -class RMDataProcesser(DataProcesser): +class RMDataProcesser(DataProcesser): def prepare(self, tokenizer, dataset): - block_size = self.config.get("block_size") - - + if block_size is None: block_size = tokenizer.model_max_length if block_size > 1024: @@ -31,9 +25,8 @@ def prepare(self, tokenizer, dataset): f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." ) block_size = min(block_size, tokenizer.model_max_length) - + def tokenize_function(examples): - tokenizer.pad_token = tokenizer.eos_token chosen = tokenizer( examples["prompt"] + examples["chosen"], @@ -41,11 +34,10 @@ def tokenize_function(examples): truncation=True, padding="max_length", ) - + examples["chosen_input_ids"] = chosen["input_ids"] examples["chosen_attention_mask"] = chosen["attention_mask"] - rejected = tokenizer( examples["prompt"] + examples["rejected"], max_length=block_size, @@ -77,14 +69,14 @@ def tokenize_function(examples): per_device_train_batch_size = self.config.get("per_device_train_batch_size", 2) per_device_eval_batch_size = self.config.get("per_device_eval_batch_size", 4) train_dataloader = torch.utils.data.DataLoader( - train_dataset, - shuffle=True, - collate_fn=transformers.default_data_collator, - batch_size=per_device_train_batch_size + train_dataset, + shuffle=True, + collate_fn=transformers.default_data_collator, + batch_size=per_device_train_batch_size, ) eval_dataloader = torch.utils.data.DataLoader( - eval_dataset, - collate_fn=transformers.default_data_collator, - batch_size=per_device_eval_batch_size + eval_dataset, + collate_fn=transformers.default_data_collator, + batch_size=per_device_eval_batch_size, ) - return train_dataloader, eval_dataloader \ No newline at end of file + return train_dataloader, eval_dataloader diff --git a/common/dataset/__init__.py b/common/dataset/__init__.py index e90f6cb35..9b04a188b 100644 --- a/common/dataset/__init__.py +++ b/common/dataset/__init__.py @@ -6,4 +6,4 @@ basedir = os.path.dirname(realpath) import_all_module(basedir, "common.dataset") -__all__ = ["Dataset"] \ No newline at end of file +__all__ = ["Dataset"] diff --git a/common/dataset/dataset.py b/common/dataset/dataset.py index 8e4e08cde..9b865c2fa 100644 --- a/common/dataset/dataset.py +++ b/common/dataset/dataset.py @@ -1,13 +1,13 @@ - class Meta(type): def __init__(cls, name, bases, namespace, **kwargs): super().__init__(name, bases, namespace, **kwargs) - if not hasattr(cls, 'registory'): + if not hasattr(cls, "registory"): # this is the base class cls.registory = {} else: # this is the subclass cls.registory[name] = cls + class Dataset(metaclass=Meta): - pass \ No newline at end of file + pass diff --git a/common/dataset/huggingface_dataset.py b/common/dataset/huggingface_dataset.py index d2ef9e794..9173e067f 100644 --- a/common/dataset/huggingface_dataset.py +++ b/common/dataset/huggingface_dataset.py @@ -3,6 +3,7 @@ from .dataset import Dataset + def local_load(name, **load_config): if os.path.isfile(name): file = os.path.basename(os.path.abspath(name)) @@ -12,6 +13,7 @@ def local_load(name, **load_config): dataset = datasets.load_dataset(name, **load_config) return dataset["train"] + class HuggingfaceDataset(Dataset): def __call__(self, config): name = config.get("name") @@ -22,12 +24,16 @@ def __call__(self, config): train_dataset = local_load(name) if validation_file is not None: validation_dataset = local_load(validation_file) - return datasets.DatasetDict({"train":train_dataset, "validation_dataset": validation_dataset}) - if validation_split_percentage/100 > 0.0 and validation_split_percentage/100 < 1.0: - datasets_dict = train_dataset.train_test_split(test_size = validation_split_percentage/100) + return datasets.DatasetDict( + {"train": train_dataset, "validation_dataset": validation_dataset} + ) + if validation_split_percentage / 100 > 0.0 and validation_split_percentage / 100 < 1.0: + datasets_dict = train_dataset.train_test_split( + test_size=validation_split_percentage / 100 + ) datasets_dict["validation"] = datasets_dict["test"] return datasets_dict - return datasets.DatasetDict({"train":train_dataset}) + return datasets.DatasetDict({"train": train_dataset}) else: load_config = config.get("load_config", {}) if load_from_disk: diff --git a/common/init.py b/common/init.py index a39c27ae1..63715f18f 100644 --- a/common/init.py +++ b/common/init.py @@ -3,6 +3,7 @@ from .logging import logger + def check_config(config): logger.debug("check config start") if isinstance(config, dict): @@ -10,6 +11,7 @@ def check_config(config): else: return False + def init(config): logger.debug("global init start") if not check_config(config): @@ -31,4 +33,4 @@ def init(config): else: logger.info("seed is not set") - logger.debug("global init finish") \ No newline at end of file + logger.debug("global init finish") diff --git a/common/initializer/__init__.py b/common/initializer/__init__.py index c5276e8d6..2cdc27adb 100644 --- a/common/initializer/__init__.py +++ b/common/initializer/__init__.py @@ -6,4 +6,4 @@ basedir = os.path.dirname(realpath) import_all_module(basedir, "common.initializer") -__all__ = ["Initializer"] \ No newline at end of file +__all__ = ["Initializer"] diff --git a/common/initializer/initializer.py b/common/initializer/initializer.py index 3e23c87bb..341412112 100644 --- a/common/initializer/initializer.py +++ b/common/initializer/initializer.py @@ -1,12 +1,13 @@ class Meta(type): def __init__(cls, name, bases, namespace, **kwargs): super().__init__(name, bases, namespace, **kwargs) - if not hasattr(cls, 'registory'): + if not hasattr(cls, "registory"): # this is the base class cls.registory = {} else: # this is the subclass cls.registory[name] = cls + class Initializer(metaclass=Meta): - pass \ No newline at end of file + pass diff --git a/common/load.py b/common/load.py index 687e73e25..16fcfd1c5 100644 --- a/common/load.py +++ b/common/load.py @@ -9,6 +9,7 @@ from . import trainer from . import initializer + def load_check_decorator(func): def wrapper(*args, **kwargs): try: @@ -23,8 +24,10 @@ def wrapper(*args, **kwargs): exit(1) else: return ret + return wrapper + @load_check_decorator def load_dataset(config: Dict[str, Any]): logger.info(f"{sys._getframe().f_code.co_name} config: {config}") @@ -40,6 +43,7 @@ def load_dataset(config: Dict[str, Any]): exit(1) return _ + @load_check_decorator def load_tokenizer(config: Dict[str, Any]): logger.info(f"{sys._getframe().f_code.co_name} config: {config}") @@ -55,6 +59,7 @@ def load_tokenizer(config: Dict[str, Any]): exit(1) return _ + @load_check_decorator def load_model(config: Dict[str, Any]): logger.info(f"{sys._getframe().f_code.co_name} config: {config}") @@ -70,6 +75,7 @@ def load_model(config: Dict[str, Any]): exit(1) return _ + @load_check_decorator def load_optimizer(model, config: Dict[str, Any]): logger.info(f"{sys._getframe().f_code.co_name} config: {config}") @@ -85,6 +91,7 @@ def load_optimizer(model, config: Dict[str, Any]): exit(1) return _ + @load_check_decorator def get_trainer(config: Dict[str, Any]): logger.info(f"{sys._getframe().f_code.co_name} config: {config}") @@ -99,6 +106,7 @@ def get_trainer(config: Dict[str, Any]): exit(1) return _ + @load_check_decorator def get_initializer(config: Dict[str, Any]): logger.info(f"{sys._getframe().f_code.co_name} config: {config}") @@ -111,4 +119,4 @@ def get_initializer(config: Dict[str, Any]): except Exception as e: logger.critical(f"{Factory.__name__} init error: {e}", exc_info=True) exit(1) - return _ \ No newline at end of file + return _ diff --git a/common/logging.py b/common/logging.py index 20a2baf46..f181ba915 100644 --- a/common/logging.py +++ b/common/logging.py @@ -1,4 +1,3 @@ -import sys import logging import logging.config import traceback @@ -12,42 +11,39 @@ logging_config = { "version": 1, "loggers": { - "root": { - "level": "DEBUG", - "handlers": ["consoleHandler"] - }, + "root": {"level": "DEBUG", "handlers": ["consoleHandler"]}, "common": { "level": "DEBUG", "handlers": ["consoleHandler"], "qualname": "common", - "propagate": 0 - } - + "propagate": 0, + }, }, "handlers": { "consoleHandler": { "class": "logging.StreamHandler", "level": "DEBUG", - "formatter": "standardFormatter" - }, + "formatter": "standardFormatter", + }, }, "formatters": { "standardFormatter": { "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", - "datefmt": "" + "datefmt": "", } - } + }, } if logging_config is not None: try: logging.config.dictConfig(logging_config) - except Exception as e: + except Exception: traceback.print_exc() exit(1) if use_accelerate_log: import accelerate + get_logger = functools.partial(accelerate.logging.get_logger, name=logger_name) else: get_logger = functools.partial(logging.getLogger, name=logger_name) diff --git a/common/model/huggingface_model_for_causal_lm.py b/common/model/huggingface_model_for_causal_lm.py index 7e4c969cd..30ad5a809 100644 --- a/common/model/huggingface_model_for_causal_lm.py +++ b/common/model/huggingface_model_for_causal_lm.py @@ -4,6 +4,7 @@ from peft import get_peft_model, LoraConfig import deltatuner + class HuggingFaceModelForCausalLM(Model): def __call__(self, config): name = config.get("name") diff --git a/common/model/model.py b/common/model/model.py index 369d9b62a..308e15c36 100644 --- a/common/model/model.py +++ b/common/model/model.py @@ -1,13 +1,13 @@ - class Meta(type): def __init__(cls, name, bases, namespace, **kwargs): super().__init__(name, bases, namespace, **kwargs) - if not hasattr(cls, 'registory'): + if not hasattr(cls, "registory"): # this is the base class cls.registory = {} else: # this is the subclass cls.registory[name] = cls + class Model(metaclass=Meta): - pass \ No newline at end of file + pass diff --git a/common/model/reward_model.py b/common/model/reward_model.py index 294c9f7df..a4aa237ef 100644 --- a/common/model/reward_model.py +++ b/common/model/reward_model.py @@ -6,8 +6,8 @@ from .model import Model + class HuggingFaceRewardModel(Model): - def __call__(self, config): name = config.get("name") if name is None: @@ -19,7 +19,7 @@ def __call__(self, config): try: model = get_reward_model(model_cls, name) - except: + except Exception: print("Load reward model error") exit() @@ -27,20 +27,18 @@ def __call__(self, config): def get_reward_model(model_cls, name): - class RewardModel(model_cls): - def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) # The additional value head. - if hasattr(self.config, 'hidden_size'): + if hasattr(self.config, "hidden_size"): self.value_head = nn.Linear(self.config.hidden_size, 1) - elif hasattr(self.config, 'n_embd'): + elif hasattr(self.config, "n_embd"): self.value_head = nn.Linear(self.config.n_embd, 1) else: raise ValueError("current model does not support") - + self.post_init() def forward( @@ -49,15 +47,14 @@ def forward( chosen_attention_mask, rejected_input_ids, rejected_attention_mask, - **kwargs + **kwargs, ) -> torch.Tensor: chosen_value = self.value(chosen_input_ids, chosen_attention_mask) rejected_value = self.value(rejected_input_ids, rejected_attention_mask) return torch.stack([chosen_value, rejected_value], dim=1) def value(self, input_ids, attention_mask) -> torch.Tensor: - """Forward function predicts whether chosen response has a higher reward. - """ + """Forward function predicts whether chosen response has a higher reward.""" # Force inputs to be torch tensors. if not isinstance(input_ids, torch.Tensor): input_ids = torch.tensor(input_ids).to(self.device) @@ -68,7 +65,7 @@ def value(self, input_ids, attention_mask) -> torch.Tensor: input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, - )['last_hidden_state'] + )["last_hidden_state"] values = self.value_head(last_hidden_state) # Remove the last dimension, since there is only a single value per token. @@ -80,4 +77,3 @@ def generate(self, *kwargs): raise NotImplementedError("Reward model does not generate token.") return RewardModel.from_pretrained(name) - diff --git a/common/optimizer/__init__.py b/common/optimizer/__init__.py index bdba4937e..122acc90f 100644 --- a/common/optimizer/__init__.py +++ b/common/optimizer/__init__.py @@ -6,4 +6,4 @@ basedir = os.path.dirname(realpath) import_all_module(basedir, "common.optimizer") -__all__ = ["Optimizer"] +__all__ = ["Optimizer"] diff --git a/common/optimizer/default_optimizer.py b/common/optimizer/default_optimizer.py index a41b4021a..dab5803a2 100644 --- a/common/optimizer/default_optimizer.py +++ b/common/optimizer/default_optimizer.py @@ -1,13 +1,12 @@ -import torch - +import torch # noqa: F401 from .optimizer import Optimizer + class DefaultOptimizer(Optimizer): def __call__(self, model, config): - optimizer_name = config.get("name", "SGD") optimizer_config = config.get("config", {}) - optimizer_type = eval("torch.optim.%s"%(optimizer_name)) + optimizer_type = eval("torch.optim.%s" % (optimizer_name)) optimizer_grouped_parameters = self.get_grouped_parameters(model, config) optimizer = optimizer_type(optimizer_grouped_parameters, **optimizer_config) @@ -15,4 +14,4 @@ def __call__(self, model, config): return optimizer def get_grouped_parameters(self, model, config): - return model.parameters() \ No newline at end of file + return model.parameters() diff --git a/common/optimizer/group_optimizer.py b/common/optimizer/group_optimizer.py index 8cfdf4848..0e07878db 100644 --- a/common/optimizer/group_optimizer.py +++ b/common/optimizer/group_optimizer.py @@ -1,13 +1,12 @@ -import torch - +import torch # noqa: F401 from .optimizer import Optimizer + class GroupOptimizer(Optimizer): def __call__(self, model, config): - optimizer_name = config.get("name", "SGD") optimizer_config = config.get("config", {}) - optimizer_type = eval("torch.optim.%s"%(optimizer_name)) + optimizer_type = eval("torch.optim.%s" % (optimizer_name)) optimizer_grouped_parameters = self.get_grouped_parameters(model, config) optimizer = optimizer_type(optimizer_grouped_parameters, **optimizer_config) @@ -18,12 +17,16 @@ def get_grouped_parameters(self, model, config): no_decay = ["bias", "layer_norm.weight"] optimizer_grouped_parameters = [ { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "params": [ + p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) + ], "weight_decay": 0.1, }, { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "params": [ + p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) + ], "weight_decay": 0.0, }, ] - return optimizer_grouped_parameters \ No newline at end of file + return optimizer_grouped_parameters diff --git a/common/optimizer/optimizer.py b/common/optimizer/optimizer.py index 298f02660..bbcd54fe1 100644 --- a/common/optimizer/optimizer.py +++ b/common/optimizer/optimizer.py @@ -1,13 +1,13 @@ - class Meta(type): def __init__(cls, name, bases, namespace, **kwargs): super().__init__(name, bases, namespace, **kwargs) - if not hasattr(cls, 'registory'): + if not hasattr(cls, "registory"): # this is the base class cls.registory = {} else: # this is the subclass cls.registory[name] = cls + class Optimizer(metaclass=Meta): - pass \ No newline at end of file + pass diff --git a/common/tokenizer/empty_tokenizer.py b/common/tokenizer/empty_tokenizer.py index 917526e95..c2684aca0 100644 --- a/common/tokenizer/empty_tokenizer.py +++ b/common/tokenizer/empty_tokenizer.py @@ -1,13 +1,14 @@ -import transformers - from .tokenizer import Tokenizer + class _EmptyTokenizer: def __init__(self, max_token_id): self.max_token_id = max_token_id + def __len__(self): return self.max_token_id + class EmptyTokenizer(Tokenizer): def __call__(self, config): tokenizer_config = config.get("config") diff --git a/common/tokenizer/huggingface_tokenizer.py b/common/tokenizer/huggingface_tokenizer.py index e8a50ab07..a6a60bc7f 100644 --- a/common/tokenizer/huggingface_tokenizer.py +++ b/common/tokenizer/huggingface_tokenizer.py @@ -2,9 +2,10 @@ from .tokenizer import Tokenizer + class HuggingFaceTokenizer(Tokenizer): def __call__(self, config): name = config.get("name") load_config = config.get("config", {}) tokenizer = transformers.AutoTokenizer.from_pretrained(name, **load_config) - return tokenizer \ No newline at end of file + return tokenizer diff --git a/common/tokenizer/tokenizer.py b/common/tokenizer/tokenizer.py index 6ae9cc173..973d50aee 100644 --- a/common/tokenizer/tokenizer.py +++ b/common/tokenizer/tokenizer.py @@ -1,13 +1,13 @@ - class Meta(type): def __init__(cls, name, bases, namespace, **kwargs): super().__init__(name, bases, namespace, **kwargs) - if not hasattr(cls, 'registory'): + if not hasattr(cls, "registory"): # this is the base class cls.registory = {} else: # this is the subclass cls.registory[name] = cls + class Tokenizer(metaclass=Meta): - pass \ No newline at end of file + pass diff --git a/common/torch_config.py b/common/torch_config.py index 89632eb2d..5a63ab565 100644 --- a/common/torch_config.py +++ b/common/torch_config.py @@ -5,6 +5,7 @@ from typing import Optional import os import sys + # The package importlib_metadata is in a different place, depending on the Python version. if sys.version_info < (3, 8): import importlib_metadata @@ -23,25 +24,19 @@ def backend_cls(self): def libs_import(): - """try to import IPEX and oneCCL. - """ + """try to import IPEX and oneCCL.""" try: - import intel_extension_for_pytorch + import intel_extension_for_pytorch # noqa: F401 except ImportError: - raise ImportError( - "Please install intel_extension_for_pytorch" - ) + raise ImportError("Please install intel_extension_for_pytorch") try: ccl_version = importlib_metadata.version("oneccl_bind_pt") if ccl_version >= "1.12": - # pylint: disable-all - import oneccl_bindings_for_pytorch + import oneccl_bindings_for_pytorch # noqa: F401 else: - import torch_ccl + import torch_ccl # noqa: F401 except ImportError as ccl_not_exist: - raise ImportError( - "Please install torch-ccl" - ) from ccl_not_exist + raise ImportError("Please install torch-ccl") from ccl_not_exist def _set_torch_distributed_env_vars(device): diff --git a/common/trainer/__init__.py b/common/trainer/__init__.py index d031b5d8b..b33b565a5 100644 --- a/common/trainer/__init__.py +++ b/common/trainer/__init__.py @@ -6,4 +6,4 @@ basedir = os.path.dirname(realpath) import_all_module(basedir, "common.trainer") -__all__ = ["Trainer"] +__all__ = ["Trainer"] diff --git a/common/trainer/default_trainer.py b/common/trainer/default_trainer.py index f3aa965b9..a33ac2bdd 100644 --- a/common/trainer/default_trainer.py +++ b/common/trainer/default_trainer.py @@ -14,6 +14,7 @@ from ..logging import logger + class DefaultTrainer(Trainer): def __init__(self, config): self.config = config @@ -27,12 +28,12 @@ def __init__(self, config): def recovery(self, config): if config is None or config is {}: - logger.warning(f"checkpoint is empty, skip") + logger.warning("checkpoint is empty, skip") return root_path = config.get("root_path") model_name = config.get("model_name", "") if root_path is None: - logger.warning(f"checkpoint root_path is empty, skip") + logger.warning("checkpoint root_path is empty, skip") local_checkpoint_path = self._get_local_path(root_path, model_name) try: logger.info(f"start recovery from {local_checkpoint_path}") @@ -48,8 +49,12 @@ def recovery(self, config): self.optimizer.load_state_dict(optimizer_state) # update lr_scheduler status - if Path.exists(checkpoint_dir / "lr_scheduler.pt") and hasattr(self, "lr_scheduler"): - scheduler_state = torch.load(checkpoint_dir / "lr_scheduler.pt", map_location="cpu") + if Path.exists(checkpoint_dir / "lr_scheduler.pt") and hasattr( + self, "lr_scheduler" + ): + scheduler_state = torch.load( + checkpoint_dir / "lr_scheduler.pt", map_location="cpu" + ) self.lr_scheduler.load_state_dict(scheduler_state) # update current epoch @@ -59,7 +64,7 @@ def recovery(self, config): logger.info(f"recovery to epoch {self.starting_epoch}") except FileNotFoundError as e: logger.info(e) - except Exception as e: + except Exception: logger.warning("recovery error", exc_info=True) def _coordinate(self, accelerator): @@ -68,15 +73,24 @@ def _coordinate(self, accelerator): self.size = accelerator.num_processes self.local_rank = accelerator.local_process_index accelerator.wait_for_everyone() - logger.info(f"coordinate workers finish, cluster size:{self.size} worker rank:{self.rank} worker local_rank:{self.local_rank}") + logger.info( + f"coordinate workers finish, cluster size:{self.size} worker rank:{self.rank} worker local_rank:{self.local_rank}" + ) - def _get_lr_scheduler(self, lr_scheduler_config, optimizer, num_train_epochs, num_steps_per_epoch, accelerator): + def _get_lr_scheduler( + self, + lr_scheduler_config, + optimizer, + num_train_epochs, + num_steps_per_epoch, + accelerator, + ): # gradient_accumulation_steps = accelerator.gradient_accumulation_steps # num_update_steps_per_epoch = math.ceil(num_steps_per_epoch / gradient_accumulation_steps) enable = lr_scheduler_config.get("enable", False) if not enable: return None - max_train_steps = lr_scheduler_config.get("max_train_steps") + max_train_steps = lr_scheduler_config.get("max_train_steps") lr_scheduler_type = lr_scheduler_config.get("lr_scheduler_type", "linear") num_warmup_steps = lr_scheduler_config.get("num_warmup_steps", 0) @@ -98,17 +112,23 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator): logger.info(f"model embedding size: {embedding_size}") if len(tokenizer) > embedding_size: model.resize_token_embeddings(len(tokenizer)) - logger.warning(f"model embedding size resize to {len(tokenizer)} because of tokenizer size") + logger.warning( + f"model embedding size resize to {len(tokenizer)} because of tokenizer size" + ) - train_dataloader, eval_dataloader = self.dataprocesser.prepare( - tokenizer, dataset - ) + train_dataloader, eval_dataloader = self.dataprocesser.prepare(tokenizer, dataset) lr_scheduler_config = self.config.get("lr_scheduler") if lr_scheduler_config: num_steps_per_epoch = len(train_dataloader) num_train_epochs = self.config.get("num_train_epochs", 1) - lr_scheduler = self._get_lr_scheduler(lr_scheduler_config, optimizer, num_train_epochs, num_steps_per_epoch, accelerator) + lr_scheduler = self._get_lr_scheduler( + lr_scheduler_config, + optimizer, + num_train_epochs, + num_steps_per_epoch, + accelerator, + ) else: lr_scheduler = None @@ -119,9 +139,12 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator): # https://huggingface.co/docs/accelerate/usage_guides/fsdp#a-few-caveats-to-be-aware-of self.model = accelerator.prepare(model) - self.optimizer, self.train_dataloader, self.eval_dataloader, self.lr_scheduler = accelerator.prepare( - optimizer, train_dataloader, eval_dataloader, lr_scheduler - ) + ( + self.optimizer, + self.train_dataloader, + self.eval_dataloader, + self.lr_scheduler, + ) = accelerator.prepare(optimizer, train_dataloader, eval_dataloader, lr_scheduler) checkpoint = self.config.get("checkpoint") if checkpoint is not None: @@ -148,8 +171,19 @@ def train(self): self.lr_scheduler.step() self.optimizer.zero_grad() if step % log_step == 0: - logger.info(f"train epoch:[{idx}/{num_train_epochs}]\tstep:[{step}/{total_steps}]\tloss:{loss:.6f}\tppl:{math.exp(loss):.6f}\ttime:{time.time()-start:.6f}") - report({"train_epoch": idx, "total_epochs": num_train_epochs, "train_step": step, "total_steps": min(max_train_step, total_steps) if max_train_step else total_steps}) + logger.info( + f"train epoch:[{idx}/{num_train_epochs}]\tstep:[{step}/{total_steps}]\tloss:{loss:.6f}\tppl:{math.exp(loss):.6f}\ttime:{time.time()-start:.6f}" + ) + report( + { + "train_epoch": idx, + "total_epochs": num_train_epochs, + "train_step": step, + "total_steps": min(max_train_step, total_steps) + if max_train_step + else total_steps, + } + ) start = time.time() if max_train_step is not None: if step >= max_train_step - 1: @@ -164,7 +198,11 @@ def train(self): with torch.no_grad(): outputs = self.model(**batch) loss = outputs.loss - losses.append(self.accelerator.gather_for_metrics(loss.repeat(batch["input_ids"].shape[0]))) + losses.append( + self.accelerator.gather_for_metrics( + loss.repeat(batch["input_ids"].shape[0]) + ) + ) if max_eval_step is not None: if step >= max_eval_step: break @@ -176,7 +214,9 @@ def train(self): except OverflowError: eval_loss = float("inf") perplexity = float("inf") - logger.info(f"eval epoch:[{idx}/{num_train_epochs}]\tloss:[{eval_loss:.6f}]\tppl:[{perplexity:.6f}]\ttime:[{time.time()-start:.6f}]") + logger.info( + f"eval epoch:[{idx}/{num_train_epochs}]\tloss:[{eval_loss:.6f}]\tppl:[{perplexity:.6f}]\ttime:[{time.time()-start:.6f}]" + ) if checkpoint is not None: self.save(checkpoint, idx) @@ -187,7 +227,9 @@ def train(self): logger.info(f"start save model to {output}") unwrapped_model = self.accelerator.unwrap_model(self.model) unwrapped_model.save_pretrained( - output, is_main_process=self.accelerator.is_main_process, save_function=self.accelerator.save + output, + is_main_process=self.accelerator.is_main_process, + save_function=self.accelerator.save, ) logger.info(f"finish save model to {output}") self.accelerator.wait_for_everyone() @@ -195,14 +237,14 @@ def train(self): def _get_local_path(self, root_path, model_name): return f"{root_path}/{model_name}_{self.rank}-of-{self.size}" - def save(self, config, epoch = 0): + def save(self, config, epoch=0): if config is None or config is {}: - logger.warning(f"checkpoint is empty, skip") + logger.warning("checkpoint is empty, skip") return root_path = config.get("root_path") model_name = config.get("model_name", "") if root_path is None: - logger.warning(f"checkpoint root_path is empty, skip") + logger.warning("checkpoint root_path is empty, skip") return local_checkpoint_path = self._get_local_path(root_path, model_name) @@ -211,7 +253,10 @@ def save(self, config, epoch = 0): torch.save(self.optimizer.state_dict(), os.path.join(tmpdir, "optim.pt")) torch.save({"epoch": epoch}, os.path.join(tmpdir, "epoch.pt")) if self.lr_scheduler: - torch.save(self.lr_scheduler.state_dict(), os.path.join(tmpdir, "lr_scheduler.pt")) + torch.save( + self.lr_scheduler.state_dict(), + os.path.join(tmpdir, "lr_scheduler.pt"), + ) checkpoint = Checkpoint.from_directory(tmpdir) checkpoint.to_directory(local_checkpoint_path) logger.info(f"save checkpoint to {local_checkpoint_path} finished") diff --git a/common/trainer/rm_trainer.py b/common/trainer/rm_trainer.py index 4bb725d87..0ee8ee7eb 100644 --- a/common/trainer/rm_trainer.py +++ b/common/trainer/rm_trainer.py @@ -1,27 +1,22 @@ -from .trainer import Trainer -from itertools import chain import os import torch -from torch.utils.tensorboard import SummaryWriter -import transformers +from torch.utils.tensorboard import SummaryWriter import math import time -from .. import dataprocesser from .default_trainer import DefaultTrainer from ..logging import logger -class RMTrainer(DefaultTrainer): +class RMTrainer(DefaultTrainer): def compute_loss(self, batch, return_outputs=False): - chosen_ids = batch.pop("chosen_input_ids").to(self.model.device) chosen_mask = batch.pop("chosen_attention_mask").to(self.model.device) rejected_ids = batch.pop("rejected_input_ids").to(self.model.device) rejected_mask = batch.pop("rejected_attention_mask").to(self.model.device) result = self.model(chosen_ids, chosen_mask, rejected_ids, rejected_mask) - + chosen_rewards = result[:, 0, :] rejected_rewards = result[:, 1, :] @@ -29,10 +24,11 @@ def compute_loss(self, batch, return_outputs=False): loss = 0 for i in range(batch_size): - divergence = ( - chosen_ids[i] * chosen_mask[i] != rejected_ids[i] * rejected_mask[i] - ).squeeze().nonzero(as_tuple=True)[0] + (chosen_ids[i] * chosen_mask[i] != rejected_ids[i] * rejected_mask[i]) + .squeeze() + .nonzero(as_tuple=True)[0] + ) if len(divergence) <= 0: # Chosen and rejected prompts are identical. @@ -43,8 +39,8 @@ def compute_loss(self, batch, return_outputs=False): end_index = divergence[-1].item() # Loss is the negative log probability loss between the chosen and rejected prompt. - selected_chosen_rewards = chosen_rewards[i][start_index:end_index + 1] - selected_rejected_rewards = rejected_rewards[i][start_index:end_index + 1] + selected_chosen_rewards = chosen_rewards[i][start_index : end_index + 1] + selected_rejected_rewards = rejected_rewards[i][start_index : end_index + 1] loss += -torch.log( torch.sigmoid(selected_chosen_rewards - selected_rejected_rewards) @@ -67,14 +63,16 @@ def train(self): for step, batch in enumerate(self.train_dataloader): with self.accelerator.accumulate(self.model): loss = self.compute_loss(batch) - writer.add_scalar('training loss', loss, step) + writer.add_scalar("training loss", loss, step) self.accelerator.backward(loss) self.optimizer.step() if self.lr_scheduler is not None: self.lr_scheduler.step() self.optimizer.zero_grad() if step % log_step == 0: - logger.info(f"train epoch:[{idx}/{num_train_epochs}]\tstep:[{step}/{len(self.train_dataloader)}]\tloss:{loss}\tppl:{math.exp(loss)}\ttime:{time.time()-start}") + logger.info( + f"train epoch:[{idx}/{num_train_epochs}]\tstep:[{step}/{len(self.train_dataloader)}]\tloss:{loss}\tppl:{math.exp(loss)}\ttime:{time.time()-start}" + ) start = time.time() if self.eval_dataloader: @@ -94,6 +92,8 @@ def train(self): except OverflowError: eval_loss = float("inf") perplexity = float("inf") - logger.info(f"eval epoch:[{idx}/{num_train_epochs}]\tloss:[{eval_loss}]\tppl:[{perplexity}]\ttime:[{time.time()-start}]") - writer.add_scalar('eval loss', eval_loss, idx) - writer.add_scalar('perplexity', perplexity, idx) + logger.info( + f"eval epoch:[{idx}/{num_train_epochs}]\tloss:[{eval_loss}]\tppl:[{perplexity}]\ttime:[{time.time()-start}]" + ) + writer.add_scalar("eval loss", eval_loss, idx) + writer.add_scalar("perplexity", perplexity, idx) diff --git a/common/trainer/trainer.py b/common/trainer/trainer.py index 4dcdce359..d9974d003 100644 --- a/common/trainer/trainer.py +++ b/common/trainer/trainer.py @@ -1,13 +1,13 @@ - class Meta(type): def __init__(cls, name, bases, namespace, **kwargs): super().__init__(name, bases, namespace, **kwargs) - if not hasattr(cls, 'registory'): + if not hasattr(cls, "registory"): # this is the base class cls.registory = {} else: # this is the subclass cls.registory[name] = cls + class Trainer(metaclass=Meta): - pass \ No newline at end of file + pass diff --git a/dev/memory_status/csv_analysis.py b/dev/memory_status/csv_analysis.py index ac71222b2..1b6523ac4 100644 --- a/dev/memory_status/csv_analysis.py +++ b/dev/memory_status/csv_analysis.py @@ -26,7 +26,7 @@ plt.ylabel("rss GB") plt.subplot(121) -csv_wo_fsdp="./res/rss_per_process_wo_FSDP.csv" +csv_wo_fsdp = "./res/rss_per_process_wo_FSDP.csv" f = open(csv_wo_fsdp, mode="r", encoding="utf-8", newline="") csv_reader = csv.DictReader( f, @@ -40,10 +40,10 @@ line = next(csv_reader) rss_2ddp = list() for line in csv_reader: - rss_2ddp.append(float(line["rss"])) + rss_2ddp.append(float(line["rss"])) x = range(len(rss_2ddp)) plt.plot(x, rss_2ddp) plt.title("rss/process wo FSDP") plt.xlabel("second") plt.ylabel("rss GB") -plt.savefig("rss.png") \ No newline at end of file +plt.savefig("rss.png") diff --git a/dev/memory_status/memory_collect_ray.py b/dev/memory_status/memory_collect_ray.py index 1d5e28c63..aaa7b0273 100644 --- a/dev/memory_status/memory_collect_ray.py +++ b/dev/memory_status/memory_collect_ray.py @@ -1,11 +1,11 @@ import time + def collect_memory(eval_pid: int, name: str, output: str): import csv import matplotlib.pyplot as plt import psutil - from psutil import Process f = open(output + name + ".csv", mode="w", encoding="utf-8", newline="") @@ -42,10 +42,9 @@ def collect_memory(eval_pid: int, name: str, output: str): plt.savefig(output + name + ".png") f.close() + if __name__ == "__main__": pid = 66611 title = "rss_per_process_with_FSDP" output_path = "./res/" collect_memory(pid, title, output_path) - - diff --git a/examples/finetune/dolly1/process_data.py b/examples/finetune/dolly1/process_data_dolly1.py similarity index 73% rename from examples/finetune/dolly1/process_data.py rename to examples/finetune/dolly1/process_data_dolly1.py index 507c08245..716c073cf 100644 --- a/examples/finetune/dolly1/process_data.py +++ b/examples/finetune/dolly1/process_data_dolly1.py @@ -1,21 +1,20 @@ import os -import numpy as np -import pandas as pd from datasets import load_dataset + ds = load_dataset("tatsu-lab/alpaca") -train = ds['train'].to_pandas() +train = ds["train"].to_pandas() + def prep_data(df): df["context"] = df["input"] df["response"] = df["output"] df = df[df.response != ""] - df = df[ - ["instruction", "context", "response"] - ] + df = df[["instruction", "context", "response"]] return df + + df_train = prep_data(train) if not os.path.exists("data/train"): os.makedirs("data/train") df_train.to_json("data/train/train.jsonl", lines=True, orient="records") - diff --git a/examples/finetune/open_assistant/process_data.py b/examples/finetune/open_assistant/process_data_open_assistant.py similarity index 85% rename from examples/finetune/open_assistant/process_data.py rename to examples/finetune/open_assistant/process_data_open_assistant.py index e56da8b84..8598f86ea 100644 --- a/examples/finetune/open_assistant/process_data.py +++ b/examples/finetune/open_assistant/process_data_open_assistant.py @@ -1,10 +1,10 @@ import os -import numpy as np -import pandas as pd from datasets import load_dataset + ds = load_dataset("OpenAssistant/oasst1") -train = ds['train'].to_pandas() -val = ds['validation'].to_pandas() +train = ds["train"].to_pandas() +val = ds["validation"].to_pandas() + def prep_data(df): df_assistant = df[(df.role == "assistant") & (df["rank"] == 0.0)].copy() @@ -26,11 +26,11 @@ def prep_data(df): df_assistant = df_assistant[df_assistant.lang == "en"] - df_assistant = df_assistant[ - ["instruction", "context", "response"] - ] + df_assistant = df_assistant[["instruction", "context", "response"]] return df_assistant + + df_train = prep_data(train) df_val = prep_data(val) if not os.path.exists("data/train"): @@ -39,4 +39,3 @@ def prep_data(df): os.makedirs("data/validation") df_train.to_json("data/train/train.jsonl", lines=True, orient="records") df_val.to_json("data/validation/validation.jsonl", lines=True, orient="records") - diff --git a/examples/inference/api_server_openai/query_http_requests.py b/examples/inference/api_server_openai/query_http_requests.py index d7e57021e..6418a58f3 100644 --- a/examples/inference/api_server_openai/query_http_requests.py +++ b/examples/inference/api_server_openai/query_http_requests.py @@ -27,36 +27,36 @@ model_name = os.getenv("MODEL_TO_SERVE", "gpt2") body = { - "model": model_name, - "messages": [ - {"role": "assistant", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Tell me a long story with many words."} - ], - "temperature": 0.7, - "stream": True, + "model": model_name, + "messages": [ + {"role": "assistant", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a long story with many words."}, + ], + "temperature": 0.7, + "stream": True, } -proxies = { "http": None, "https": None} -with s.post(url, json=body, proxies=proxies) as response: - for chunk in response.iter_lines(decode_unicode=True): - if chunk is not None: - try: - # Get data from reponse chunk - chunk_data = chunk.split("data: ")[1] - - # Get message choices from data - choices = json.loads(chunk_data)["choices"] - - # Pick content from first choice - content = choices[0]["delta"]["content"] - - print(content, end="", flush=True) - except json.decoder.JSONDecodeError: - # Chunk was not formatted as expected - pass - except KeyError: - # No message was contained in the chunk - pass - except: - pass - print("") \ No newline at end of file +proxies = {"http": None, "https": None} +response = s.post(url, json=body, proxies=proxies) # type: ignore +for chunk in response.iter_lines(decode_unicode=True): + if chunk is not None: + try: + # Get data from reponse chunk + chunk_data = chunk.split("data: ")[1] + + # Get message choices from data + choices = json.loads(chunk_data)["choices"] + + # Pick content from first choice + content = choices[0]["delta"]["content"] + + print(content, end="", flush=True) + except json.decoder.JSONDecodeError: + # Chunk was not formatted as expected + pass + except KeyError: + # No message was contained in the chunk + pass + except Exception: + pass +print("") diff --git a/examples/inference/api_server_openai/query_openai_sdk.py b/examples/inference/api_server_openai/query_openai_sdk.py index 48e0974a4..d17e9f0bb 100644 --- a/examples/inference/api_server_openai/query_openai_sdk.py +++ b/examples/inference/api_server_openai/query_openai_sdk.py @@ -26,10 +26,10 @@ chat_completion = openai.ChatCompletion.create( model=model_name, messages=[ - {"role": "assistant", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Tell me a long story with many words."} + {"role": "assistant", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a long story with many words."}, ], temperature=0.7, stream=False, ) -print(chat_completion) \ No newline at end of file +print(chat_completion) diff --git a/examples/inference/api_server_simple/query_single.py b/examples/inference/api_server_simple/query_single.py index 890abf549..3d74bbb93 100644 --- a/examples/inference/api_server_simple/query_single.py +++ b/examples/inference/api_server_simple/query_single.py @@ -15,20 +15,48 @@ # import requests -import time import argparse +from typing import Dict, Union -parser = argparse.ArgumentParser(description="Example script to query with single request", add_help=True) -parser.add_argument("--model_endpoint", default="http://127.0.0.1:8000", type=str, help="Deployed model endpoint.") -parser.add_argument("--streaming_response", default=False, action="store_true", help="Whether to enable streaming response.") -parser.add_argument("--max_new_tokens", default=None, help="The maximum numbers of tokens to generate.") -parser.add_argument("--temperature", default=None, help="The value used to modulate the next token probabilities.") -parser.add_argument("--top_p", default=None, help="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to`Top p` or higher are kept for generation.") -parser.add_argument("--top_k", default=None, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.") +parser = argparse.ArgumentParser( + description="Example script to query with single request", add_help=True +) +parser.add_argument( + "--model_endpoint", + default="http://127.0.0.1:8000", + type=str, + help="Deployed model endpoint.", +) +parser.add_argument( + "--streaming_response", + default=False, + action="store_true", + help="Whether to enable streaming response.", +) +parser.add_argument( + "--max_new_tokens", default=None, help="The maximum numbers of tokens to generate." +) +parser.add_argument( + "--temperature", + default=None, + help="The value used to modulate the next token probabilities.", +) +parser.add_argument( + "--top_p", + default=None, + help="If set to float < 1, only the smallest set of most probable tokens \ + with probabilities that add up to `Top p` or higher are kept for generation.", +) +parser.add_argument( + "--top_k", + default=None, + help="The number of highest probability vocabulary tokens to keep \ + for top-k-filtering.", +) args = parser.parse_args() prompt = "Once upon a time," -config = {} +config: Dict[str, Union[int, float]] = {} if args.max_new_tokens: config["max_new_tokens"] = int(args.max_new_tokens) if args.temperature: @@ -40,12 +68,17 @@ sample_input = {"text": prompt, "config": config, "stream": args.streaming_response} -proxies = { "http": None, "https": None} -outputs = requests.post(args.model_endpoint, proxies=proxies, json=sample_input, stream=args.streaming_response) +proxies = {"http": None, "https": None} +outputs = requests.post( + args.model_endpoint, + proxies=proxies, # type: ignore + json=sample_input, + stream=args.streaming_response, +) if args.streaming_response: outputs.raise_for_status() for output in outputs.iter_content(chunk_size=None, decode_unicode=True): - print(output, end='', flush=True) + print(output, end="", flush=True) print() else: print(outputs.text, flush=True) diff --git a/examples/rlhf/process_data.py b/examples/rlhf/process_data.py index fd5305a32..e45d8a34d 100644 --- a/examples/rlhf/process_data.py +++ b/examples/rlhf/process_data.py @@ -1,58 +1,49 @@ import os -import numpy as np -import pandas as pd from datasets import load_dataset import argparse - def prep_data(df, colume): - df = df[ - colume - ] + df = df[colume] return df -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--dataset", type=str, required=False, - default='Dahoas/rm-static', + default="Dahoas/rm-static", ) parser.add_argument( "--save_dir", type=str, required=False, - default='data', + default="data", ) parser.add_argument( - "--mode", - type=str, - required=False, - default='reward', - choices=['reward', 'rlhf'] + "--mode", type=str, required=False, default="reward", choices=["reward", "rlhf"] ) args = parser.parse_args() ds = load_dataset(args.dataset) - train = ds['train'].to_pandas() - test = ds['test'].to_pandas() + train = ds["train"].to_pandas() + test = ds["test"].to_pandas() - if args.mode == 'reward': + if args.mode == "reward": df_train = prep_data(train, colume=["prompt", "chosen", "rejected"]) - df_test = prep_data(test, colume=["prompt", "chosen", "rejected"]) - elif args.mode == 'rlhf': + df_test = prep_data(test, colume=["prompt", "chosen", "rejected"]) + elif args.mode == "rlhf": df_train = prep_data(train, colume=["prompt"]) - df_test = prep_data(test, colume=["prompt"]) + df_test = prep_data(test, colume=["prompt"]) else: - raise ValueError('unsupport mode') + raise ValueError("unsupport mode") save_dir = os.path.join(args.save_dir, args.mode) if not os.path.exists(save_dir): os.makedirs(save_dir) - df_train.to_json(os.path.join(save_dir, 'train.jsonl'), lines=True, orient="records") - df_test.to_json(os.path.join(save_dir, 'test.jsonl'), lines=True, orient="records") + df_train.to_json(os.path.join(save_dir, "train.jsonl"), lines=True, orient="records") + df_test.to_json(os.path.join(save_dir, "test.jsonl"), lines=True, orient="records") diff --git a/finetune/finetune.py b/finetune/finetune.py index 90351e3a8..0815dabfe 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -1,10 +1,8 @@ #!/usr/bin/env python import os -import time import argparse -import traceback -from typing import Any, Dict +from typing import Any, Dict, Union import accelerate from accelerate.utils import is_xpu_available @@ -17,19 +15,23 @@ from pydantic_yaml import parse_yaml_raw_as from accelerate import FullyShardedDataParallelPlugin -from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullOptimStateDictConfig, + FullStateDictConfig, +) import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import common from finetune.finetune_config import FinetuneConfig -def get_accelerate_environment_variable(mode: str, config: Dict[str, Any]) -> dict: - mixed_precision = config["Training"]["mixed_precision"] +def get_accelerate_environment_variable(mode: str, config: Union[Dict[str, Any], None]) -> dict: + mixed_precision = config["Training"]["mixed_precision"] if config else "no" mode_env_vars = { "CPU_DDP": { - "ACCELERATE_USE_CPU": "true", + "ACCELERATE_USE_CPU": "true", "ACCELERATE_USE_IPEX": "true", "ACCELERATE_MIXED_PRECISION": mixed_precision, }, @@ -53,7 +55,7 @@ def get_accelerate_environment_variable(mode: str, config: Dict[str, Any]) -> di "FSDP_USE_ORIG_PARAMS": "false", "FSDP_SYNC_MODULE_STATES": "true", "ACCELERATE_MIXED_PRECISION": mixed_precision, - } + }, } if mode not in mode_env_vars: raise ValueError(f"accelerate mode must be one of {list(mode_env_vars.keys())}") @@ -64,92 +66,112 @@ def train_func(config: Dict[str, Any]): cwd = config.get("cwd") if cwd: os.chdir(cwd) - + gradient_accumulation_steps = config["Training"].get("gradient_accumulation_steps", 1) accelerate_mode = config["Training"]["accelerate_mode"] if accelerate_mode in ["GPU_FSDP"]: fsdp_plugin = FullyShardedDataParallelPlugin( state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False), - optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=False, rank0_only=False) + optim_state_dict_config=FullOptimStateDictConfig( + offload_to_cpu=False, rank0_only=False + ), ) else: fsdp_plugin = None - accelerator = accelerate.Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, - fsdp_plugin=fsdp_plugin) - common.logger.info(f"accelerator generate finish, accelerator device type = {accelerator.device}") + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, fsdp_plugin=fsdp_plugin + ) + common.logger.info( + f"accelerator generate finish, accelerator device type = {accelerator.device}" + ) seed = config["Training"].get("seed") if seed is not None: accelerate.utils.set_seed(seed) - datasets = common.dataset.Dataset.registory.get("HuggingfaceDataset")()(config = { - "name": config["Dataset"]["train_file"], - "validation_file": config["Dataset"]["validation_file"], - "validation_split_percentage": config["Dataset"]["validation_split_percentage"] - }) - - tokenizer = common.tokenizer.Tokenizer.registory.get("HuggingFaceTokenizer")()(config = { - "name": config["General"]["base_model"], - "config": config["General"]["config"] - }) - - model = common.model.Model.registory.get("HuggingFaceModelForCausalLM")()(config = { - "name": config["General"]["base_model"], - "config": config["General"]["config"], - "lora_config": config["General"]["lora_config"] if config["General"].get("lora_config") else None - }) - - optimizer = common.optimizer.Optimizer.registory.get("DefaultOptimizer")()(model, config = { - "name": config["Training"]["optimizer"], - "config": { - "lr": config["Training"]["learning_rate"] - }, - }) - - trainer = common.trainer.Trainer.registory.get("DefaultTrainer")(config = { - "num_train_epochs": config["Training"]["epochs"], - "max_train_step": config["Training"].get("max_train_steps", None), - "log_step": 1, - "output": config["General"]["output_dir"], - "dataprocesser": { - "type": "GeneralProcesser", - "per_device_train_batch_size": config["Training"]["batch_size"], - "per_device_eval_batch_size": config["Training"]["batch_size"], - "preprocessing_num_workers": config["Dataset"].get("preprocessing_num_workers", 1), - "shuffle": True - }, - "lr_scheduler": { - "enable": True, - "max_train_steps": None, - "lr_scheduler_type": config["Training"]["lr_scheduler"], - "num_warmup_steps": 0, + datasets = common.dataset.Dataset.registory.get("HuggingfaceDataset")()( + config={ + "name": config["Dataset"]["train_file"], + "validation_file": config["Dataset"]["validation_file"], + "validation_split_percentage": config["Dataset"]["validation_split_percentage"], + } + ) + + tokenizer = common.tokenizer.Tokenizer.registory.get("HuggingFaceTokenizer")()( + config={ + "name": config["General"]["base_model"], + "config": config["General"]["config"], + } + ) + + model = common.model.Model.registory.get("HuggingFaceModelForCausalLM")()( + config={ + "name": config["General"]["base_model"], + "config": config["General"]["config"], + "lora_config": config["General"]["lora_config"] + if config["General"].get("lora_config") + else None, + } + ) + + optimizer = common.optimizer.Optimizer.registory.get("DefaultOptimizer")()( + model, + config={ + "name": config["Training"]["optimizer"], + "config": {"lr": config["Training"]["learning_rate"]}, }, - "checkpoint": { - "root_path": config["General"]["checkpoint_dir"], - } if config["General"].get("checkpoint_dir") else None - }) + ) + + trainer = common.trainer.Trainer.registory.get("DefaultTrainer")( + config={ + "num_train_epochs": config["Training"]["epochs"], + "max_train_step": config["Training"].get("max_train_steps", None), + "log_step": 1, + "output": config["General"]["output_dir"], + "dataprocesser": { + "type": "GeneralProcesser", + "per_device_train_batch_size": config["Training"]["batch_size"], + "per_device_eval_batch_size": config["Training"]["batch_size"], + "preprocessing_num_workers": config["Dataset"].get("preprocessing_num_workers", 1), + "shuffle": True, + }, + "lr_scheduler": { + "enable": True, + "max_train_steps": None, + "lr_scheduler_type": config["Training"]["lr_scheduler"], + "num_warmup_steps": 0, + }, + "checkpoint": { + "root_path": config["General"]["checkpoint_dir"], + } + if config["General"].get("checkpoint_dir") + else None, + } + ) try: - common.logger.info(f"trainer prepare start") + common.logger.info("trainer prepare start") model.training = True trainer.prepare(model, tokenizer, datasets, optimizer, accelerator) except Exception as e: common.logger.critical(e, exc_info=True) exit(1) - common.logger.info(f"trainer prepare finish") + common.logger.info("trainer prepare finish") try: - common.logger.info(f"train start") + common.logger.info("train start") trainer.train() except Exception as e: common.logger.critical(e, exc_info=True) exit(1) - common.logger.info(f"train finish") + common.logger.info("train finish") def get_finetune_config(): - parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") + parser = argparse.ArgumentParser( + description="Finetune a transformers model on a causal language modeling task" + ) parser.add_argument( "--config_file", type=str, @@ -165,7 +187,7 @@ def get_finetune_config(): return finetune_config.dict() -def main(external_config = None): +def main(external_config=None): if not external_config: config = get_finetune_config() else: @@ -186,31 +208,31 @@ def main(external_config = None): if not ray.is_initialized(): runtime_env = { "env_vars": { - "OMP_NUM_THREADS": str(resources_per_worker["CPU"]), + "OMP_NUM_THREADS": str(resources_per_worker["CPU"]), "CCL_ZE_IPC_EXCHANGE": "sockets", "CCL_WORKER_COUNT": str(ccl_worker_count), "CCL_LOG_LEVEL": "info", "WORLD_SIZE": str(num_training_workers), "FI_TCP_IFACE": "lo", - "FI_PROVIDER": "tcp" + "FI_PROVIDER": "tcp", } } accelerate_env_vars = get_accelerate_environment_variable(accelerate_mode, config) runtime_env["env_vars"].update(accelerate_env_vars) - if config["General"]["gpt_base_model"] == True: + if config["General"]["gpt_base_model"] is True: runtime_env["pip"] = ["transformers==4.26.0"] - ray.init(runtime_env = runtime_env) + ray.init(runtime_env=runtime_env) common.logger.info(f"ray available resources = {ray.available_resources()}") scaling_config = ScalingConfig( - num_workers = num_training_workers, - use_gpu = use_gpu, - resources_per_worker = resources_per_worker, - placement_strategy = "SPREAD", + num_workers=num_training_workers, + use_gpu=use_gpu, + resources_per_worker=resources_per_worker, + placement_strategy="SPREAD", ) device = config["Training"]["device"].lower() @@ -223,7 +245,7 @@ def main(external_config = None): else: customer_torch_config = config.get("torch_config") torch_config = common.TorchConfig(**customer_torch_config, device=device) - + if config.get("failure_config", None) is None: failure_config = FailureConfig() else: @@ -243,11 +265,12 @@ def main(external_config = None): train_loop_config=config, scaling_config=scaling_config, torch_config=torch_config, - run_config=run_config + run_config=run_config, ) results = trainer.fit() return results + if __name__ == "__main__": main() diff --git a/format.sh b/format.sh new file mode 100755 index 000000000..187dd0315 --- /dev/null +++ b/format.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Formats the files by running pre-commit hooks. + +while getopts 'ah' opt; do + case "$opt" in + a) + pip install -q pre-commit + pre-commit install + pre-commit run --all-files + exit $? + ;; + + "?"|h) + echo -e "Usage: $(basename "$0") [-a]\n\t-a\trun on all files\n\t-h\tshow this help message" + exit 1 + ;; + esac +done +shift "$((OPTIND -1))" + +pip install -q pre-commit +pre-commit install +pre-commit run +exit $? diff --git a/inference/api_openai_backend/openai_protocol.py b/inference/api_openai_backend/openai_protocol.py index 351e01c66..d482701d9 100644 --- a/inference/api_openai_backend/openai_protocol.py +++ b/inference/api_openai_backend/openai_protocol.py @@ -34,10 +34,10 @@ # limitations under the License. # -from typing import Any, Dict, Literal, List, TypeVar, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, Literal, List, TypeVar, Optional, Tuple, Type, Union -from pydantic import BaseModel, Field, root_validator, validator -from enum import IntEnum, Enum +from pydantic import BaseModel, Field, root_validator +from enum import Enum import uuid import time import yaml @@ -77,9 +77,7 @@ class UsageInfo(BaseModel): completion_tokens: Optional[int] = 0 @classmethod - def from_response( - cls, response: Union["ModelResponse", Dict[str, Any]] - ) -> "UsageInfo": + def from_response(cls, response: Union["ModelResponse", Dict[str, Any]]) -> "UsageInfo": if isinstance(response, BaseModel): response_dict = response.dict() else: @@ -183,7 +181,8 @@ def yaml( **kwargs, ): """ - Generate a YAML representation of the model, `include` and `exclude` arguments as per `dict()`. + Generate a YAML representation of the model, `include` and `exclude` + arguments as per `dict()`. """ return yaml.dump( self.dict( @@ -211,9 +210,7 @@ def get_properties(cls): return [prop for prop in dir(cls) if isinstance(getattr(cls, prop), property)] def dict(self, *args, **kwargs): - self.__dict__.update( - {prop: getattr(self, prop) for prop in self.get_properties()} - ) + self.__dict__.update({prop: getattr(self, prop) for prop in self.get_properties()}) return super().dict(*args, **kwargs) # type: ignore def json( @@ -221,9 +218,7 @@ def json( *args, **kwargs, ) -> str: - self.__dict__.update( - {prop: getattr(self, prop) for prop in self.get_properties()} - ) + self.__dict__.update({prop: getattr(self, prop) for prop in self.get_properties()}) return super().json(*args, **kwargs) # type: ignore @@ -247,9 +242,7 @@ def text_or_error_or_finish_reason(cls, values): and values.get("error") is None and values.get("finish_reason") is None ): - raise ValueError( - "Either 'generated_text' or 'error' or 'finish_reason' must be set" - ) + raise ValueError("Either 'generated_text' or 'error' or 'finish_reason' must be set") return values @classmethod @@ -263,9 +256,7 @@ def merge_stream(cls, *responses: "ModelResponse") -> "ModelResponse": if len(responses) == 1: return responses[0] - generated_text = "".join( - [response.generated_text or "" for response in responses] - ) + generated_text = "".join([response.generated_text or "" for response in responses]) num_input_tokens = [ response.num_input_tokens for response in responses @@ -277,17 +268,13 @@ def merge_stream(cls, *responses: "ModelResponse") -> "ModelResponse": for response in responses if response.num_input_tokens_batch is not None ] - max_num_input_tokens_batch = ( - max(num_input_tokens_batch) if num_input_tokens_batch else None - ) + max_num_input_tokens_batch = max(num_input_tokens_batch) if num_input_tokens_batch else None num_generated_tokens = [ response.num_generated_tokens for response in responses if response.num_generated_tokens is not None ] - total_generated_tokens = ( - sum(num_generated_tokens) if num_generated_tokens else None - ) + total_generated_tokens = sum(num_generated_tokens) if num_generated_tokens else None num_generated_tokens_batch = [ response.num_generated_tokens_batch for response in responses @@ -308,9 +295,7 @@ def merge_stream(cls, *responses: "ModelResponse") -> "ModelResponse": if response.generation_time is not None ] total_generation_time = sum(generation_time) if generation_time else None - error = next( - (response.error for response in reversed(responses) if response.error), None - ) + error = next((response.error for response in reversed(responses) if response.error), None) return cls( generated_text=generated_text, @@ -341,9 +326,7 @@ def num_total_tokens(self) -> Optional[float]: @property def num_total_tokens_batch(self) -> Optional[float]: try: - return (self.num_input_tokens_batch or 0) + ( - self.num_generated_tokens_batch or 0 - ) + return (self.num_input_tokens_batch or 0) + (self.num_generated_tokens_batch or 0) except Exception: return None @@ -395,9 +378,7 @@ def __str__(self) -> str: return self.value @classmethod - def from_vllm_finish_reason( - cls, finish_reason: Optional[str] - ) -> Optional["FinishReason"]: + def from_vllm_finish_reason(cls, finish_reason: Optional[str]) -> Optional["FinishReason"]: if finish_reason is None: return None if finish_reason == "stop": @@ -406,4 +387,4 @@ def from_vllm_finish_reason( return cls.LENGTH if finish_reason == "abort": return cls.CANCELLED - return cls.STOP \ No newline at end of file + return cls.STOP diff --git a/inference/api_openai_backend/query_client.py b/inference/api_openai_backend/query_client.py index e19566243..ecc19f1d9 100644 --- a/inference/api_openai_backend/query_client.py +++ b/inference/api_openai_backend/query_client.py @@ -34,11 +34,11 @@ from typing import Dict from fastapi import HTTPException -from .openai_protocol import ModelCard, Prompt -from .openai_protocol import Prompt, ModelResponse +from .openai_protocol import ModelCard, Prompt, ModelResponse from .request_handler import handle_request -class RouterQueryClient(): + +class RouterQueryClient: def __init__(self, serve_deployments): self.serve_deployments = serve_deployments @@ -51,9 +51,7 @@ async def query(self, model: str, prompt: Prompt, request_id: str): responses = [resp async for resp in response_stream] return ModelResponse.merge_stream(*responses) - async def stream( - self, model: str, prompt: Prompt, request_id: str - ): + async def stream(self, model: str, prompt: Prompt, request_id: str): if model in self.serve_deployments: deploy_handle = self.serve_deployments[model] else: @@ -76,7 +74,9 @@ async def stream( model=model, prompt=prompt, request_id=request_id, - async_iterator=deploy_handle.options(stream=True).stream_response.options(stream=True, use_new_handle_api=True).remote(prompt_content, gen_config) + async_iterator=deploy_handle.options(stream=True) + .stream_response.options(stream=True, use_new_handle_api=True) + .remote(prompt_content, gen_config), ): yield x diff --git a/inference/api_openai_backend/request_handler.py b/inference/api_openai_backend/request_handler.py index 701684953..00ae49256 100644 --- a/inference/api_openai_backend/request_handler.py +++ b/inference/api_openai_backend/request_handler.py @@ -35,9 +35,8 @@ import asyncio import traceback from typing import AsyncIterator, List -from fastapi import Request, status, HTTPException +from fastapi import status, HTTPException from starlette.responses import JSONResponse -from starlette.requests import Request from pydantic import ValidationError as PydanticValidationError from logger import get_logger from .openai_protocol import Prompt, ModelResponse, ErrorResponse, FinishReason @@ -56,12 +55,11 @@ def __init__( self.message = message self.type = type -def openai_exception_handler(request: Request, exc: OpenAIHTTPException): - assert isinstance( - exc, OpenAIHTTPException - ), f"Unable to handle invalid exception {type(exc)}" + +def openai_exception_handler(exc: OpenAIHTTPException): + assert isinstance(exc, OpenAIHTTPException), f"Unable to handle invalid exception {type(exc)}" if exc.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR: - message = f"Internal Server Error" + message = "Internal Server Error" internal_message = message exc_type = "InternalServerError" else: @@ -78,6 +76,7 @@ def openai_exception_handler(request: Request, exc: OpenAIHTTPException): ) return JSONResponse(content=err_response.dict(), status_code=exc.status_code) + def extract_message_from_exception(e: Exception) -> str: # If the exception is a Ray exception, we need to dig through the text to get just # the exception message without the stack trace @@ -85,8 +84,8 @@ def extract_message_from_exception(e: Exception) -> str: # format_exception_only in that case) message_lines = traceback.format_exception_only(type(e), e)[-1].strip().split("\n") message = "" - # The stack trace lines will be prefixed with spaces, so we need to start from the bottom - # and stop at the last line before a line with a space + # The stack trace lines will be prefixed with spaces, so we need to start + # from the bottom and stop at the last line before a line with a space found_last_line_before_stack_trace = False for line in reversed(message_lines): if not line.startswith(" "): @@ -97,6 +96,7 @@ def extract_message_from_exception(e: Exception) -> str: message = message.strip() return message + async def handle_request( model: str, request_id: str, @@ -124,7 +124,7 @@ async def handle_request( # We do not raise here because that would cause a disconnection for streaming. -def _get_response_for_error(e: Exception, request_id: str): +def _get_response_for_error(e, request_id: str): """Convert an exception to an ModelResponse object""" logger.error(f"Request {request_id} failed with:", exc_info=e) status_code = status.HTTP_500_INTERNAL_SERVER_ERROR diff --git a/inference/api_openai_backend/router_app.py b/inference/api_openai_backend/router_app.py index 1afed2659..fc4328610 100644 --- a/inference/api_openai_backend/router_app.py +++ b/inference/api_openai_backend/router_app.py @@ -43,7 +43,12 @@ from logger import get_logger from .request_handler import OpenAIHTTPException, openai_exception_handler from .query_client import RouterQueryClient -from .openai_protocol import Prompt, ModelResponse, CompletionRequest, ChatCompletionRequest +from .openai_protocol import ( + Prompt, + ModelResponse, + CompletionRequest, + ChatCompletionRequest, +) from .openai_protocol import ( ChatCompletionResponse, CompletionResponse, @@ -64,6 +69,7 @@ # timeout in 10 minutes. Streaming can take longer than 3 min TIMEOUT = float(os.environ.get("ROUTER_HTTP_TIMEOUT", 600)) + def init() -> FastAPI: router_app = FastAPI() router_app.add_exception_handler(OpenAIHTTPException, openai_exception_handler) @@ -102,9 +108,7 @@ async def _completions_wrapper( logger.error(f"{subresult_dict['error']}") all_results.pop() had_error = True - yield "data: " + ModelResponse( - **subresult_dict - ).json() + "\n\n" + yield "data: " + ModelResponse(**subresult_dict).json() + "\n\n" # Return early in case of an error break choices = [ @@ -117,9 +121,7 @@ async def _completions_wrapper( usage = None if subresult_dict["finish_reason"]: usage = ( - UsageInfo.from_response( - ModelResponse.merge_stream(*all_results) - ) + UsageInfo.from_response(ModelResponse.merge_stream(*all_results)) if all_results else None ) @@ -173,18 +175,14 @@ async def _chat_completions_wrapper( subresult_dict["finish_reason"] = None all_results.pop() had_error = True - yield "data: " + ModelResponse( - **subresult_dict - ).json() + "\n\n" + yield "data: " + ModelResponse(**subresult_dict).json() + "\n\n" # Return early in case of an error break else: finish_reason = subresult_dict["finish_reason"] - choices: List[DeltaChoices] = [ + choices = [ DeltaChoices( - delta=DeltaContent( - content=subresult_dict["generated_text"] or "" - ), + delta=DeltaContent(content=subresult_dict["generated_text"] or ""), index=0, finish_reason=None, ) @@ -200,7 +198,7 @@ async def _chat_completions_wrapper( # Return early in case of an error break if not had_error: - choices: List[DeltaChoices] = [ + choices = [ DeltaChoices( delta=DeltaEOS(), index=0, @@ -334,11 +332,7 @@ async def chat( request_id, body, response, - self.query_client.stream( - body.model, - prompt, - request_id - ) + self.query_client.stream(body.model, prompt, request_id), ), media_type="text/event-stream", ) diff --git a/inference/api_server_openai.py b/inference/api_server_openai.py index 45179d6c5..77831a9d2 100644 --- a/inference/api_server_openai.py +++ b/inference/api_server_openai.py @@ -68,19 +68,20 @@ def router_application(deployments): } ).bind(merged_client) + def openai_serve_run(deployments, host, route_prefix, port): router_app = router_application(deployments) serve.run( - router_app, - name="router", - route_prefix=route_prefix, - host=host, - _blocking=True, - ).options( + router_app, + name="router", + route_prefix=route_prefix, + host=host, + _blocking=True, + ).options( stream=True, use_new_handle_api=True, ) deployment_address = f"http://{host}:{port}{route_prefix}" print(f"Deployment is ready at `{deployment_address}`.") - return deployment_address \ No newline at end of file + return deployment_address diff --git a/inference/api_server_simple.py b/inference/api_server_simple.py index a52348b84..0663700d8 100644 --- a/inference/api_server_simple.py +++ b/inference/api_server_simple.py @@ -22,7 +22,14 @@ def serve_run(deployments, model_list): for model_id, infer_conf in model_list.items(): print("deploy model: ", model_id) deployment = deployments[model_id] - handle = serve.run(deployment, _blocking=True, host=infer_conf.host, port=infer_conf.port, name=infer_conf.name, route_prefix=infer_conf.route_prefix) + serve.run( + deployment, + _blocking=True, + host=infer_conf.host, + port=infer_conf.port, + name=infer_conf.name, + route_prefix=infer_conf.route_prefix, + ) deployment_name = infer_conf.name if infer_conf.host == "0.0.0.0": all_nodes = ray.nodes() diff --git a/inference/chat_process.py b/inference/chat_process.py index f52c6a88f..05849aa22 100644 --- a/inference/chat_process.py +++ b/inference/chat_process.py @@ -14,11 +14,12 @@ # limitations under the License. # + class ChatModel: human_id = "" bot_id = "" unknown_id = "" - MEANINGLESS_WORDS = ['', '', '<|endoftext|>', '
'] + MEANINGLESS_WORDS = ["", "", "<|endoftext|>", "
"] stop_words = [""] def __init__(self, intro, human_id, bot_id, stop_words) -> None: @@ -30,7 +31,7 @@ def __init__(self, intro, human_id, bot_id, stop_words) -> None: def prepare_prompt(self, messages: list): """Prepare prompt from history messages.""" - prompt = '' + prompt = "" for msg in messages: role, content = msg.role, msg.content if role == "user": @@ -54,16 +55,19 @@ def convert_output(self, output: str): output = output.replace(word, "") text = output # remove partial human_id or bot id - if '\n' in text and (human_id.startswith(text[text.rfind('\n')+1:]) or - bot_id.startswith(text[text.rfind('\n')+1])): - text = text[:text.rfind('\n')] + if "\n" in text and ( + human_id.startswith(text[text.rfind("\n") + 1 :]) + or bot_id.startswith(text[text.rfind("\n") + 1]) + ): + text = text[: text.rfind("\n")] return text - def get_prompt(self ,messages): + def get_prompt(self, messages): """Generate response based on messages.""" prompt = self.prepare_prompt(messages) return prompt - + + class ChatModelGptJ(ChatModel): def __init__(self, intro, human_id, bot_id, stop_words): super().__init__(intro, human_id, bot_id, stop_words) @@ -90,6 +94,7 @@ def prepare_prompt(self, messages: list): prompt += f"{self.bot_id}:\n" return prompt + class ChatModelLLama(ChatModel): def __init__(self, intro, human_id, bot_id, stop_words): super().__init__(intro, human_id, bot_id, stop_words) @@ -113,5 +118,8 @@ def prepare_prompt(self, messages: list): prompt += f"{self.bot_id}:\n" return prompt + if __name__ == "__main__": - process_tool = ChatModelGptJ("### Instruction", "### Response", stop_words=["##", "### Instruction"]) + process_tool = ChatModelGptJ( + "", "### Instruction", "### Response", stop_words=["##", "### Instruction"] + ) diff --git a/inference/deepspeed_predictor.py b/inference/deepspeed_predictor.py index c2bf14835..b9ec6cda9 100644 --- a/inference/deepspeed_predictor.py +++ b/inference/deepspeed_predictor.py @@ -14,38 +14,48 @@ import os from predictor import Predictor from utils import get_torch_dtype +from inference.inference_config import ( + InferenceConfig, + DEVICE_CPU, + DEVICE_XPU, + IPEX_PRECISION_BF16, +) -from inference.inference_config import InferenceConfig, DEVICE_CPU, DEVICE_XPU, IPEX_PRECISION_BF16 class DSPipeline: - def __init__( - self, - infer_conf: InferenceConfig, - pad_token_id, - stopping_criteria - ): + def __init__(self, infer_conf: InferenceConfig, pad_token_id, stopping_criteria): self.device = torch.device(infer_conf.device) self.pad_token_id = pad_token_id self.stopping_criteria = stopping_criteria model_desc = infer_conf.model_description model_config = model_desc.config - hf_config = AutoConfig.from_pretrained(model_desc.model_id_or_path, torchscript=True, trust_remote_code=model_config.trust_remote_code) + hf_config = AutoConfig.from_pretrained( + model_desc.model_id_or_path, + torchscript=True, + trust_remote_code=model_config.trust_remote_code, + ) # get correct torch type for loading HF model torch_dtype = get_torch_dtype(infer_conf, hf_config) - self.model = AutoModelForCausalLM.from_pretrained(model_desc.model_id_or_path, - config=hf_config, - torch_dtype=torch_dtype, - low_cpu_mem_usage=True, - **model_config.dict()) - + self.model = AutoModelForCausalLM.from_pretrained( + model_desc.model_id_or_path, + config=hf_config, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + **model_config.dict(), + ) + if model_desc.peft_model_id_or_path: from peft import PeftModel + self.model = PeftModel.from_pretrained(self.model, model_desc.peft_model_id_or_path) if model_desc.peft_type == "deltatuner": from deltatuner import DeltaTunerModel - self.model = DeltaTunerModel.from_pretrained(self.model, model_desc.peft_model_id_or_path) + + self.model = DeltaTunerModel.from_pretrained( + self.model, model_desc.peft_model_id_or_path + ) self.model = self.model.merge_and_unload() self.model = self.model.eval().to(self.device) @@ -54,29 +64,40 @@ def __init__( self.model.eval() def streaming_generate(self, inputs, streamer, **generate_kwargs): - self.model.generate(inputs, - pad_token_id=self.pad_token_id, - stopping_criteria=self.stopping_criteria, - streamer=streamer, - **generate_kwargs) + self.model.generate( + inputs, + pad_token_id=self.pad_token_id, + stopping_criteria=self.stopping_criteria, + streamer=streamer, + **generate_kwargs, + ) def generate(self, inputs, **config): gen_tokens = self.model.generate( inputs, pad_token_id=self.pad_token_id, stopping_criteria=self.stopping_criteria, - **config + **config, ) return gen_tokens + @ray.remote class PredictionWorker(TorchDistributedWorker): - """A PredictionWorker is a Ray remote actor that runs a single shard of a DeepSpeed job. + """A PredictionWorker is a Ray remote actor that runs a single shard + of a DeepSpeed job. - Multiple PredictionWorkers of the same WorkerGroup form a PyTorch DDP process - group and work together under the orchestration of DeepSpeed. + Multiple PredictionWorkers of the same WorkerGroup form a PyTorch DDP + process group and work together under the orchestration of DeepSpeed. """ - def __init__(self, world_size: int, infer_conf: InferenceConfig, pad_token_id, stopping_criteria): + + def __init__( + self, + world_size: int, + infer_conf: InferenceConfig, + pad_token_id, + stopping_criteria, + ): self.world_size = world_size self.infer_conf = infer_conf self.pad_token_id = pad_token_id @@ -92,8 +113,8 @@ def init_model(self, local_rank: int): else: replace_with_kernel_inject = True - os.environ['LOCAL_RANK'] = str(local_rank) - os.environ['WORLD_SIZE'] = str(self.world_size) + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(self.world_size) pipe = DSPipeline( self.infer_conf, @@ -105,17 +126,23 @@ def init_model(self, local_rank: int): pipe.model, mp_size=self.world_size, dtype=torch.bfloat16, - replace_with_kernel_inject=replace_with_kernel_inject + replace_with_kernel_inject=replace_with_kernel_inject, ) if self.infer_conf.ipex.enabled: import intel_extension_for_pytorch as ipex - try: ipex._C.disable_jit_linear_repack() - except: pass + + try: + ipex._C.disable_jit_linear_repack() + except Exception: + pass pipe.model = ipex.optimize_transformers( pipe.model.eval(), - dtype=torch.bfloat16 if self.infer_conf.ipex.precision == IPEX_PRECISION_BF16 else torch.float32, - inplace=True) + dtype=torch.bfloat16 + if self.infer_conf.ipex.precision == IPEX_PRECISION_BF16 + else torch.float32, + inplace=True, + ) self.generator = pipe @@ -125,6 +152,7 @@ def streaming_generate(self, inputs, streamer, **config): def generate(self, inputs, **config): return self.generator.generate(inputs, **config) + class DeepSpeedPredictor(Predictor): def __init__(self, infer_conf: InferenceConfig) -> None: super().__init__(infer_conf) @@ -138,7 +166,7 @@ def __init__(self, infer_conf: InferenceConfig) -> None: scaling_conf = ScalingConfig( use_gpu=use_gpu, num_workers=infer_conf.workers_per_group, - resources_per_worker=resource + resources_per_worker=resource, ) print(scaling_conf) @@ -150,7 +178,7 @@ def __del__(self): # Use dummy streamer to ignore other workers' ouputs def _create_dummy_streamer(self): - class DummyStreamer(): + class DummyStreamer: def put(self, value): pass @@ -180,35 +208,41 @@ def _init_worker_group(self, scaling_config: ScalingConfig): # Create the prediction workers. self.prediction_workers = [ - prediction_worker_cls.remote(scaling_config.num_workers, self.infer_conf, - self.pad_token_id, self.stopping_criteria) + prediction_worker_cls.remote( + scaling_config.num_workers, + self.infer_conf, + self.pad_token_id, + self.stopping_criteria, + ) for i in range(scaling_config.num_workers) ] # Initialize torch distributed process group for the workers. - local_ranks = init_torch_dist_process_group(self.prediction_workers, backend="ccl" if self.infer_conf.device != "cuda" else "nccl") + local_ranks = init_torch_dist_process_group( + self.prediction_workers, + backend="ccl" if self.infer_conf.device != "cuda" else "nccl", + ) # Initialize the model on each worker. - ray.get([ - worker.init_model.remote(local_rank) - for worker, local_rank in zip(self.prediction_workers, local_ranks) - ]) + ray.get( + [ + worker.init_model.remote(local_rank) + for worker, local_rank in zip(self.prediction_workers, local_ranks) + ] + ) def streaming_generate(self, prompt, streamer, **config): input_ids = self.tokenize_inputs(prompt) inputs_ref = ray.put(input_ids) self.prediction_workers[0].streaming_generate.remote(inputs_ref, streamer, **config) for worker in self.prediction_workers[1:]: - worker.streaming_generate.remote(inputs_ref, self._create_dummy_streamer(), **config) + worker.streaming_generate.remote(inputs_ref, self._create_dummy_streamer(), **config) def generate(self, prompt, **config): input_ids = self.tokenize_inputs(prompt) inputs_ref = ray.put(input_ids) gen_tokens = ray.get( - [ - worker.generate.remote(inputs_ref, **config) - for worker in self.prediction_workers - ] + [worker.generate.remote(inputs_ref, **config) for worker in self.prediction_workers] )[0] return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0] @@ -219,7 +253,11 @@ def get_streamer(self): class RayTextIteratorStreamer(TextStreamer): def __init__( - self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs + self, + tokenizer: "AutoTokenizer", + skip_prompt: bool = False, + timeout: Optional[float] = None, + **decode_kwargs, ): super().__init__(tokenizer, skip_prompt, **decode_kwargs) self.text_queue = Queue() @@ -240,19 +278,13 @@ def __next__(self): raise StopIteration() else: return value + return RayTextIteratorStreamer(self.tokenizer, skip_special_tokens=True) - def predict( - self, - data: List[str], - **kwargs - ) -> str: + def predict(self, data: List[str], **kwargs) -> str: data_ref = ray.put(data) prediction = ray.get( - [ - worker.generate.remote(data_ref, **kwargs) - for worker in self.prediction_workers - ] + [worker.generate.remote(data_ref, **kwargs) for worker in self.prediction_workers] ) return prediction diff --git a/inference/inference_config.py b/inference/inference_config.py index 812d579f4..b82d968b3 100644 --- a/inference/inference_config.py +++ b/inference/inference_config.py @@ -1,37 +1,41 @@ import os from pydantic import BaseModel, validator, ConfigDict from pydantic_yaml import parse_yaml_raw_as -from typing import List, Dict +from typing import List, Dict, Union -IPEX_PRECISION_BF16 = 'bf16' -IPEX_PRECISION_FP32 = 'fp32' +IPEX_PRECISION_BF16 = "bf16" +IPEX_PRECISION_FP32 = "fp32" DEVICE_CPU = "cpu" DEVICE_HPU = "hpu" DEVICE_XPU = "xpu" DEVICE_CUDA = "cuda" + class Prompt(BaseModel): intro: str = "" human_id: str = "" bot_id: str = "" stop_words: List[str] = [] + class ModelConfig(BaseModel): trust_remote_code: bool = False - use_auth_token: str = None + use_auth_token: Union[str, None] = None load_in_4bit: bool = False + class Ipex(BaseModel): enabled: bool = True - precision: str = 'bf16' + precision: str = "bf16" - @validator('precision') + @validator("precision") def _check_precision(cls, v: str): if v: assert v in [IPEX_PRECISION_BF16, IPEX_PRECISION_FP32] return v + # for bigdl model class BigDLModelConfig(BaseModel): load_in_low_bit: str = "" @@ -42,16 +46,17 @@ def _check_load_in_low_bit(cls, v: str): assert v in ["sym_int4", "asym_int4", "sym_int5", "asym_int5", "sym_int8"] return v + class ModelDescription(BaseModel): - model_id_or_path: str = None + model_id_or_path: Union[str, None] = None bigdl: bool = False - tokenizer_name_or_path: str = None - chat_processor: str = None + tokenizer_name_or_path: Union[str, None] = None + chat_processor: Union[str, None] = None gpt_base_model: bool = False - quantized_model_id_or_path: str = None - quantization_type: str = None - peft_model_id_or_path: str = None - peft_type: str = None + quantized_model_id_or_path: Union[str, None] = None + quantization_type: Union[str, None] = None + peft_model_id_or_path: Union[str, None] = None + peft_type: Union[str, None] = None # only effective when device is hpu use_hpu_graphs: bool = True prompt: Prompt = Prompt() @@ -60,25 +65,26 @@ class ModelDescription(BaseModel): # prevent warning of protected namespaces # DO NOT TOUCH - model_config = ConfigDict(protected_namespaces=()) - - @validator('quantization_type') + model_config = ConfigDict(protected_namespaces=()) # type: ignore + + @validator("quantization_type") def _check_quant_type(cls, v: str): if v: - assert v in ["ipex_smoothquant", 'ipex_weightonly', 'llamacpp'] + assert v in ["ipex_smoothquant", "ipex_weightonly", "llamacpp"] return v - @validator('peft_type') + @validator("peft_type") def _check_perftype(cls, v: str): if v: assert v in ["lora", "adalora", "deltatuner"] return v + class InferenceConfig(BaseModel): host: str = "0.0.0.0" port: int = 8000 - name: str = None - route_prefix: str = None + name: str = "default" + route_prefix: Union[str, None] = None cpus_per_worker: int = 24 gpus_per_worker: int = 0 hpus_per_worker: int = 0 @@ -90,34 +96,35 @@ class InferenceConfig(BaseModel): # prevent warning of protected namespaces # DO NOT TOUCH - model_config = ConfigDict(protected_namespaces=()) + model_config = ConfigDict(protected_namespaces=()) # type: ignore - @validator('host') + @validator("host") def _check_host(cls, v: str): if v: assert v in ["0.0.0.0", "127.0.0.1"] return v - @validator('port') + @validator("port") def _check_port(cls, v: int): assert v > 0 & v < 65535 return v - @validator('device') + @validator("device") def _check_device(cls, v: str): if v: assert v.lower() in [DEVICE_CPU, DEVICE_XPU, DEVICE_CUDA, DEVICE_HPU] return v.lower() - @validator('workers_per_group') + @validator("workers_per_group") def _check_workers_per_group(cls, v: int): if v: assert v > 0 return v -all_models : Dict[str, InferenceConfig] = {} -base_models : Dict[str, InferenceConfig] = {} -_models : Dict[str, InferenceConfig] = {} + +all_models: Dict[str, InferenceConfig] = {} +base_models: Dict[str, InferenceConfig] = {} +_models: Dict[str, InferenceConfig] = {} _cur = os.path.dirname(os.path.abspath(__file__)) _models_folder = _cur + "/models" diff --git a/inference/logger.py b/inference/logger.py index 44768c1e6..d00283fca 100644 --- a/inference/logger.py +++ b/inference/logger.py @@ -18,9 +18,7 @@ import os from typing import Optional -LOG_FORMAT = ( - "[%(levelname)s %(asctime)s]{rank} %(filename)s: %(lineno)d " "%(message)s" -) +LOG_FORMAT = "[%(levelname)s %(asctime)s]{rank} %(filename)s: %(lineno)d " "%(message)s" def get_logger(name: str = None, rank: Optional[int] = None, **kwargs): diff --git a/inference/predictor.py b/inference/predictor.py index 8344ecf86..1965f3d59 100644 --- a/inference/predictor.py +++ b/inference/predictor.py @@ -2,12 +2,15 @@ import torch from transformers import AutoTokenizer, StoppingCriteriaList from inference.inference_config import InferenceConfig -from utils import max_input_len, StoppingCriteriaSub +from utils import StoppingCriteriaSub + class Predictor: def __init__(self, infer_conf: InferenceConfig) -> None: self.infer_conf = infer_conf - self.tokenizer = AutoTokenizer.from_pretrained(infer_conf.model_description.tokenizer_name_or_path) + self.tokenizer = AutoTokenizer.from_pretrained( + infer_conf.model_description.tokenizer_name_or_path + ) self.device = torch.device(infer_conf.device) # now deepspeed predictor don't have the model # so configure_tokenizer cannot be called @@ -19,21 +22,14 @@ def __init__(self, infer_conf: InferenceConfig) -> None: prompt = infer_conf.model_description.prompt stop_words = prompt.stop_words - stop_words_ids = [self.tokenizer(stop_word, return_tensors='pt').input_ids.squeeze() for stop_word in stop_words] + stop_words_ids = [ + self.tokenizer(stop_word, return_tensors="pt").input_ids.squeeze() + for stop_word in stop_words + ] self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) def tokenize_inputs(self, text): - if self.device.type == "hpu": - input_tokens = self.tokenizer( - text, - return_tensors="pt", - padding="max_length", - max_length=max_input_len(input_token_len), - ) - else: - input_tokens = self.tokenizer( - text, return_tensors="pt", padding=True - ) + input_tokens = self.tokenizer(text, return_tensors="pt", padding=True) return input_tokens.input_ids.to(device=self.device) def configure_tokenizer(self, model_name): @@ -48,13 +44,13 @@ def configure_tokenizer(self, model_name): if ( hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None - and not "chatglm" in model_name + and "chatglm" not in model_name ): tokenizer.pad_token_id = model.generation_config.pad_token_id if ( hasattr(model.generation_config, "eos_token_id") and model.generation_config.eos_token_id is not None - and not "chatglm" in model_name + and "chatglm" not in model_name ): tokenizer.eos_token_id = model.generation_config.eos_token_id if ( @@ -64,20 +60,18 @@ def configure_tokenizer(self, model_name): tokenizer.bos_token_id = model.generation_config.bos_token_id if tokenizer.pad_token_id is None: - model.generation_config.pad_token_id = ( - tokenizer.pad_token_id - ) = tokenizer.eos_token_id + model.generation_config.pad_token_id = tokenizer.pad_token_id = tokenizer.eos_token_id if model.generation_config.eos_token_id is None: model.generation_config.eos_token_id = tokenizer.eos_token_id - + if not model.config.is_encoder_decoder: tokenizer.padding_side = "left" if tokenizer.pad_token is None and tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token model.generation_config.pad_token_id = model.generation_config.eos_token_id - + def generate(self, prompt, **config): pass diff --git a/inference/predictor_deployment.py b/inference/predictor_deployment.py index 440d071f9..fff8375d1 100644 --- a/inference/predictor_deployment.py +++ b/inference/predictor_deployment.py @@ -24,7 +24,7 @@ import torch from transformers import TextIteratorStreamer from inference.inference_config import InferenceConfig -from typing import Union +from typing import Union, Dict, Any from starlette.responses import StreamingResponse from inference.api_openai_backend.openai_protocol import ModelResponse @@ -39,24 +39,31 @@ def __init__(self, infer_conf: InferenceConfig): if chat_processor_name: try: module = __import__("chat_process") - except: + except Exception: sys.path.append(os.path.dirname(__file__)) module = __import__("chat_process") chat_processor = getattr(module, chat_processor_name, None) if chat_processor is None: - raise ValueError(infer_conf.name + " deployment failed. chat_processor(" + chat_processor_name + ") does not exist.") + raise ValueError( + infer_conf.name + + " deployment failed. chat_processor(" + + chat_processor_name + + ") does not exist." + ) self.process_tool = chat_processor(**prompt.dict()) - + self.use_deepspeed = infer_conf.deepspeed if self.use_deepspeed: from deepspeed_predictor import DeepSpeedPredictor + self.predictor = DeepSpeedPredictor(infer_conf) self.streamer = self.predictor.get_streamer() else: from transformer_predictor import TransformerPredictor + self.predictor = TransformerPredictor(infer_conf) self.loop = asyncio.get_running_loop() - + def consume_streamer(self): for text in self.streamer: yield text @@ -74,10 +81,10 @@ async def consume_streamer_async(self, streamer: TextIteratorStreamer): await asyncio.sleep(0.001) async def __call__(self, http_request: Request) -> Union[StreamingResponse, str]: - json_request: str = await http_request.json() + json_request: Dict[str, Any] = await http_request.json() prompts = [] text = json_request["text"] - config = json_request["config"] if "config" in json_request else {} + config = json_request["config"] if "config" in json_request else {} streaming_response = json_request["stream"] if isinstance(text, list): if self.process_tool is not None: @@ -91,12 +98,21 @@ async def __call__(self, http_request: Request) -> Union[StreamingResponse, str] return self.predictor.generate(prompts, **config) if self.use_deepspeed: self.predictor.streaming_generate(prompts, self.streamer, **config) - return StreamingResponse(self.consume_streamer(), status_code=200, media_type="text/plain") + return StreamingResponse( + self.consume_streamer(), status_code=200, media_type="text/plain" + ) else: streamer = self.predictor.get_streamer() - self.loop.run_in_executor(None, functools.partial(self.predictor.streaming_generate, prompts, streamer, **config)) - return StreamingResponse(self.consume_streamer_async(streamer), status_code=200, media_type="text/plain") - + self.loop.run_in_executor( + None, + functools.partial(self.predictor.streaming_generate, prompts, streamer, **config), + ) + return StreamingResponse( + self.consume_streamer_async(streamer), + status_code=200, + media_type="text/plain", + ) + async def stream_response(self, prompt, config): prompts = [] if isinstance(prompt, list): @@ -113,7 +129,10 @@ async def stream_response(self, prompt, config): response_handle = self.consume_streamer() else: streamer = self.predictor.get_streamer() - self.loop.run_in_executor(None, functools.partial(self.predictor.streaming_generate, prompts, streamer, **config)) + self.loop.run_in_executor( + None, + functools.partial(self.predictor.streaming_generate, prompts, streamer, **config), + ) response_handle = self.consume_streamer_async(streamer) async for output in response_handle: model_response = ModelResponse( diff --git a/inference/serve.py b/inference/serve.py index 05b838b2b..3fd59180c 100644 --- a/inference/serve.py +++ b/inference/serve.py @@ -23,6 +23,7 @@ from predictor_deployment import PredictorDeployment from inference.inference_config import ModelDescription, InferenceConfig, all_models + def get_deployed_models(args): # serve all pre-defined models, or model from MODEL_TO_SERVE env, if no model argument specified if args.model is None and args.config_file is None: @@ -33,11 +34,13 @@ def get_deployed_models(args): print("reading from config file, " + args.config_file) with open(args.config_file, "r") as f: infer_conf = parse_yaml_raw_as(InferenceConfig, f) - else: # args.model should be set + else: # args.model should be set print("reading from command line, " + args.model) model_desc = ModelDescription() model_desc.model_id_or_path = args.model - model_desc.tokenizer_name_or_path = args.tokenizer if args.tokenizer is not None else args.model + model_desc.tokenizer_name_or_path = ( + args.tokenizer if args.tokenizer is not None else args.model + ) infer_conf = InferenceConfig(model_description=model_desc) infer_conf.host = "127.0.0.1" if args.serve_local_only else "0.0.0.0" infer_conf.port = args.port @@ -51,29 +54,69 @@ def get_deployed_models(args): deployments = {} for model_id, infer_conf in model_list.items(): ray_actor_options = get_deployment_actor_options(infer_conf) - deployments[model_id] = PredictorDeployment.options(ray_actor_options=ray_actor_options).bind(infer_conf) + deployments[model_id] = PredictorDeployment.options( + ray_actor_options=ray_actor_options + ).bind(infer_conf) return deployments, model_list + # make it unittest friendly def main(argv=None): # args import argparse + parser = argparse.ArgumentParser(description="Model Serve Script", add_help=True) - parser.add_argument("--config_file", type=str, help="Inference configuration file in YAML. If specified, all other arguments will be ignored.") + parser.add_argument( + "--config_file", + type=str, + help="Inference configuration file in YAML. If specified, all other arguments will be ignored.", + ) parser.add_argument("--model", default=None, type=str, help="Model name or path.") parser.add_argument("--tokenizer", default=None, type=str, help="Tokenizer name or path.") parser.add_argument("--port", default=8000, type=int, help="The port of deployment address.") - parser.add_argument("--route_prefix", default=None, type=str, help="The route prefix for HTTP requests.") + parser.add_argument( + "--route_prefix", + default=None, + type=str, + help="The route prefix for HTTP requests.", + ) parser.add_argument("--cpus_per_worker", default="24", type=int, help="CPUs per worker.") - parser.add_argument("--gpus_per_worker", default=0, type=float, help="GPUs per worker, used when --device is cuda.") - parser.add_argument("--hpus_per_worker", default=0, type=float, help="HPUs per worker, used when --device is hpu.") - parser.add_argument("--deepspeed", action='store_true', help="Enable deepspeed inference.") - parser.add_argument("--workers_per_group", default="2", type=int, help="Workers per group, used with --deepspeed.") - parser.add_argument("--ipex", action='store_true', help="Enable ipex optimization.") + parser.add_argument( + "--gpus_per_worker", + default=0, + type=float, + help="GPUs per worker, used when --device is cuda.", + ) + parser.add_argument( + "--hpus_per_worker", + default=0, + type=float, + help="HPUs per worker, used when --device is hpu.", + ) + parser.add_argument("--deepspeed", action="store_true", help="Enable deepspeed inference.") + parser.add_argument( + "--workers_per_group", + default="2", + type=int, + help="Workers per group, used with --deepspeed.", + ) + parser.add_argument("--ipex", action="store_true", help="Enable ipex optimization.") parser.add_argument("--device", default="cpu", type=str, help="cpu, xpu, hpu or cuda.") - parser.add_argument("--serve_local_only", action="store_true", help="Only support local access to url.") - parser.add_argument("--serve_simple", action="store_true", help="Whether to serve OpenAI-compatible API for all models or serve simple endpoint based on model conf files.") - parser.add_argument("--keep_serve_terminal", action="store_true", help="Whether to keep serve terminal.") + parser.add_argument( + "--serve_local_only", + action="store_true", + help="Only support local access to url.", + ) + parser.add_argument( + "--serve_simple", + action="store_true", + help="Whether to serve OpenAI-compatible API for all models or serve simple endpoint based on model conf files.", + ) + parser.add_argument( + "--keep_serve_terminal", + action="store_true", + help="Whether to keep serve terminal.", + ) args = parser.parse_args(argv) @@ -85,7 +128,8 @@ def main(argv=None): serve_run(deployments, model_list) else: # provide OpenAI compatible api to run LLM models - # all models are served under the same URL and then accessed through model_id, so it needs to pass in a unified URL. + # all models are served under the same URL and then accessed + # through model_id, so it needs to pass in a unified URL. host = "127.0.0.1" if args.serve_local_only else "0.0.0.0" rp = args.route_prefix if args.route_prefix else "" route_prefix = "/{}".format(rp) @@ -97,5 +141,6 @@ def main(argv=None): else: print(msg) + if __name__ == "__main__": main(sys.argv[1:]) diff --git a/inference/transformer_predictor.py b/inference/transformer_predictor.py index 406400878..2784016b9 100644 --- a/inference/transformer_predictor.py +++ b/inference/transformer_predictor.py @@ -1,17 +1,22 @@ import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from transformers import AutoModelForCausalLM, AutoConfig from transformers import TextIteratorStreamer from inference.inference_config import InferenceConfig, IPEX_PRECISION_BF16 from predictor import Predictor from utils import get_torch_dtype + class TransformerPredictor(Predictor): def __init__(self, infer_conf: InferenceConfig): super().__init__(infer_conf) - + model_desc = infer_conf.model_description model_config = model_desc.config - hf_config = AutoConfig.from_pretrained(model_desc.model_id_or_path, torchscript=True, trust_remote_code=model_config.trust_remote_code) + hf_config = AutoConfig.from_pretrained( + model_desc.model_id_or_path, + torchscript=True, + trust_remote_code=model_config.trust_remote_code, + ) if self.device.type == "hpu": from optimum.habana.transformers.modeling_utils import ( @@ -22,7 +27,10 @@ def __init__(self, infer_conf: InferenceConfig): # get correct torch type for loading HF model torch_dtype = get_torch_dtype(infer_conf, hf_config) if model_desc.bigdl: - from bigdl.llm.transformers import AutoModelForCausalLM as BigDLAutoModelForCLM + from bigdl.llm.transformers import ( + AutoModelForCausalLM as BigDLAutoModelForCLM, + ) + bmodel_config = {} bmodel_config.update(model_config.dict()) if model_desc.bigdl_config.load_in_low_bit: @@ -32,7 +40,7 @@ def __init__(self, infer_conf: InferenceConfig): torch_dtype=torch_dtype, config=hf_config, low_cpu_mem_usage=True, - **bmodel_config + **bmodel_config, ) else: model = AutoModelForCausalLM.from_pretrained( @@ -40,13 +48,15 @@ def __init__(self, infer_conf: InferenceConfig): torch_dtype=torch_dtype, config=hf_config, low_cpu_mem_usage=True, - **model_config.dict() + **model_config.dict(), ) if model_desc.peft_model_id_or_path: from peft import PeftModel + model = PeftModel.from_pretrained(model, model_desc.peft_model_id_or_path) if model_desc.peft_type == "deltatuner": from deltatuner import DeltaTunerModel + model = DeltaTunerModel.from_pretrained(model, model_desc.peft_model_id_or_path) model = model.merge_and_unload() @@ -54,7 +64,10 @@ def __init__(self, infer_conf: InferenceConfig): if self.device.type == "hpu": self.use_hpu_graphs = model_desc.use_hpu_graphs if self.use_hpu_graphs: - from habana_frameworks.torch.hpu import wrap_in_hpu_graph # pylint: disable=E0401 + from habana_frameworks.torch.hpu import ( + wrap_in_hpu_graph, + ) # pylint: disable=E0401 + model = wrap_in_hpu_graph(model) else: print("Warning: use_hpu_graphs is set to False. This will hurt the performance.") @@ -66,12 +79,16 @@ def __init__(self, infer_conf: InferenceConfig): import intel_extension_for_pytorch as ipex torch._C._jit_set_texpr_fuser_enabled(False) - try: ipex._C.disable_jit_linear_repack() - except: pass + try: + ipex._C.disable_jit_linear_repack() + except Exception: + pass model = ipex.optimize_transformers( model.eval(), - dtype=torch.bfloat16 if infer_conf.ipex.precision == IPEX_PRECISION_BF16 else torch.float32, - inplace=True + dtype=torch.bfloat16 + if infer_conf.ipex.precision == IPEX_PRECISION_BF16 + else torch.float32, + inplace=True, ) self.model = model @@ -88,20 +105,22 @@ def _process_config(self, config): def streaming_generate(self, prompt, streamer, **config): self._process_config(config) input_ids = self.tokenize_inputs(prompt) - self.model.generate(input_ids, - stopping_criteria=self.stopping_criteria, - streamer=streamer, - **config) + self.model.generate( + input_ids, + stopping_criteria=self.stopping_criteria, + streamer=streamer, + **config, + ) def generate(self, prompt, **config): self._process_config(config) input_ids = self.tokenize_inputs(prompt) gen_tokens = self.model.generate( - input_ids, - stopping_criteria=self.stopping_criteria, - **config + input_ids, stopping_criteria=self.stopping_criteria, **config ) return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0] def get_streamer(self): - return TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=0, skip_special_tokens=True) + return TextIteratorStreamer( + self.tokenizer, skip_prompt=True, timeout=0, skip_special_tokens=True + ) diff --git a/inference/utils.py b/inference/utils.py index 3c7e47fe0..c0bd3b14a 100644 --- a/inference/utils.py +++ b/inference/utils.py @@ -17,6 +17,8 @@ from transformers import StoppingCriteria import torch from inference.inference_config import InferenceConfig, DEVICE_CPU +from typing import Dict, Any + def get_deployment_actor_options(infer_conf: InferenceConfig): _ray_env_key = "env_vars" @@ -25,15 +27,16 @@ def get_deployment_actor_options(infer_conf: InferenceConfig): "KMP_BLOCKTIME": "1", "KMP_SETTINGS": "1", "KMP_AFFINITY": "granularity=fine,compact,1,0", - "MALLOC_CONF": "oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000" + "MALLOC_CONF": "oversize_threshold:1,background_thread:true,\ + metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000", } - runtime_env = {_ray_env_key: {}} + runtime_env: Dict[str, Any] = {_ray_env_key: {}} if infer_conf.ipex.enabled: runtime_env[_ray_env_key].update(_predictor_runtime_env_ipex) if infer_conf.deepspeed: runtime_env[_ray_env_key]["DS_ACCELERATOR"] = infer_conf.device # now PredictorDeployment itself is a worker, we should require resources for it - ray_actor_options = {"runtime_env": runtime_env} + ray_actor_options: Dict[str, Any] = {"runtime_env": runtime_env} if infer_conf.device == "cpu": ray_actor_options["num_cpus"] = infer_conf.cpus_per_worker elif infer_conf.device == "cuda": @@ -45,19 +48,20 @@ def get_deployment_actor_options(infer_conf: InferenceConfig): pass return ray_actor_options -class StoppingCriteriaSub(StoppingCriteria): - def __init__(self, stops = [], encounters=1): +class StoppingCriteriaSub(StoppingCriteria): + def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: - length = 1 if len(stop.size())==0 else stop.size()[0] + length = 1 if len(stop.size()) == 0 else stop.size()[0] if torch.all((stop == input_ids[0][-length:])).item(): return True return False + # used in inference with Gaudi def max_input_len(input_text_length): if input_text_length <= 128: @@ -70,22 +74,24 @@ def max_input_len(input_text_length): print("Max support length is 4096") return 4096 + def get_torch_dtype(infer_conf: InferenceConfig, hf_config) -> torch.dtype: - ''' - return torch default dtype, a.k.a float32, if it's cpu only inference without ipex because - bfloat16 is too slow and float16 is not supported in CPU - ''' + """ + return torch default dtype, a.k.a float32, if it's cpu only inference without + ipex because bfloat16 is too slow and float16 is not supported in CPU + """ if hf_config is None or is_cpu_without_ipex(infer_conf): return torch.get_default_dtype() - if hasattr(hf_config, 'torch_dtype'): + if hasattr(hf_config, "torch_dtype"): t = hf_config.torch_dtype if t: return t - if hasattr(hf_config, '__getitem__'): - t = hf_config['torch_dtype'] + if hasattr(hf_config, "__getitem__"): + t = hf_config["torch_dtype"] if t: return t return torch.get_default_dtype() + def is_cpu_without_ipex(infer_conf: InferenceConfig) -> bool: return (not infer_conf.ipex.enabled) and infer_conf.device == DEVICE_CPU diff --git a/pretrain/backend/deepspeed_backend.py b/pretrain/backend/deepspeed_backend.py index 961e9c933..cd185a7ab 100644 --- a/pretrain/backend/deepspeed_backend.py +++ b/pretrain/backend/deepspeed_backend.py @@ -2,7 +2,6 @@ import os from dataclasses import dataclass from datetime import timedelta -from typing import Optional import deepspeed import torch.distributed as dist @@ -14,18 +13,17 @@ from ray.train._internal.worker_group import WorkerGroup from ray.train._internal.utils import get_address_and_port from ray.train.constants import DEFAULT_NCCL_SOCKET_IFNAME -from dataclasses import dataclass logger = logging.getLogger(__name__) @dataclass class TorchConfig(RayTorchConfig): - @property def backend_cls(self): return DeepSpeedBackend + def _set_nccl_network_interface(): """Set the appropriate NCCL network interface to use.""" @@ -39,6 +37,7 @@ def _set_nccl_network_interface(): ) os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_NCCL_SOCKET_IFNAME + def _setup_deepspeed_process_group( backend: str, world_rank: int, @@ -83,7 +82,6 @@ def _setup_deepspeed_process_group( ) os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" - deepspeed.init_distributed( dist_backend=backend, auto_mpi_discovery=False, @@ -93,8 +91,8 @@ def _setup_deepspeed_process_group( timeout=timedelta(seconds=timeout_s), ) -class DeepSpeedBackend(_TorchBackend): +class DeepSpeedBackend(_TorchBackend): def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig): if dist.is_available(): # Set the appropriate training backend. @@ -109,9 +107,7 @@ def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig): if backend == "nccl": worker_group.execute(_set_nccl_network_interface) - master_addr, master_port = worker_group.execute_single( - 0, get_address_and_port - ) + master_addr, master_port = worker_group.execute_single(0, get_address_and_port) if backend_config.init_method == "env": def set_env_vars(addr, port): @@ -145,5 +141,3 @@ def set_env_vars(addr, port): ray.get(setup_futures) else: raise RuntimeError("Distributed torch is not available.") - - diff --git a/pretrain/backend/habana_backend.py b/pretrain/backend/habana_backend.py index 3c4dbc9fc..125987ba2 100644 --- a/pretrain/backend/habana_backend.py +++ b/pretrain/backend/habana_backend.py @@ -3,26 +3,23 @@ from ray.train._internal.worker_group import WorkerGroup from dataclasses import dataclass + @dataclass class TorchConfig(RayTorchConfig): - @property def backend_cls(self): return EnableHabanaBackend + def habana_import(): try: - import habana_frameworks.torch + import habana_frameworks.torch # noqa: F401 except ImportError as habana_not_exist: - raise ImportError( - "Please install habana_frameworks" - ) from habana_not_exist + raise ImportError("Please install habana_frameworks") from habana_not_exist -class EnableHabanaBackend(_TorchBackend): +class EnableHabanaBackend(_TorchBackend): def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): for i in range(len(worker_group)): worker_group.execute_single_async(i, habana_import) super().on_start(worker_group, backend_config) - - diff --git a/pretrain/megatron_deepspeed_pretrain.py b/pretrain/megatron_deepspeed_pretrain.py index f11e784b7..aa5002711 100644 --- a/pretrain/megatron_deepspeed_pretrain.py +++ b/pretrain/megatron_deepspeed_pretrain.py @@ -1,4 +1,5 @@ -import os, sys +import os +import sys from typing import Any, Dict import ray @@ -6,21 +7,21 @@ from ray.air.config import ScalingConfig from ray.air import RunConfig, FailureConfig -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import common import importlib -loader = importlib.util.find_spec('habana_frameworks') + +loader = importlib.util.find_spec("habana_frameworks") if loader is not None: from backend.habana_backend import TorchConfig else: from ray.train.torch import TorchConfig -from megatron.training import pretrain +from megatron.training import pretrain def train_func(config: Dict[str, Any]): - cwd = config.get("cwd") if cwd: os.chdir(cwd) @@ -35,33 +36,40 @@ def train_func(config: Dict[str, Any]): if pretrain_module_name is not None: pretrain_module = importlib.import_module(pretrain_module_name) else: - pretrain_module = importlib.import_module('pretrain_gpt') + pretrain_module = importlib.import_module("pretrain_gpt") else: raise ImportError("Please set megatron_deepspeed_path in config") - - common.init(config) - megatron_config = config.get('megatron_config', {}) - - if hasattr(pretrain_module, 'ModelType'): - pretrain(pretrain_module.train_valid_test_datasets_provider, - pretrain_module.model_provider, - pretrain_module.ModelType.encoder_or_decoder, - pretrain_module.forward_step, - data_post_process=pretrain_module.data_post_process, - external_args=megatron_config) - elif hasattr(pretrain_module, 'llama_argument_handler'): - pretrain(pretrain_module.train_valid_test_datasets_provider, - pretrain_module.model_provider, - pretrain_module.forward_step, - pretrain_module.llama_argument_handler, - external_args=megatron_config) + + common.init(config) # type: ignore + megatron_config = config.get("megatron_config", {}) + + if hasattr(pretrain_module, "ModelType"): + pretrain( + pretrain_module.train_valid_test_datasets_provider, + pretrain_module.model_provider, + pretrain_module.ModelType.encoder_or_decoder, + pretrain_module.forward_step, + data_post_process=pretrain_module.data_post_process, + external_args=megatron_config, + ) + elif hasattr(pretrain_module, "llama_argument_handler"): + pretrain( + pretrain_module.train_valid_test_datasets_provider, + pretrain_module.model_provider, + pretrain_module.forward_step, + pretrain_module.llama_argument_handler, + external_args=megatron_config, + ) else: - pretrain(pretrain_module.train_valid_test_datasets_provider, - pretrain_module.model_provider, - pretrain_module.forward_step, - external_args=megatron_config) + pretrain( + pretrain_module.train_valid_test_datasets_provider, + pretrain_module.model_provider, + pretrain_module.forward_step, + external_args=megatron_config, + ) + -def main(external_config = None): +def main(external_config=None): config = common.Config() if external_config is not None: config.merge(external_config) @@ -71,7 +79,6 @@ def main(external_config = None): ray_init_config = ray_config.get("init", {}) common.logger.info(f"ray init config: {ray_init_config}") - runtime_env = ray_init_config.get("runtime_env") ray.init(**ray_init_config) scaling_config = ScalingConfig(**ray_config.get("scaling_config", {})) @@ -90,11 +97,11 @@ def main(external_config = None): train_func, train_loop_config=config, scaling_config=scaling_config, - torch_config = torch_config, - run_config = run_config + torch_config=torch_config, + run_config=run_config, ) - results = trainer.fit() + trainer.fit() + if __name__ == "__main__": main() - diff --git a/pretrain/plugin/group_dataset.py b/pretrain/plugin/group_dataset.py index c4a538230..93838f7bf 100644 --- a/pretrain/plugin/group_dataset.py +++ b/pretrain/plugin/group_dataset.py @@ -3,6 +3,7 @@ from common.dataset import Dataset + class GroupDataset(Dataset): def __call__(self, config): path = config.get("path") @@ -18,4 +19,3 @@ def get_all_file(self, path): files = os.listdir(path) list.sort(files) return [os.path.join(path, file) for file in files] - diff --git a/pretrain/plugin/hf_pretrainer.py b/pretrain/plugin/hf_pretrainer.py index 23390c1bc..d9aafdfc5 100755 --- a/pretrain/plugin/hf_pretrainer.py +++ b/pretrain/plugin/hf_pretrainer.py @@ -2,15 +2,11 @@ import math import logging import sys -import torch from torch.utils.data import DataLoader, Dataset -from .pretrainer import PreTrainer -from pathlib import Path import common from common import dataprocesser from common.logging import logger import evaluate -from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer from typing import Optional from transformers import ( HfArgumentParser, @@ -19,21 +15,24 @@ import transformers from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version, send_example_telemetry -from transformers.utils.versions import require_version from transformers import Trainer, TrainingArguments from common.trainer import Trainer as RayTrainer + use_habana = True import importlib -loader = importlib.util.find_spec('habana_frameworks') + +loader = importlib.util.find_spec("habana_frameworks") if loader is not None: from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments from optimum.habana.utils import set_seed + try: from optimum.habana.utils import check_optimum_habana_min_version except ImportError: def check_optimum_habana_min_version(*a, **b): return () + finally: # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. check_min_version("4.33.0") @@ -41,15 +40,18 @@ def check_optimum_habana_min_version(*a, **b): else: use_habana = False -class HFCustomerSamplerTrainer(GaudiTrainer if use_habana else Trainer) : - def set_sampler(self, sampler) : + +class HFCustomerSamplerTrainer(GaudiTrainer if use_habana else Trainer): # type: ignore + def set_sampler(self, sampler): self.customer_sampler = sampler def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") - train_dataloader, _, _ = self.customer_sampler.prepare(None, (self.train_dataset, None, None)) + train_dataloader, _, _ = self.customer_sampler.prepare( + None, (self.train_dataset, None, None) + ) return train_dataloader def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: @@ -59,7 +61,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa _, eval_dataloader, _ = self.customer_sampler.prepare(None, (None, eval_dataset, None)) return eval_dataloader - + def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: test_dataset = test_dataset if test_dataset is not None else self.test_dataset if test_dataset is None: @@ -69,9 +71,6 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: return test_dataloader - - - class HuggingFacePreTrainer(RayTrainer): def __init__(self, config): self.config = config @@ -84,20 +83,16 @@ def __init__(self, config): self.starting_episode = 0 self.mode = "ddp" - def prepare(self, model, tokenizer, dataset, optimizer, accelerator): - self.train_dataset, self.eval_dataset, self.test_dataset = dataset - - def train(self): if use_habana: del os.environ["ACCELERATE_TORCH_DEVICE"] parser = HfArgumentParser((GaudiTrainingArguments)) else: parser = HfArgumentParser((TrainingArguments)) - + training_args = parser.parse_dict(self.config.get("training_config", None))[0] send_example_telemetry("Ray_HF_Trainer", training_args) @@ -117,7 +112,7 @@ def train(self): transformers.utils.logging.set_verbosity(log_level) transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() - + if use_habana: gaudi_config = GaudiConfig.from_pretrained( training_args.gaudi_config_name, @@ -127,7 +122,11 @@ def train(self): ) # Log on each process the small summary: - mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast or gaudi_config.use_habana_mixed_precision + mixed_precision = ( + training_args.bf16 + or gaudi_config.use_torch_autocast + or gaudi_config.use_habana_mixed_precision + ) logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, " + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, " @@ -137,7 +136,11 @@ def train(self): # Detecting last checkpoint. last_checkpoint = None - if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): last_checkpoint = get_last_checkpoint(training_args.output_dir) if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: raise ValueError( @@ -155,18 +158,18 @@ def train(self): set_seed(training_args.seed) model_config = self.config.get("model") - model_config['deepspeed_zero_stage'] = training_args.deepspeed_plugin.zero_stage + model_config["deepspeed_zero_stage"] = training_args.deepspeed_plugin.zero_stage if model_config: self.model = common.load_model(model_config) else: - common.logger.warn(f"No internal model plugin provided") + common.logger.warn("No internal model plugin provided") self.model.train() - + tokenizer_config = self.config.get("tokenizer") if tokenizer_config: self.tokenizer = common.load_tokenizer(tokenizer_config) else: - common.logger.warn(f"No internal tokenizer plugin provided") + common.logger.warn("No internal tokenizer plugin provided") # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. @@ -174,8 +177,8 @@ def train(self): if len(self.tokenizer) > embedding_size: self.model.resize_token_embeddings(len(self.tokenizer)) - if training_args.do_eval: + def preprocess_logits_for_metrics(logits, labels): if isinstance(logits, tuple): # Depending on the model and config, logits may contain extra tensors, @@ -206,9 +209,11 @@ def compute_metrics(eval_preds): # Data collator will default to DataCollatorWithPadding, so we change it. data_collator=default_data_collator, compute_metrics=compute_metrics if training_args.do_eval else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, + preprocess_logits_for_metrics=preprocess_logits_for_metrics + if training_args.do_eval + else None, ) - + else: trainer = HFCustomerSamplerTrainer( model=self.model, @@ -219,13 +224,14 @@ def compute_metrics(eval_preds): # Data collator will default to DataCollatorWithPadding, so we change it. data_collator=default_data_collator, compute_metrics=compute_metrics if training_args.do_eval else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, + preprocess_logits_for_metrics=preprocess_logits_for_metrics + if training_args.do_eval + else None, ) print("use the GPU for training") trainer.set_sampler(self.dataprocesser) - # Training if training_args.do_train: checkpoint = None @@ -238,9 +244,11 @@ def compute_metrics(eval_preds): metrics = train_result.metrics - #if data_args.streaming: - metrics["train_samples"] = training_args.max_steps * training_args.per_device_train_batch_size - #else: + # if data_args.streaming: + metrics["train_samples"] = ( + training_args.max_steps * training_args.per_device_train_batch_size + ) + # else: # max_train_samples = ( # data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) # ) @@ -271,7 +279,7 @@ def compute_metrics(eval_preds): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) - #kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} + # kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} # if data_args.dataset_name is not None: # kwargs["dataset_tags"] = data_args.dataset_name # if data_args.dataset_config_name is not None: @@ -281,6 +289,6 @@ def compute_metrics(eval_preds): # kwargs["dataset"] = data_args.dataset_name if training_args.push_to_hub: - trainer.push_to_hub(**kwargs) + trainer.push_to_hub() else: - trainer.create_model_card(**kwargs) \ No newline at end of file + trainer.create_model_card() diff --git a/pretrain/plugin/huggingface_model_from_config.py b/pretrain/plugin/huggingface_model_from_config.py index 65731f111..5ce38da8f 100644 --- a/pretrain/plugin/huggingface_model_from_config.py +++ b/pretrain/plugin/huggingface_model_from_config.py @@ -1,42 +1,41 @@ import torch import math import transformers -from transformers import PreTrainedModel from common.model.model import Model + # for huggingface model weight random initialization class HuggingFaceModelFromConfig(Model): - def __call__(self, config): name = config.get("name") self.model_config = config.get("config", {}) self.auto_config = None if name is not None: - self.auto_config = transformers.AutoConfig.from_pretrained(pretrained_model_name_or_path=name, **self.model_config) + self.auto_config = transformers.AutoConfig.from_pretrained( + pretrained_model_name_or_path=name, **self.model_config + ) else: self.auto_config = transformers.AutoConfig.for_model(**self.model_config) self.model = transformers.AutoModelForCausalLM.from_config(self.auto_config) - - if config.get('deepspeed_zero_stage', None) != 3: + + if config.get("deepspeed_zero_stage", None) != 3: self.init_weights() - + return self.model def init_weights(self): - - if self.model_config.get('init_method', None): + if self.model_config.get("init_method", None): init_method = self.get_init_methods(self.model_config) self.recursive_initialization(self.model, init_method, init_method) return self.model - + def recursive_initialization( - self, - module, - init_method_linear, - init_method_embedding, - ): - + self, + module, + init_method_linear, + init_method_embedding, + ): if isinstance(module, torch.nn.Linear): init_method_linear(module.weight) if module.bias is not None: @@ -48,22 +47,22 @@ def recursive_initialization( for child in module.children(): self.recursive_initialization(child, init_method_linear, init_method_embedding) - def get_init_methods(self, init_config): - init_method = init_config.get('init_method') - init_method_std = init_config.get('init_method_std', 0.02) - + init_method = init_config.get("init_method") + init_method_std = init_config.get("init_method_std", 0.02) + num_layers = self.auto_config.num_hidden_layers hidden_size = self.auto_config.hidden_size - + def _get(name): if name == "normal": return init_method_normal( - init_method_std, + init_method_std, ) elif name == "scaled_normal": return scaled_init_method_normal( - init_method_std, num_layers, + init_method_std, + num_layers, ) elif name == "xavier_uniform": return xavier_uniform_init_method() @@ -71,11 +70,12 @@ def _get(name): return xavier_normal_init_method() elif name == "wang_init": return wang_init_method( - num_layers, hidden_size, + num_layers, + hidden_size, ) elif name == "small_init": return small_init_init_method( - hidden_size, + hidden_size, ) else: raise NotImplementedError(f"Unknown init method {name}") @@ -92,9 +92,7 @@ def init_(tensor): return init_ -def scaled_init_method_normal( - sigma, num_layers, use_mup_outer=False, mup_init_scale=1.0 -): +def scaled_init_method_normal(sigma, num_layers, use_mup_outer=False, mup_init_scale=1.0): """Init method based on N(0, sigma/sqrt(2*num_layers).""" std = sigma / math.sqrt(2.0 * num_layers) @@ -106,7 +104,8 @@ def init_(tensor): def xavier_uniform_init_method(use_mup_outer=False, mup_init_scale=1.0): """Fills the input Tensor with values according to the method described in Understanding the difficulty of - training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution.""" + training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution. + """ def init_(tensor): return torch.nn.init.xavier_uniform_(tensor) @@ -116,7 +115,8 @@ def init_(tensor): def xavier_normal_init_method(use_mup_outer=False, mup_init_scale=1.0): """Fills the input Tensor with values according to the method described in Understanding the difficulty of - training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution.""" + training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution. + """ def init_(tensor): return torch.nn.init.xavier_normal_(tensor) @@ -126,7 +126,8 @@ def init_(tensor): def small_init_init_method(dim, use_mup_outer=False, mup_init_scale=1.0): """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving - the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution.""" + the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution. + """ std = math.sqrt(2 / (5 * dim)) def init_(tensor): diff --git a/pretrain/plugin/megatron_dataset.py b/pretrain/plugin/megatron_dataset.py index 36fc33c9a..944c6b53b 100644 --- a/pretrain/plugin/megatron_dataset.py +++ b/pretrain/plugin/megatron_dataset.py @@ -1,19 +1,16 @@ -import numpy as np - from megatron import get_args, print_rank_0 from megatron.training import build_train_valid_test_datasets, update_train_iters from megatron.data import gpt_dataset -from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from common.dataset import Dataset + class MegatronDataset(Dataset): def __call__(self, config): def _train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" args = get_args() - print_rank_0('> building train, validation, and test datasets ' - 'for GPT ...') + print_rank_0("> building train, validation, and test datasets " "for GPT ...") train_ds, valid_ds, test_ds = gpt_dataset.build_train_valid_test_datasets( data_prefix=args.data_path, data_impl=args.data_impl, @@ -25,7 +22,8 @@ def _train_valid_test_datasets_provider(train_val_test_num_samples): train_data_prefix=args.train_data_path, valid_data_prefix=args.valid_data_path, test_data_prefix=args.test_data_path, - data_cache_path=args.data_cache_path) + data_cache_path=args.data_cache_path, + ) print_rank_0("> finished creating GPT datasets ...") return train_ds, valid_ds, test_ds diff --git a/pretrain/plugin/megatron_pretrainer.py b/pretrain/plugin/megatron_pretrainer.py index f884806bd..4ee76bfa3 100644 --- a/pretrain/plugin/megatron_pretrainer.py +++ b/pretrain/plugin/megatron_pretrainer.py @@ -15,6 +15,7 @@ from .pretrainer import PreTrainer from common.logging import logger + class MegatronPreTrainer(PreTrainer): def __init__(self, config): self.config = config @@ -33,7 +34,9 @@ def _coordinate(self, accelerator): self.size = accelerator.num_processes self.local_rank = accelerator.local_process_index accelerator.wait_for_everyone() - logger.info(f"coordinate workers finish, cluster size:{self.size} worker rank:{self.rank} worker local_rank:{self.local_rank}") + logger.info( + f"coordinate workers finish, cluster size:{self.size} worker rank:{self.rank} worker local_rank:{self.local_rank}" + ) def _get_all_checkpoint_step(self, root_path): if not os.path.exists(f"{root_path}"): @@ -61,16 +64,16 @@ def _get_latest_checkpoint_step(self, root_path): def recovery(self, config): if config is None or config is {}: - logger.warning(f"checkpoint is empty, skip") + logger.warning("checkpoint is empty, skip") return 0 root_path = config.get("root_path") if root_path is None: - logger.error(f"checkpoint root_path is None, exit") + logger.error("checkpoint root_path is None, exit") exit(1) step = self._get_latest_checkpoint_step(root_path) if step is None or step == -1: - logger.warning(f"step is None, skip") + logger.warning("step is None, skip") return 0 local_checkpoint_path = self._get_local_path(root_path, step) try: @@ -87,23 +90,34 @@ def recovery(self, config): self.optimizer.load_state_dict(optimizer_state) # update lr_scheduler status - if Path.exists(checkpoint_dir / "lr_scheduler.pt") and hasattr(self, "lr_scheduler"): - scheduler_state = torch.load(checkpoint_dir / "lr_schduler.pt", map_location="cpu") + if Path.exists(checkpoint_dir / "lr_scheduler.pt") and hasattr( + self, "lr_scheduler" + ): + scheduler_state = torch.load( + checkpoint_dir / "lr_schduler.pt", map_location="cpu" + ) self.lr_scheduler.load_state_dict(scheduler_state) logger.info(f"recovery to step {int(step)}") self.starting_step = int(step) + 1 - except Exception as e: - logger.warning(f"recovery error", exc_info=True) + except Exception: + logger.warning("recovery error", exc_info=True) return 0 - def _get_lr_scheduler(self, lr_scheduler_config, optimizer, num_train_epochs, num_steps_per_epoch, accelerator): + def _get_lr_scheduler( + self, + lr_scheduler_config, + optimizer, + num_train_epochs, + num_steps_per_epoch, + accelerator, + ): # gradient_accumulation_steps = accelerator.gradient_accumulation_steps # num_update_steps_per_epoch = math.ceil(num_steps_per_epoch / gradient_accumulation_steps) enable = lr_scheduler_config.get("enable", False) if not enable: return None - max_train_steps = lr_scheduler_config.get("max_train_steps") + max_train_steps = lr_scheduler_config.get("max_train_steps") lr_scheduler_type = lr_scheduler_config.get("lr_scheduler_type", "linear") num_warmup_steps = lr_scheduler_config.get("num_warmup_steps", 0) @@ -125,13 +139,23 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator): logger.info(f"model embedding size: {embedding_size}") if len(tokenizer) > embedding_size: model.resize_token_embeddings(len(tokenizer)) - logger.warning(f"model embedding size resize to {len(tokenizer)} because of tokenizer size") + logger.warning( + f"model embedding size resize to {len(tokenizer)} because of tokenizer size" + ) lr_scheduler_config = self.config.get("lr_scheduler") if lr_scheduler_config: - num_steps_per_epoch = len(dataset) // self.config.get("dataprocesser").get("per_device_train_batch_size") + num_steps_per_epoch = len(dataset) // self.config.get("dataprocesser").get( + "per_device_train_batch_size" + ) num_train_epochs = self.config.get("num_train_epochs", 1) - lr_scheduler = self._get_lr_scheduler(lr_scheduler_config, optimizer, num_train_epochs, num_steps_per_epoch, accelerator) + lr_scheduler = self._get_lr_scheduler( + lr_scheduler_config, + optimizer, + num_train_epochs, + num_steps_per_epoch, + accelerator, + ) else: lr_scheduler = None @@ -151,7 +175,9 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator): # only ddp support pass - self.train_dataloader, self.eval_dataloader, _ = self.dataprocesser.prepare(tokenizer, dataset, step=self.starting_step) + self.train_dataloader, self.eval_dataloader, _ = self.dataprocesser.prepare( + tokenizer, dataset, step=self.starting_step + ) def train(self): checkpoint = self.config.get("checkpoint") @@ -162,15 +188,15 @@ def train(self): # megatron arguments args = get_args() - logger.info(f"start train") + logger.info("start train") self.model.train() start = time.time() for step, batch in enumerate(self.train_dataloader): step = step + self.starting_step batch["input_ids"] = batch["text"] - batch["labels"] = batch["text"] + batch["labels"] = batch["text"] del batch["text"] - #del batch["dummy_sample"] + # del batch["dummy_sample"] with self.accelerator.accumulate(self.model): outputs = self.model(**batch) loss = outputs.loss @@ -182,25 +208,36 @@ def train(self): self.lr_scheduler.step() self.optimizer.zero_grad() if step % log_step == 0: - logger.info(f"step:[{step}/{len(self.train_dataloader)}]\tlr:{self.lr_scheduler.get_last_lr() if self.lr_scheduler else None}\tloss:{loss}\tppl:{math.exp(loss)}\ttime:{time.time()-start}") + logger.info( + f"step:[{step}/{len(self.train_dataloader)}]\tlr:{self.lr_scheduler.get_last_lr() if self.lr_scheduler else None}\tloss:{loss}\tppl:{math.exp(loss)}\ttime:{time.time()-start}" + ) start = time.time() if max_train_step is not None: if step >= max_train_step: break - if self.eval_dataloader and args.eval_interval and step and step % args.eval_interval == 0: + if ( + self.eval_dataloader + and args.eval_interval + and step + and step % args.eval_interval == 0 + ): logger.info(f"start eval step {step}") self.model.eval() start = time.time() losses = [] for step, batch in enumerate(self.eval_dataloader): batch["input_ids"] = batch["text"] - batch["labels"] = batch["text"] + batch["labels"] = batch["text"] del batch["text"] with torch.no_grad(): outputs = self.model(**batch) loss = outputs.loss - losses.append(self.accelerator.gather_for_metrics(loss.repeat(batch["input_ids"].shape[0]))) + losses.append( + self.accelerator.gather_for_metrics( + loss.repeat(batch["input_ids"].shape[0]) + ) + ) if max_eval_step is not None: if step >= max_eval_step: break @@ -212,9 +249,11 @@ def train(self): except OverflowError: eval_loss = float("inf") perplexity = float("inf") - logger.info(f"eval step:[{step}/{len(self.train_dataloader)}]\tloss:[{eval_loss}]\tppl:[{perplexity}]\ttime:[{time.time()-start}]") + logger.info( + f"eval step:[{step}/{len(self.train_dataloader)}]\tloss:[{eval_loss}]\tppl:[{perplexity}]\ttime:[{time.time()-start}]" + ) - if checkpoint is not None and (step+1) % checkpoint_step == 0: + if checkpoint is not None and (step + 1) % checkpoint_step == 0: self.save(checkpoint, step) self.accelerator.wait_for_everyone() @@ -223,7 +262,9 @@ def train(self): logger.info(f"start save model to {output}") unwrapped_model = self.accelerator.unwrap_model(self.model) unwrapped_model.save_pretrained( - output, is_main_process=self.accelerator.is_main_process, save_function=self.accelerator.save + output, + is_main_process=self.accelerator.is_main_process, + save_function=self.accelerator.save, ) logger.info(f"finish save model to {output}") self.accelerator.wait_for_everyone() @@ -243,10 +284,13 @@ def _save(self, local_checkpoint_path): torch.save(unwrapped_model.state_dict(), os.path.join(tmpdir, "model.pt")) torch.save(self.optimizer.state_dict(), os.path.join(tmpdir, "optim.pt")) if self.lr_scheduler: - torch.save(self.lr_scheduler.state_dict(), os.path.join(tmpdir, "lr_schduler.pt")) + torch.save( + self.lr_scheduler.state_dict(), + os.path.join(tmpdir, "lr_schduler.pt"), + ) checkpoint = Checkpoint.from_directory(tmpdir) checkpoint.to_directory(local_checkpoint_path) - logger.info(f"save checkpoint finish") + logger.info("save checkpoint finish") def _get_donefile_path(self, root_path, step): return f"{root_path}/{step}/donefile" @@ -269,14 +313,14 @@ def _remove_stale_checkpoint(self, root_path, num_to_keep): def save(self, config, step): if config is None or config is {}: - logger.warning(f"checkpoint is empty, skip") + logger.warning("checkpoint is empty, skip") return root_path = config.get("root_path") if root_path is None: - logger.warning(f"checkpoint root_path is empty, skip") + logger.warning("checkpoint root_path is empty, skip") num_to_keep = config.get("num_to_keep") if num_to_keep <= 0: - logger.warning(f"checkpoint num_to_keep cannot be zero, ignored") + logger.warning("checkpoint num_to_keep cannot be zero, ignored") num_to_keep = None local_checkpoint_path = self._get_local_path(root_path, step) if self.mode == "ddp": diff --git a/pretrain/plugin/megatron_processer.py b/pretrain/plugin/megatron_processer.py index f9d71c3bc..178256ad5 100644 --- a/pretrain/plugin/megatron_processer.py +++ b/pretrain/plugin/megatron_processer.py @@ -1,38 +1,34 @@ -import torch - from megatron import get_args, print_rank_0 from megatron.core import mpu from megatron.data.data_samplers import build_pretraining_data_loader -from deepspeed.accelerator import get_accelerator from common.dataprocesser import DataProcesser + class MegatronProcesser(DataProcesser): def prepare(self, tokenizer, dataset, **kwargs): args = get_args() (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) - print_rank_0('> building train, validation, and test datasets ...') + print_rank_0("> building train, validation, and test datasets ...") iteration = kwargs.get("step", 0) if iteration: # passed value is starting step iteration -= 1 args.consumed_train_samples = iteration * args.global_batch_size - args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ - args.eval_iters * args.global_batch_size + args.consumed_valid_samples = ( + (args.iteration // args.eval_interval) * args.eval_iters * args.global_batch_size + ) # Data loader only on rank 0 of each model parallel group. if args.use_dataset_only or mpu.get_tensor_model_parallel_rank() == 0: - - # Build datasets. + # Build datasets. train_ds, valid_ds, test_ds = dataset # Build dataloders. - train_dataloader = build_pretraining_data_loader( - train_ds, args.consumed_train_samples) - valid_dataloader = build_pretraining_data_loader( - valid_ds, args.consumed_valid_samples) + train_dataloader = build_pretraining_data_loader(train_ds, args.consumed_train_samples) + valid_dataloader = build_pretraining_data_loader(valid_ds, args.consumed_valid_samples) test_dataloader = build_pretraining_data_loader(test_ds, 0) return train_dataloader, valid_dataloader, test_dataloader diff --git a/pretrain/plugin/megtron_initializer.py b/pretrain/plugin/megtron_initializer.py index 9a9f6be41..cad268603 100644 --- a/pretrain/plugin/megtron_initializer.py +++ b/pretrain/plugin/megtron_initializer.py @@ -1,4 +1,3 @@ -from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.initialize import initialize_megatron from common.initializer import Initializer from common.logging import logger @@ -10,10 +9,10 @@ def __init__(self, config): self.args = {} def init(self): - #self._parse_arguments(ARGUMENTS_SCHEMA, config) + # self._parse_arguments(ARGUMENTS_SCHEMA, config) args = None - if "megatron_config" in self.config : - args = self.config["megatron_config"] + if "megatron_config" in self.config: + args = self.config["megatron_config"] initialize_megatron(ignore_unknown_args=True, external_args=args, allow_no_cuda=True) else: logger.error("cannot initialize the megatron without the megatron_config") diff --git a/pretrain/plugin/plain_id_processer.py b/pretrain/plugin/plain_id_processer.py index f3cf64215..20117cdcf 100644 --- a/pretrain/plugin/plain_id_processer.py +++ b/pretrain/plugin/plain_id_processer.py @@ -1,29 +1,29 @@ -import math -import time -from itertools import chain - import torch import transformers from common.dataprocesser import DataProcesser + class PlainIDProcesser(DataProcesser): def prepare(self, tokenizer, datasets): per_device_train_batch_size = self.config.get("per_device_train_batch_size", 1) - per_device_eval_batch_size = self.config.get("per_device_eval_batch_size", 1) + self.config.get("per_device_eval_batch_size", 1) def label(examples): examples["input_ids"] = examples["tokens"].copy() examples["labels"] = examples["tokens"].copy() return examples - train_datasets = [dataset["train"].map(label, remove_columns=["tokens"]) for dataset in datasets] + train_datasets = [ + dataset["train"].map(label, remove_columns=["tokens"]) for dataset in datasets + ] train_dataloaders = [ torch.utils.data.DataLoader( - train_dataset, - shuffle=False, - collate_fn=transformers.default_data_collator, - batch_size=per_device_train_batch_size - ) for train_dataset in train_datasets + train_dataset, + shuffle=False, + collate_fn=transformers.default_data_collator, + batch_size=per_device_train_batch_size, + ) + for train_dataset in train_datasets ] return train_dataloaders, None diff --git a/pretrain/plugin/pretrainer.py b/pretrain/plugin/pretrainer.py index be1494d54..1bde38f62 100755 --- a/pretrain/plugin/pretrainer.py +++ b/pretrain/plugin/pretrainer.py @@ -16,6 +16,7 @@ from common.trainer import Trainer from common.logging import logger + class PreTrainer(Trainer): def __init__(self, config): self.config = config @@ -34,7 +35,9 @@ def _coordinate(self, accelerator): self.size = accelerator.num_processes self.local_rank = accelerator.local_process_index accelerator.wait_for_everyone() - logger.info(f"coordinate workers finish, cluster size:{self.size} worker rank:{self.rank} worker local_rank:{self.local_rank}") + logger.info( + f"coordinate workers finish, cluster size:{self.size} worker rank:{self.rank} worker local_rank:{self.local_rank}" + ) def _get_all_checkpoint_episode(self, root_path): if not os.path.exists(f"{root_path}"): @@ -62,18 +65,18 @@ def _get_latest_checkpoint_episode(self, root_path): def recovery(self, config): if config is None or config is {}: - logger.warning(f"checkpoint is empty, skip") + logger.warning("checkpoint is empty, skip") return 0 root_path = config.get("root_path") episode = config.get("episode", None) if root_path is None: - logger.error(f"checkpoint root_path is None, exit") + logger.error("checkpoint root_path is None, exit") exit(1) if episode is None: episode = self._get_latest_checkpoint_episode(root_path) if episode is None or episode == -1: - logger.warning(f"episode is None, skip") + logger.warning("episode is None, skip") return 0 local_checkpoint_path = self._get_local_path(root_path, episode) try: @@ -90,22 +93,33 @@ def recovery(self, config): self.optimizer.load_state_dict(optimizer_state) # update lr_scheduler status - if Path.exists(checkpoint_dir / "lr_scheduler.pt") and hasattr(self, "lr_scheduler"): - scheduler_state = torch.load(checkpoint_dir / "lr_schduler.pt", map_location="cpu") + if Path.exists(checkpoint_dir / "lr_scheduler.pt") and hasattr( + self, "lr_scheduler" + ): + scheduler_state = torch.load( + checkpoint_dir / "lr_schduler.pt", map_location="cpu" + ) self.lr_scheduler.load_state_dict(scheduler_state) logger.info(f"recovery to episode {int(episode)}") self.starting_episode = int(episode) + 1 - except Exception as e: - logger.warning(f"recovery error", exc_info=True) - - def _get_lr_scheduler(self, lr_scheduler_config, optimizer, num_train_epochs, num_steps_per_epoch, accelerator): + except Exception: + logger.warning("recovery error", exc_info=True) + + def _get_lr_scheduler( + self, + lr_scheduler_config, + optimizer, + num_train_epochs, + num_steps_per_epoch, + accelerator, + ): # gradient_accumulation_steps = accelerator.gradient_accumulation_steps # num_update_steps_per_epoch = math.ceil(num_steps_per_epoch / gradient_accumulation_steps) enable = lr_scheduler_config.get("enable", False) if not enable: return None - max_train_steps = lr_scheduler_config.get("max_train_steps") + max_train_steps = lr_scheduler_config.get("max_train_steps") lr_scheduler_type = lr_scheduler_config.get("lr_scheduler_type", "linear") num_warmup_steps = lr_scheduler_config.get("num_warmup_steps", 0) @@ -127,17 +141,23 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator): logger.info(f"model embedding size: {embedding_size}") if len(tokenizer) > embedding_size: model.resize_token_embeddings(len(tokenizer)) - logger.warning(f"model embedding size resize to {len(tokenizer)} because of tokenizer size") + logger.warning( + f"model embedding size resize to {len(tokenizer)} because of tokenizer size" + ) - train_dataloader, eval_dataloader = self.dataprocesser.prepare( - tokenizer, dataset - ) + train_dataloader, eval_dataloader = self.dataprocesser.prepare(tokenizer, dataset) lr_scheduler_config = self.config.get("lr_scheduler") if lr_scheduler_config: num_steps_per_epoch = len(train_dataloader) num_train_epochs = self.config.get("num_train_epochs", 1) - lr_scheduler = self._get_lr_scheduler(lr_scheduler_config, optimizer, num_train_epochs, num_steps_per_epoch, accelerator) + lr_scheduler = self._get_lr_scheduler( + lr_scheduler_config, + optimizer, + num_train_epochs, + num_steps_per_epoch, + accelerator, + ) else: lr_scheduler = None @@ -164,21 +184,22 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator): pass self.train_dataloader, self.eval_dataloader = accelerator.prepare( - train_dataloader, eval_dataloader, + train_dataloader, + eval_dataloader, ) def _check_and_mkdir(self, path): - path = Path(path) + path = Path(path) if not path.exists(): path.mkdir(parents=True) - + def _write_json(self, target_dict, save_path): json_object = json.dumps(target_dict, indent=4) with open(save_path, "w") as outfile: outfile.write(json_object) def train(self): - num_train_epochs = self.config.get("num_train_epochs", 1) + self.config.get("num_train_epochs", 1) checkpoint = self.config.get("checkpoint") log_step = self.config.get("log_step", 1) max_train_step_per_episode = self.config.get("max_train_step_per_episode") @@ -189,21 +210,25 @@ def train(self): self._check_and_mkdir(save_state_path) training_state = {} else: - training_state = None - + training_state = None + for idx in range(self.starting_episode, len(self.train_dataloader), 1): logger.info(f"start train episode {idx}") if training_state is not None and int(self.rank) == 0: - training_state[f'episode_{idx}'] = {} + training_state[f"episode_{idx}"] = {} self.model.train() current_train_dataloader = self.train_dataloader[idx] start = time.time() for step, batch in enumerate(current_train_dataloader): if training_state is not None and int(self.rank) == 0: - training_state[f'episode_{idx}'][f'step_{step}'] = {} - training_state[f'episode_{idx}'][f'step_{step}']['data'] = batch['input_ids'][0].tolist()[:50] - training_state[f'episode_{idx}'][f'step_{step}']['learning_rate'] = self.lr_scheduler.state_dict()['_last_lr'] - + training_state[f"episode_{idx}"][f"step_{step}"] = {} + training_state[f"episode_{idx}"][f"step_{step}"]["data"] = batch["input_ids"][ + 0 + ].tolist()[:50] + training_state[f"episode_{idx}"][f"step_{step}"][ + "learning_rate" + ] = self.lr_scheduler.state_dict()["_last_lr"] + with self.accelerator.accumulate(self.model): outputs = self.model(**batch) loss = outputs.loss @@ -215,14 +240,20 @@ def train(self): self.lr_scheduler.step() self.optimizer.zero_grad() if step % log_step == 0: - logger.info(f"train episode:[{idx}/{len(self.train_dataloader)}]\tstep:[{step}]\tloss:{loss}\tppl:{math.exp(loss)}\ttime:{time.time()-start}") + logger.info( + f"train episode:[{idx}/{len(self.train_dataloader)}]\tstep:[{step}]\tloss:{loss}\tppl:{math.exp(loss)}\ttime:{time.time()-start}" + ) start = time.time() if training_state is not None and int(self.rank) == 0: - training_state[f'episode_{idx}'][f'step_{step}']['loss'] = loss.item() - training_state[f'episode_{idx}'][f'step_{step}']['ppl'] = math.exp(loss) - file_name = "stepwise_training_state_recovery" if self.starting_episode > 0 else "stepwise_training_state" + training_state[f"episode_{idx}"][f"step_{step}"]["loss"] = loss.item() + training_state[f"episode_{idx}"][f"step_{step}"]["ppl"] = math.exp(loss) + file_name = ( + "stepwise_training_state_recovery" + if self.starting_episode > 0 + else "stepwise_training_state" + ) self._write_json(training_state, f"{save_state_path}/{file_name}.json") - + if max_train_step_per_episode is not None: if step >= max_train_step_per_episode: break @@ -236,7 +267,11 @@ def train(self): with torch.no_grad(): outputs = self.model(**batch) loss = outputs.loss - losses.append(self.accelerator.gather_for_metrics(loss.repeat(batch["input_ids"].shape[0]))) + losses.append( + self.accelerator.gather_for_metrics( + loss.repeat(batch["input_ids"].shape[0]) + ) + ) if max_eval_step_per_episode is not None: if step >= max_eval_step_per_episode: break @@ -248,7 +283,9 @@ def train(self): except OverflowError: eval_loss = float("inf") perplexity = float("inf") - logger.info(f"eval episode:[{idx}/{len(self.train_dataloader)}]\tloss:[{eval_loss}]\tppl:[{perplexity}]\ttime:[{time.time()-start}]") + logger.info( + f"eval episode:[{idx}/{len(self.train_dataloader)}]\tloss:[{eval_loss}]\tppl:[{perplexity}]\ttime:[{time.time()-start}]" + ) if checkpoint is not None: self.save(checkpoint, idx) @@ -259,7 +296,9 @@ def train(self): logger.info(f"start save model to {output}") unwrapped_model = self.accelerator.unwrap_model(self.model) unwrapped_model.save_pretrained( - output, is_main_process=self.accelerator.is_main_process, save_function=self.accelerator.save + output, + is_main_process=self.accelerator.is_main_process, + save_function=self.accelerator.save, ) logger.info(f"finish save model to {output}") self.accelerator.wait_for_everyone() @@ -279,10 +318,13 @@ def _save(self, local_checkpoint_path): torch.save(unwrapped_model.state_dict(), os.path.join(tmpdir, "model.pt")) torch.save(self.optimizer.state_dict(), os.path.join(tmpdir, "optim.pt")) if self.lr_scheduler: - torch.save(self.lr_scheduler.state_dict(), os.path.join(tmpdir, "lr_schduler.pt")) + torch.save( + self.lr_scheduler.state_dict(), + os.path.join(tmpdir, "lr_schduler.pt"), + ) checkpoint = Checkpoint.from_directory(tmpdir) checkpoint.to_directory(local_checkpoint_path) - logger.info(f"save checkpoint finish") + logger.info("save checkpoint finish") def _get_donefile_path(self, root_path, episode): return f"{root_path}/{episode}/donefile" @@ -305,14 +347,14 @@ def _remove_stale_checkpoint(self, root_path, num_to_keep): def save(self, config, episode): if config is None or config is {}: - logger.warning(f"checkpoint is empty, skip") + logger.warning("checkpoint is empty, skip") return root_path = config.get("root_path") if root_path is None: - logger.warning(f"checkpoint root_path is empty, skip") + logger.warning("checkpoint root_path is empty, skip") num_to_keep = config.get("num_to_keep") if num_to_keep <= 0: - logger.warning(f"checkpoint num_to_keep cannot be zero, ignored") + logger.warning("checkpoint num_to_keep cannot be zero, ignored") num_to_keep = None local_checkpoint_path = self._get_local_path(root_path, episode) if self.mode == "ddp": @@ -323,5 +365,5 @@ def save(self, config, episode): else: pass self._save_done(root_path, episode) - if num_to_keep > 0 and self.mode == "ddp"and int(self.rank) == 0: + if num_to_keep > 0 and self.mode == "ddp" and int(self.rank) == 0: self._remove_stale_checkpoint(root_path, num_to_keep) diff --git a/pretrain/pretrain.py b/pretrain/pretrain.py index 69ce217d3..3e045c19d 100644 --- a/pretrain/pretrain.py +++ b/pretrain/pretrain.py @@ -1,8 +1,6 @@ #!/usr/bin/env python import os -import time -import traceback from typing import Any, Dict import accelerate @@ -13,14 +11,17 @@ from ray.air import RunConfig, FailureConfig import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import common -import importlib +from importlib import util + use_habana = False -loader = importlib.util.find_spec('habana_frameworks') +loader = util.find_spec("habana_frameworks") if loader is not None: from backend.habana_backend import TorchConfig + use_habana = True else: from ray.train.torch import TorchConfig @@ -32,83 +33,84 @@ def train_func(config: Dict[str, Any]): if cwd: os.chdir(cwd) from common.common import import_all_module - import_all_module(f"{os.path.dirname(os.path.realpath(__file__))}/plugin","plugin") - common.init(config) + + import_all_module(f"{os.path.dirname(os.path.realpath(__file__))}/plugin", "plugin") + common.init(config) # type: ignore initializer_config = config.get("initializer") if initializer_config: - try : + try: initializer = common.get_initializer(initializer_config) initializer.init() except Exception as e: common.logger.critical(e, exc_info=True) - exit(1) - common.logger.info(f"Initializer is initialized") + exit(1) + common.logger.info("Initializer is initialized") accelerator = None accelerator_config = config.get("accelerator") - if accelerator_config != None: - try : + if accelerator_config is not None: + try: common.logger.info(f"accelerator_config: {accelerator_config}") accelerator = accelerate.Accelerator(**accelerator_config) except Exception as e: common.logger.critical(e, exc_info=True) exit(1) - common.logger.info(f"accelerator generate finish") - + common.logger.info("accelerator generate finish") + model = None datasets = None tokenizer = None optimizer = None - + datasets_config = config.get("datasets") if datasets_config: datasets = common.load_dataset(datasets_config) - common.logger.info(f" ") + common.logger.info(" ") else: - common.logger.warn(f"No datasets plugin provided, use the built-in datasets of trainer") - + common.logger.warn("No datasets plugin provided, use the built-in datasets of trainer") + tokenizer_config = config.get("tokenizer") if tokenizer_config: tokenizer = common.load_tokenizer(tokenizer_config) else: - common.logger.warn(f"No tokenizer plugin provided, use the built-in tokenizer of trainer") - + common.logger.warn("No tokenizer plugin provided, use the built-in tokenizer of trainer") + model_config = config.get("model") if model_config: model = common.load_model(model_config) else: - common.logger.warn(f"No model plugin provided, use the built-in model of trainer") - + common.logger.warn("No model plugin provided, use the built-in model of trainer") + optimizer_config = config.get("optimizer") if optimizer_config: optimizer = common.load_optimizer(model, config.get("optimizer")) else: - common.logger.warn(f"No optimizer plugin provided, use the built-in optimizer of trainer") + common.logger.warn("No optimizer plugin provided, use the built-in optimizer of trainer") - trainer_config = config.get("trainer") + trainer_config = config.get("trainer") if trainer_config: trainer = common.get_trainer(config.get("trainer")) - try : + try: trainer.prepare(model, tokenizer, datasets, optimizer, accelerator) except Exception as e: common.logger.critical(e, exc_info=True) exit(1) - common.logger.info(f"trainer prepare finish") + common.logger.info("trainer prepare finish") - try : - common.logger.info(f"train start") + try: + common.logger.info("train start") trainer.train() - common.logger.info(f"train done") + common.logger.info("train done") except Exception as e: common.logger.critical(e, exc_info=True) exit(1) - common.logger.info(f"train finish") + common.logger.info("train finish") else: - common.logger.error(f"Trainer isn't found!") + common.logger.error("Trainer isn't found!") -def main(external_config = None): +def main(external_config=None): config = common.Config() if external_config is not None: config.merge(external_config) @@ -120,19 +122,18 @@ def main(external_config = None): ray_init_config = ray_config.get("init", {}) common.logger.info(f"ray init config: {ray_init_config}") - runtime_env = ray_init_config.get("runtime_env") ray.init(**ray_init_config) scaling_config = ScalingConfig(**ray_config.get("scaling_config", {})) common.logger.info(f"ray scaling config: {scaling_config}") if ( - config['trainer'].get("training_config", None) and - config['trainer'].get("training_config").get("deepspeed", None) and - use_habana == False + config["trainer"].get("training_config", None) + and config["trainer"].get("training_config").get("deepspeed", None) + and use_habana is False ): torch_config = DeepSpeedTorchConfig(**ray_config.get("torch_config", {})) - else: + else: torch_config = TorchConfig(**ray_config.get("torch_config", {})) common.logger.info(f"ray torch config: {torch_config}") @@ -146,10 +147,10 @@ def main(external_config = None): train_func, train_loop_config=config, scaling_config=scaling_config, - torch_config = torch_config, - run_config = run_config + torch_config=torch_config, + run_config=run_config, ) - results = trainer.fit() + trainer.fit() elif config.get("run_mode") == "initialized": ray_config = config.get("ray_config") @@ -168,12 +169,13 @@ def main(external_config = None): train_func, train_loop_config=config, scaling_config=scaling_config, - torch_config = torch_config, - run_config = run_config + torch_config=torch_config, + run_config=run_config, ) - results = trainer.fit() + trainer.fit() else: pass + if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml index 5baee1eb3..8fbec0007 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,3 +74,6 @@ packages = ["finetune", "inference"] [project.urls] Repository = "https://github.com/intel/llm-on-ray.git" Issues = "https://github.com/intel/llm-on-ray.git/issues" + +[tool.black] +line-length = 100 diff --git a/rlhf/ppo.py b/rlhf/ppo.py index bfbab6998..cc9fab6ae 100644 --- a/rlhf/ppo.py +++ b/rlhf/ppo.py @@ -1,9 +1,6 @@ #!/usr/bin/env python import os -import time -import traceback -from typing import Any, Dict import ray from ray import air, tune @@ -16,38 +13,38 @@ from rl_algo.ppo.rlhf_ppo_torch_learner import RLHFPPOTorchLearner import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import common from common.agentenv.rlhf_env import RLHFEnv class ValueFunctionInitializerCallback(DefaultCallbacks): - def on_algorithm_init(self, *, algorithm, **kwargs) -> None: - learner_group = algorithm.learner_group + learner_group = algorithm.learner_group # noqa: F841 # assigned to but never used def init_ray(config): - num_training_workers = config["Training"].get("num_training_workers") resources_per_worker = config["Training"].get("resources_per_worker") runtime_env = { "env_vars": { - "OMP_NUM_THREADS": str(resources_per_worker["CPU"]), - "ACCELERATE_USE_CPU": "True", + "OMP_NUM_THREADS": str(resources_per_worker["CPU"]), + "ACCELERATE_USE_CPU": "True", "ACCELERATE_MIXED_PRECISION": "no", "CCL_WORKER_COUNT": "1", "CCL_LOG_LEVEL": "info", "WORLD_SIZE": str(num_training_workers), } } - ray.init(runtime_env = runtime_env, local_mode = True) + ray.init(runtime_env=runtime_env, local_mode=True) -def prepare_ppo(config): +def prepare_ppo(config): env_creator = lambda config: RLHFEnv(config) + tune.register_env("RLHFEnv", env_creator) agentenv_config = { @@ -56,42 +53,40 @@ def prepare_ppo(config): "tokenizer": { "type": "HuggingFaceTokenizer", "name": config["General"]["model_name"], - "config": {} + "config": {}, }, "reward_model": { "type": "HuggingFaceRewardModel", "name": config["General"]["rm_name"], - "config": {} + "config": {}, }, "sft_model": { "type": "HuggingFaceModelForCausalLM", "name": config["General"]["model_name"], - "config": {} + "config": {}, }, "datasets": { "type": "HuggingfaceDataset", "name": config["Dataset"]["train_file"], - "load_config" : { - } + "load_config": {}, }, "kl_coeff": config["Training"]["kl_coeff"], "max_generation_length": 50, - "model_max_length": 1024 - } + "model_max_length": 1024, + }, } ppo_config = ( PPOConfig(algo_class=PPORLHF) .framework("torch") .environment( - "RLHFEnv", + "RLHFEnv", env_config=agentenv_config, disable_env_checking=True, ) .rl_module( _enable_rl_module_api=True, - rl_module_spec=SingleAgentRLModuleSpec - ( + rl_module_spec=SingleAgentRLModuleSpec( RLHFPPOTorchRLModule, model_config_dict={ "actor_base_model": config["General"]["model_name"], @@ -100,7 +95,7 @@ def prepare_ppo(config): ), ) .training( - learner_class = RLHFPPOTorchLearner, + learner_class=RLHFPPOTorchLearner, # optimizer=config["Training"]["optimizer"], lr=config["Training"]["learning_rate"], num_sgd_iter=1, @@ -108,9 +103,7 @@ def prepare_ppo(config): train_batch_size=config["Training"]["experience_batch_size"], _enable_learner_api=True, ) - .rollouts( - num_rollout_workers=0 - ) + .rollouts(num_rollout_workers=0) .evaluation( evaluation_interval=1, evaluation_duration_unit="episodes", @@ -120,16 +113,18 @@ def prepare_ppo(config): _disable_initialize_loss_from_dummy_batch=True, ) .callbacks( - callbacks_class=make_multi_callbacks([ - ValueFunctionInitializerCallback, - ]) + callbacks_class=make_multi_callbacks( + [ + ValueFunctionInitializerCallback, + ] + ) ) ) return ppo_config -def main(external_config = None): +def main(external_config=None): config = common.Config() if external_config is not None: config.merge(external_config) @@ -146,13 +141,13 @@ def main(external_config = None): checkpoint_frequency=50, checkpoint_at_end=True, ), - stop={"training_iteration": config["Training"]["training_iteration"]} + stop={"training_iteration": config["Training"]["training_iteration"]}, ), - tune_config=tune.TuneConfig(reuse_actors=False) + tune_config=tune.TuneConfig(reuse_actors=False), ) - results = tuner.fit() + tuner.fit() + if __name__ == "__main__": main() - diff --git a/rlhf/reward.py b/rlhf/reward.py index 5cf48ec06..7045a6c44 100644 --- a/rlhf/reward.py +++ b/rlhf/reward.py @@ -1,8 +1,6 @@ #!/usr/bin/env python import os -import time -import traceback from typing import Any, Dict import accelerate @@ -13,83 +11,97 @@ from ray.air import RunConfig, FailureConfig import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import common + def train_func(config: Dict[str, Any]): cwd = config.get("cwd") if cwd: os.chdir(cwd) - + gradient_accumulation_steps = config["Training"].get("gradient_accumulation_steps", 1) - accelerator = accelerate.Accelerator(gradient_accumulation_steps = gradient_accumulation_steps) - common.logger.info(f"accelerator generate finish") + accelerator = accelerate.Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) + common.logger.info("accelerator generate finish") seed = config["Training"].get("seed") if seed is not None: accelerate.utils.set_seed(seed) - datasets = common.dataset.Dataset.registory.get("HuggingfaceDataset")()(config = { - "name": config["Dataset"]["train_file"], - "validation_file": config["Dataset"]["validation_file"], - "validation_split_percentage": config["Dataset"]["validation_split_percentage"] - }) + datasets = common.dataset.Dataset.registory.get("HuggingfaceDataset")()( + config={ + "name": config["Dataset"]["train_file"], + "validation_file": config["Dataset"]["validation_file"], + "validation_split_percentage": config["Dataset"]["validation_split_percentage"], + } + ) - tokenizer = common.tokenizer.Tokenizer.registory.get("HuggingFaceTokenizer")()(config = { - "name": config["General"]["base_model"], - }) + tokenizer = common.tokenizer.Tokenizer.registory.get("HuggingFaceTokenizer")()( + config={ + "name": config["General"]["base_model"], + } + ) - model = common.model.Model.registory.get("HuggingFaceRewardModel")()(config = { - "name": config["General"]["base_model"], - }) + model = common.model.Model.registory.get("HuggingFaceRewardModel")()( + config={ + "name": config["General"]["base_model"], + } + ) - optimizer = common.optimizer.Optimizer.registory.get("DefaultOptimizer")()(model, config = { - "name": config["Training"]["optimizer"], - "config": { - "lr": config["Training"]["learning_rate"] + optimizer = common.optimizer.Optimizer.registory.get("DefaultOptimizer")()( + model, + config={ + "name": config["Training"]["optimizer"], + "config": {"lr": config["Training"]["learning_rate"]}, }, - }) - - trainer = common.trainer.Trainer.registory.get("RMTrainer")(config = { - "num_train_epochs": config["Training"]["epochs"], - "max_train_step": config["Training"].get("max_train_steps", None), - "output": config["General"]["output_dir"], - "dataprocesser": { - "type": "RMDataProcesser", - "per_device_train_batch_size": config["Training"]["batch_size"], - "per_device_eval_batch_size": config["Training"]["batch_size"], - "preprocessing_num_workers": config["Dataset"].get("preprocessing_num_workers", 1), - "shuffle": True - }, - "lr_scheduler": { - "enable": True, - "max_train_steps": None, - "lr_scheduler_type": config["Training"]["lr_scheduler"], - "num_warmup_steps": 0, - }, - "checkpoint": { - "root_path": config["General"]["checkpoint_dir"], - } if config["General"].get("checkpoint_dir") else None - }) + ) + + trainer = common.trainer.Trainer.registory.get("RMTrainer")( + config={ + "num_train_epochs": config["Training"]["epochs"], + "max_train_step": config["Training"].get("max_train_steps", None), + "output": config["General"]["output_dir"], + "dataprocesser": { + "type": "RMDataProcesser", + "per_device_train_batch_size": config["Training"]["batch_size"], + "per_device_eval_batch_size": config["Training"]["batch_size"], + "preprocessing_num_workers": config["Dataset"].get("preprocessing_num_workers", 1), + "shuffle": True, + }, + "lr_scheduler": { + "enable": True, + "max_train_steps": None, + "lr_scheduler_type": config["Training"]["lr_scheduler"], + "num_warmup_steps": 0, + }, + "checkpoint": { + "root_path": config["General"]["checkpoint_dir"], + } + if config["General"].get("checkpoint_dir") + else None, + } + ) - try : - common.logger.info(f"trainer prepare start") + try: + common.logger.info("trainer prepare start") trainer.prepare(model, tokenizer, datasets, optimizer, accelerator) except Exception as e: common.logger.critical(e, exc_info=True) exit(1) - common.logger.info(f"trainer prepare finish") + common.logger.info("trainer prepare finish") - try : - common.logger.info(f"train start") + try: + common.logger.info("train start") trainer.train() except Exception as e: common.logger.critical(e, exc_info=True) exit(1) - common.logger.info(f"train finish") + common.logger.info("train finish") + -def main(external_config = None): +def main(external_config=None): config = common.Config() if external_config is not None: config.merge(external_config) @@ -101,23 +113,23 @@ def main(external_config = None): if not ray.is_initialized(): runtime_env = { "env_vars": { - "OMP_NUM_THREADS": str(resources_per_worker["CPU"]), - "ACCELERATE_USE_CPU": "True", + "OMP_NUM_THREADS": str(resources_per_worker["CPU"]), + "ACCELERATE_USE_CPU": "True", "ACCELERATE_MIXED_PRECISION": "no", "CCL_WORKER_COUNT": "1", "CCL_LOG_LEVEL": "info", "WORLD_SIZE": str(num_training_workers), } } - ray.init(runtime_env = runtime_env) + ray.init(runtime_env=runtime_env) scaling_config = ScalingConfig( - num_workers = num_training_workers, - resources_per_worker = resources_per_worker, - placement_strategy = "SPREAD" + num_workers=num_training_workers, + resources_per_worker=resources_per_worker, + placement_strategy="SPREAD", ) - torch_config = common.TorchConfig(backend = "ccl") + torch_config = common.TorchConfig(backend="ccl") failure_config = FailureConfig() @@ -130,10 +142,10 @@ def main(external_config = None): train_func, train_loop_config=config, scaling_config=scaling_config, - torch_config = torch_config, - run_config = run_config + torch_config=torch_config, + run_config=run_config, ) - results = trainer.fit() + trainer.fit() if __name__ == "__main__": diff --git a/rlhf/rl_algo/ppo/ppo_rlhf.py b/rlhf/rl_algo/ppo/ppo_rlhf.py index 41c3cb28d..55657a507 100644 --- a/rlhf/rl_algo/ppo/ppo_rlhf.py +++ b/rlhf/rl_algo/ppo/ppo_rlhf.py @@ -1,57 +1,48 @@ - import torch import numpy as np -from typing import List, Optional, Type, Union, TYPE_CHECKING -from ray.rllib.algorithms import Algorithm, AlgorithmConfig +from ray.rllib.algorithms import AlgorithmConfig from ray.rllib.algorithms.ppo import PPO -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.sample_batch import SampleBatch, concat_samples, DEFAULT_POLICY_ID -from ray.rllib.core.learner.learner_group import LearnerGroup +from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.utils.metrics import ( - NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED, LEARNER_STATS_KEY + NUM_AGENT_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED, + LEARNER_STATS_KEY, ) from ray.rllib.evaluation.metrics import RolloutMetrics -import sys, os -sys.path.append(os.path.join(os.path.dirname(__file__), '../../../')) +import os +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), "../../../")) from common.agentenv.rlhf_env import generate_response from .rlhf_buffer import Buffer, BufferItem -from ray.rllib.evaluation.metrics import ( - collect_episodes, - collect_metrics, - summarize_episodes, -) - class RLHFSampler: - """This sampler is a local sampler for LLMEnv. - - The underlying env is an LLMEnv which creates a batch of prompts and the agent has - to generate a response for each prompt. Then the env evaluate those responses and - returns a reward signal. + """This sampler is a local sampler for LLMEnv. + + The underlying env is an LLMEnv which creates a batch of prompts and the agent has + to generate a response for each prompt. Then the env evaluate those responses and + returns a reward signal. """ def __init__(self, module, env): - self._env = env self._module = module self.max_generation_length = self._env.max_generation_length - def sample(self, batch_size: int, **kwargs) -> SampleBatch: - - # TODO (Kourosh): Can we use batch inference here? + # TODO (Kourosh): Can we use batch inference here? batches = Buffer() for i in range(batch_size): obs, _ = self._env.reset() output = generate_response( - self._module.actor, - input_ids=torch.tensor(obs['input_ids'])[None], + self._module.actor, + input_ids=torch.tensor(obs["input_ids"])[None], max_length=self.max_generation_length, eos_token_id=self._env.tokenizer.eos_token_id, ) @@ -62,51 +53,49 @@ def sample(self, batch_size: int, **kwargs) -> SampleBatch: generated_tokens = output["sequence"][-n_generated_tokens:] value = self._module.critic(output["sequence"]).detach().item() - + action = { "sequence": generated_tokens.detach().numpy()[0], - "response_mask": np.array([0]*n_input_tokens + [1]*n_generated_tokens), - "logits": output["logits"].detach().numpy()[0], # remove batch dimension - "attention_mask": np.array([1]*(n_input_tokens + n_generated_tokens)), + "response_mask": np.array([0] * n_input_tokens + [1] * n_generated_tokens), + "logits": output["logits"].detach().numpy()[0], # remove batch dimension + "attention_mask": np.array([1] * (n_input_tokens + n_generated_tokens)), } next_obs, reward, terminated, truncated, info = self._env.step(action) - assert terminated == True, "The env should be terminated after each step." + assert terminated is True, "The env should be terminated after each step." # value and reward should be both float scalars here. advantages = value - reward batches.append( - BufferItem(**{ - SampleBatch.OBS: obs, - SampleBatch.ACTIONS: action, - SampleBatch.REWARDS: reward, - SampleBatch.INFOS: info, - Postprocessing.VALUE_TARGETS: value, - Postprocessing.ADVANTAGES: advantages, - }) + BufferItem( + **{ + SampleBatch.OBS: obs, + SampleBatch.ACTIONS: action, + SampleBatch.REWARDS: reward, + SampleBatch.INFOS: info, + Postprocessing.VALUE_TARGETS: value, + Postprocessing.ADVANTAGES: advantages, + } + ) ) - return batches.convert_to_sample_batch() class PPORLHF(PPO): - - def setup(self, config: AlgorithmConfig) -> None: super().setup(config) self.rlhf_module = self.learner_group._learner.module[DEFAULT_POLICY_ID] self.env = self.workers.local_worker().env - + # create a copy of module and env in the algorithm. self.sampler = RLHFSampler(self.rlhf_module, self.env) def training_step(self): - train_batch = self.sampler.sample(batch_size=self.config.train_batch_size) train_batch = train_batch.as_multi_agent() self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps() @@ -120,8 +109,7 @@ def training_step(self): policies_to_update = {DEFAULT_POLICY_ID} kl_dict = { - pid: train_results[pid][LEARNER_STATS_KEY].get("kl") - for pid in policies_to_update + pid: train_results[pid][LEARNER_STATS_KEY].get("kl") for pid in policies_to_update } self.learner_group.additional_update( module_ids_to_update=policies_to_update, @@ -130,29 +118,25 @@ def training_step(self): ) return train_results - + def evaluate(self): # breakpoint() train_batch = self.sampler.sample(batch_size=1) - rewards = train_batch[SampleBatch.INFOS]['r_align'] - - self.evaluation_metrics = {"evaluation": - { - "reward": rewards.item() - } - } + rewards = train_batch[SampleBatch.INFOS]["r_align"] + + self.evaluation_metrics = {"evaluation": {"reward": rewards.item()}} eval_metric = RolloutMetrics( - episode_length = 1, - episode_reward = rewards.item(), - agent_rewards = {}, - custom_metrics = {}, - perf_stats = {}, - hist_data = {}, - media = {}, - episode_faulty = {}, - connector_metrics = {}, + episode_length=1, + episode_reward=rewards.item(), + agent_rewards={}, + custom_metrics={}, + perf_stats={}, + hist_data={}, + media={}, + episode_faulty={}, + connector_metrics={}, ) self.workers.local_worker().sampler.metrics_queue.put(eval_metric) @@ -160,5 +144,3 @@ def evaluate(self): self.workers.local_worker().sampler._env_runner_obj._perf_stats.incr("iters", 1) return self.evaluation_metrics - - diff --git a/rlhf/rl_algo/ppo/rlhf_buffer.py b/rlhf/rl_algo/ppo/rlhf_buffer.py index 8175a5905..512fba6de 100644 --- a/rlhf/rl_algo/ppo/rlhf_buffer.py +++ b/rlhf/rl_algo/ppo/rlhf_buffer.py @@ -1,31 +1,34 @@ from dataclasses import dataclass import torch import numpy as np -import tree # pip install dm-tree +import tree # pip install dm-tree from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation.postprocessing import Postprocessing +from typing import Dict, Any + @dataclass class BufferItem: - # TODO (Kourosh): These names have to match those in the SampleBatch and + # TODO (Kourosh): These names have to match those in the SampleBatch and # PostProcessing. - obs: dict # keys (shape): input_ids (T,), attention_mask (T,) - actions: dict # keys: sequence (T,), response_mask (T,), logits (T, VS), attention_mask (T, ) - infos: dict - rewards: float # scalar (python float) - value_targets: float # scalar (python float) - advantages: float # scalar (python float) + obs: dict # keys (shape): input_ids (T,), attention_mask (T,) + actions: dict # keys: sequence (T,), response_mask (T,), logits (T, VS), attention_mask (T, ) + infos: dict + rewards: float # scalar (python float) + value_targets: float # scalar (python float) + advantages: float # scalar (python float) + class Buffer: """This buffer should work for both torch and numpy types in the buffer items. - - Its job is to collect simple BufferItems but then upon calling - convert_to_sample_batch, figure out the padding required to create blocks for + + Its job is to collect simple BufferItems but then upon calling + convert_to_sample_batch, figure out the padding required to create blocks for tensors inside a SampleBatch. """ - + def __init__(self): self._buffer = [] self._framework = None @@ -35,54 +38,60 @@ def append(self, item: BufferItem): self._framework = torch if isinstance(item.obs["input_ids"], torch.Tensor) else np else: if self._framework == torch: - assert isinstance(item.obs["input_ids"], torch.Tensor), "The buffer items should be of the same framework." + assert isinstance( + item.obs["input_ids"], torch.Tensor + ), "The buffer items should be of the same framework." else: - assert isinstance(item.obs["input_ids"], np.ndarray), "The buffer items should be of the same framework." - + assert isinstance( + item.obs["input_ids"], np.ndarray + ), "The buffer items should be of the same framework." # under the same key, the values should be of the same length for k in (SampleBatch.ACTIONS, SampleBatch.OBS): flattened = tree.flatten(getattr(item, k)) for i in range(len(flattened) - 1): - if not flattened[i].shape[0] == flattened[i+1].shape[0]: + if not flattened[i].shape[0] == flattened[i + 1].shape[0]: raise ValueError("The values under the same key should be of the same length.") - + self._buffer.append(item) - + def convert_to_sample_batch(self, padding_type: str = "right") -> SampleBatch: - assert padding_type in ("left", "right"), "The padding should be either 'left' or 'right'." + assert padding_type in ( + "left", + "right", + ), "The padding should be either 'left' or 'right'." keys = BufferItem.__dataclass_fields__.keys() sample_batch_dict = {} for key in keys: values = [] - for item in self._buffer: + for item in self._buffer: val = getattr(item, key) if isinstance(val, float): val = torch.tensor(val) if self._framework == torch else np.array(val) elif isinstance(val, dict): val = NestedDict(val) - + values.append(val) # some values may not have the same sequence length, so we need to pad them if key in (SampleBatch.ACTIONS, SampleBatch.OBS): # we should first obtain the max length for each value. Remember that each value is possibly a nested dict where the values are tensors. - # TODO (Kourosh): This is not optimal since we are flattening the whole - # tree structure, while all we need is the DFS traversal of the tree + # TODO (Kourosh): This is not optimal since we are flattening the whole + # tree structure, while all we need is the DFS traversal of the tree # and obtaining the first leave. # Each v is a nested dict where the leave values can be iterated easily max_length = max(next(iter(v.values())).shape[0] for v in values) - + for item in values: for nested_key, val in item.items(): if val.shape[0] < max_length: padding = self._framework.zeros( - (max_length - val.shape[0], *val.shape[1:]), - dtype=val.dtype + (max_length - val.shape[0], *val.shape[1:]), + dtype=val.dtype, ) if padding_type == "left": @@ -95,129 +104,117 @@ def convert_to_sample_batch(self, padding_type: str = "right") -> SampleBatch: item[nested_key] = torch.cat((val, padding), 0) else: item[nested_key] = np.concatenate((val, padding), 0) - - values = tree.map_structure(lambda *x: self._framework.stack(x,0), *values) + + values = tree.map_structure(lambda *x: self._framework.stack(x, 0), *values) sample_batch_dict[key] = values.asdict() if isinstance(values, NestedDict) else values return SampleBatch(sample_batch_dict) - -if __name__ == "__main__": +if __name__ == "__main__": foo = Buffer() - foo.append( - BufferItem(**{ - SampleBatch.OBS: { - "input_ids": torch.tensor([1, 2, 3]), - "attention_mask": torch.tensor([1, 1, 1]) - }, - SampleBatch.ACTIONS: { - "sequence": torch.tensor([1, 2, 3, 4]), - "logits": torch.tensor([[0.5, 0.5] for _ in range(4)]), - "attention_mask": torch.tensor([1, 1, 1, 1]) - }, - SampleBatch.REWARDS: 1.0, - Postprocessing.VALUE_TARGETS: 1.0, - Postprocessing.ADVANTAGES: 1.0, - }) - ) - - foo.append( - BufferItem(**{ - SampleBatch.OBS: { - "input_ids": torch.tensor([4, 5, 6]), - "attention_mask": torch.tensor([1, 1, 1]) - }, - SampleBatch.ACTIONS: { - "sequence": torch.tensor([4, 5, 6, 7]), - "logits": torch.tensor([[0.5, 0.5] for _ in range(4)]), - "attention_mask": torch.tensor([1, 1, 1, 1]) - }, - SampleBatch.REWARDS: 1.0, - Postprocessing.VALUE_TARGETS: 1.0, - Postprocessing.ADVANTAGES: 1.0, - }) - ) + bufferitem: Dict[str, Any] = {} + bufferitem = { + SampleBatch.OBS: { + "input_ids": torch.tensor([1, 2, 3]), + "attention_mask": torch.tensor([1, 1, 1]), + }, + SampleBatch.ACTIONS: { + "sequence": torch.tensor([1, 2, 3, 4]), + "logits": torch.tensor([[0.5, 0.5] for _ in range(4)]), + "attention_mask": torch.tensor([1, 1, 1, 1]), + }, + SampleBatch.REWARDS: 1.0, + Postprocessing.VALUE_TARGETS: 1.0, + Postprocessing.ADVANTAGES: 1.0, + } + foo.append(BufferItem(**bufferitem)) + bufferitem = { + SampleBatch.OBS: { + "input_ids": torch.tensor([4, 5, 6]), + "attention_mask": torch.tensor([1, 1, 1]), + }, + SampleBatch.ACTIONS: { + "sequence": torch.tensor([4, 5, 6, 7]), + "logits": torch.tensor([[0.5, 0.5] for _ in range(4)]), + "attention_mask": torch.tensor([1, 1, 1, 1]), + }, + SampleBatch.REWARDS: 1.0, + Postprocessing.VALUE_TARGETS: 1.0, + Postprocessing.ADVANTAGES: 1.0, + } + foo.append(BufferItem(**bufferitem)) # action sequence length is different from the previous two - foo.append( - BufferItem(**{ - SampleBatch.OBS: { - "input_ids": torch.tensor([4, 5, 6]), - "attention_mask": torch.tensor([1, 1, 1]) - }, - SampleBatch.ACTIONS: { - "sequence": torch.tensor([4, 5, 6]), - "logits": torch.tensor([[0.5, 0.5] for _ in range(3)]), - "attention_mask": torch.tensor([1, 1, 1]) - }, - SampleBatch.REWARDS: 1.0, - Postprocessing.VALUE_TARGETS: 1.0, - Postprocessing.ADVANTAGES: 1.0, - }) - ) - - sb = foo.convert_to_sample_batch() - + bufferitem = { + SampleBatch.OBS: { + "input_ids": torch.tensor([4, 5, 6]), + "attention_mask": torch.tensor([1, 1, 1]), + }, + SampleBatch.ACTIONS: { + "sequence": torch.tensor([4, 5, 6]), + "logits": torch.tensor([[0.5, 0.5] for _ in range(3)]), + "attention_mask": torch.tensor([1, 1, 1]), + }, + SampleBatch.REWARDS: 1.0, + Postprocessing.VALUE_TARGETS: 1.0, + Postprocessing.ADVANTAGES: 1.0, + } + foo.append(BufferItem(**bufferitem)) + sb = foo.convert_to_sample_batch() # numpy version + bufferitem = { + SampleBatch.OBS: { + "input_ids": np.array([1, 2, 3]), + "attention_mask": np.array([1, 1, 1]), + }, + SampleBatch.ACTIONS: { + "sequence": np.array([1, 2, 3, 4]), + "logits": np.array([[0.5, 0.5] for _ in range(4)]), + "attention_mask": np.array([1, 1, 1, 1]), + }, + SampleBatch.REWARDS: 1.0, + Postprocessing.VALUE_TARGETS: 1.0, + Postprocessing.ADVANTAGES: 1.0, + } foo = Buffer() - foo.append( - BufferItem(**{ - SampleBatch.OBS: { - "input_ids": np.array([1, 2, 3]), - "attention_mask": np.array([1, 1, 1]) - }, - SampleBatch.ACTIONS: { - "sequence": np.array([1, 2, 3, 4]), - "logits": np.array([[0.5, 0.5] for _ in range(4)]), - "attention_mask": np.array([1, 1, 1, 1]) - }, - SampleBatch.REWARDS: 1.0, - Postprocessing.VALUE_TARGETS: 1.0, - Postprocessing.ADVANTAGES: 1.0, - }) - ) - - foo.append( - BufferItem(**{ - SampleBatch.OBS: { - "input_ids": np.array([4, 5, 6]), - "attention_mask": np.array([1, 1, 1]) - }, - SampleBatch.ACTIONS: { - "sequence": np.array([4, 5, 6, 7]), - "logits": np.array([[0.5, 0.5] for _ in range(4)]), - "attention_mask": np.array([1, 1, 1, 1]) - }, - SampleBatch.REWARDS: 1.0, - Postprocessing.VALUE_TARGETS: 1.0, - Postprocessing.ADVANTAGES: 1.0, - }) - ) + foo.append(BufferItem(**bufferitem)) + + bufferitem = { + SampleBatch.OBS: { + "input_ids": np.array([4, 5, 6]), + "attention_mask": np.array([1, 1, 1]), + }, + SampleBatch.ACTIONS: { + "sequence": np.array([4, 5, 6, 7]), + "logits": np.array([[0.5, 0.5] for _ in range(4)]), + "attention_mask": np.array([1, 1, 1, 1]), + }, + SampleBatch.REWARDS: 1.0, + Postprocessing.VALUE_TARGETS: 1.0, + Postprocessing.ADVANTAGES: 1.0, + } + foo.append(BufferItem(**bufferitem)) # action sequence length is different from the previous two - foo.append( - BufferItem(**{ - SampleBatch.OBS: { - "input_ids": np.array([4, 5, 6]), - "attention_mask": np.array([1, 1, 1]), - }, - SampleBatch.ACTIONS: { - "sequence": np.array([4, 5, 6]), - "logits": np.array([[0.5, 0.5] for _ in range(3)]), - "attention_mask": np.array([1, 1, 1]) - }, - SampleBatch.REWARDS: 1.0, - Postprocessing.VALUE_TARGETS: 1.0, - Postprocessing.ADVANTAGES: 1.0, - }) - ) - - sb = foo.convert_to_sample_batch() + bufferitem = { + SampleBatch.OBS: { + "input_ids": np.array([4, 5, 6]), + "attention_mask": np.array([1, 1, 1]), + }, + SampleBatch.ACTIONS: { + "sequence": np.array([4, 5, 6]), + "logits": np.array([[0.5, 0.5] for _ in range(3)]), + "attention_mask": np.array([1, 1, 1]), + }, + SampleBatch.REWARDS: 1.0, + Postprocessing.VALUE_TARGETS: 1.0, + Postprocessing.ADVANTAGES: 1.0, + } + foo.append(BufferItem(**bufferitem)) + sb = foo.convert_to_sample_batch() breakpoint() - - \ No newline at end of file diff --git a/rlhf/rl_algo/ppo/rlhf_ppo_module.py b/rlhf/rl_algo/ppo/rlhf_ppo_module.py index e308ec2ea..a8729d752 100644 --- a/rlhf/rl_algo/ppo/rlhf_ppo_module.py +++ b/rlhf/rl_algo/ppo/rlhf_ppo_module.py @@ -1,9 +1,6 @@ - from typing import Optional -import gymnasium as gym import torch -from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule from ray.rllib.core.rl_module.rl_module import RLModuleConfig from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.torch.torch_distributions import TorchCategorical @@ -11,8 +8,6 @@ import transformers -from .util import masked_mean - class Critic(torch.nn.Module): def __init__(self, model_base: str): @@ -20,28 +15,25 @@ def __init__(self, model_base: str): self.base = transformers.AutoModel.from_pretrained(model_base) self.trunk = torch.nn.Linear(self.base.config.hidden_size, 1) - + def forward( - self, - input_ids: torch.LongTensor, - attention_mask: Optional[torch.Tensor] = None + self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: outputs = self.base(input_ids, attention_mask=attention_mask) - last_hidden_states = outputs['last_hidden_state'] + last_hidden_states = outputs["last_hidden_state"] # only use the hidden state on the last layer for the the last token as the value values = self.trunk(last_hidden_states[:, -1]).squeeze(-1) assert values.ndim == 1, "values should be a 1D tensor with batch size" return values -class RLHFPPOTorchRLModule(TorchRLModule): +class RLHFPPOTorchRLModule(TorchRLModule): def __init__(self, config: RLModuleConfig): super().__init__(config) # Override the default to customize def setup(self): - model_config = self.config.model_config_dict actor_base_model = model_config.get("actor_base_model") critic_base_model = model_config.get("critic_base_model") @@ -51,18 +43,18 @@ def setup(self): def input_specs_exploration(self): return [] - + def input_specs_inference(self): return [] def _forward_exploration(self, batch): - # we skip the default sampler's procedure for inference and exploration + # we skip the default sampler's procedure for inference and exploration pass - + def _forward_inference(self, batch): # we skip the default sampler's procedure for inference and exploration pass - + def _forward_train(self, batch): output = {} @@ -71,18 +63,16 @@ def _forward_train(self, batch): attention_mask=batch[SampleBatch.ACTIONS]["attention_mask"], ) - output[SampleBatch.VF_PREDS] = vf_out # (batch_size,) + output[SampleBatch.VF_PREDS] = vf_out # (batch_size,) actor_out = self.actor( input_ids=batch[SampleBatch.ACTIONS]["sequence"], attention_mask=batch[SampleBatch.ACTIONS]["attention_mask"], ) - actor_logits = actor_out.logits # (batch_size, seq_len, vocab_size) + actor_logits = actor_out.logits # (batch_size, seq_len, vocab_size) actor_dist = TorchCategorical(logits=actor_logits) output[SampleBatch.ACTION_DIST_INPUTS] = actor_logits output[SampleBatch.ACTION_DIST] = actor_dist return output - - diff --git a/rlhf/rl_algo/ppo/rlhf_ppo_torch_learner.py b/rlhf/rl_algo/ppo/rlhf_ppo_torch_learner.py index c4eb5b587..733863703 100644 --- a/rlhf/rl_algo/ppo/rlhf_ppo_torch_learner.py +++ b/rlhf/rl_algo/ppo/rlhf_ppo_torch_learner.py @@ -1,5 +1,5 @@ import logging -from typing import Mapping, Any +from typing import Mapping from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.policy.sample_batch import SampleBatch @@ -10,17 +10,6 @@ from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner from ray.rllib.models.torch.torch_distributions import TorchCategorical -from ray.rllib.algorithms.ppo.ppo_learner import ( - LEARNER_RESULTS_KL_KEY, - LEARNER_RESULTS_CURR_KL_COEFF_KEY, - LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY, - LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY, - PPOLearner, - PPOLearnerHyperparameters, -) -from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY - -from ray.rllib.utils.nested_dict import NestedDict from .util import masked_mean @@ -30,18 +19,14 @@ class RLHFPPOTorchLearner(PPOTorchLearner): - @override(PPOTorchLearner) def compute_loss_per_module( - self, - module_id: str, - batch: SampleBatch, - fwd_out: Mapping[str, TensorType] + self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType] ) -> TensorType: """Extention of PPO loss function to support RLHF. This customization adds attention mask to loss calculation. - It also adds the ptx-loss term introduced in InstructGPT paper for making sure + It also adds the ptx-loss term introduced in InstructGPT paper for making sure the model is aligned with the pre-trained model. """ @@ -81,7 +66,7 @@ def compute_loss_per_module( curr_entropy = masked_mean(curr_entropy_unmasked, attention_mask, dim=-1) mean_entropy = curr_entropy.mean() - surrogate_loss = - torch.min( + surrogate_loss = -torch.min( batch[Postprocessing.ADVANTAGES] * logp_ratio, batch[Postprocessing.ADVANTAGES] * torch.clamp(logp_ratio, 1 - self.hps.clip_param, 1 + self.hps.clip_param), @@ -106,7 +91,6 @@ def compute_loss_per_module( - self.entropy_coeff_scheduler.get_current_value(module_id) * curr_entropy ) - # Add mean_kl_loss (already processed through `reduce_mean_valid`), # if necessary. if self.hps.kl_coeff > 0.0: @@ -127,4 +111,4 @@ def compute_loss_per_module( "mean_reward_total": batch[SampleBatch.REWARDS].mean(), "mean_reward_rm": batch[SampleBatch.INFOS]["r_align"].mean(), "mean_reward_kl": batch[SampleBatch.INFOS]["r_kl"].mean(), - } \ No newline at end of file + } diff --git a/rlhf/rl_algo/ppo/util.py b/rlhf/rl_algo/ppo/util.py index 957415549..56a2a2374 100644 --- a/rlhf/rl_algo/ppo/util.py +++ b/rlhf/rl_algo/ppo/util.py @@ -1,8 +1,9 @@ import torch + def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: tensor = tensor * mask tensor = tensor.sum(dim=dim) mask_sum = mask.sum(dim=dim) mean = tensor / (mask_sum + 1e-8) - return mean \ No newline at end of file + return mean diff --git a/ui/start_ui.py b/ui/start_ui.py index 1490cb350..33df94177 100644 --- a/ui/start_ui.py +++ b/ui/start_ui.py @@ -18,10 +18,11 @@ import time import os import sys + sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from inference.inference_config import all_models, ModelDescription, Prompt from inference.inference_config import InferenceConfig as FinetunedConfig -from inference.chat_process import ChatModelGptJ, ChatModelLLama +from inference.chat_process import ChatModelGptJ, ChatModelLLama # noqa: F401 from inference.predictor_deployment import PredictorDeployment from ray import serve import ray @@ -34,11 +35,17 @@ from ray.util import queue import paramiko from html_format import cpu_memory_html, ray_status_html, custom_css -from typing import Dict +from typing import Dict, List, Any from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from pyrecdp.LLM import TextPipeline -from pyrecdp.primitives.operations import UrlLoader, DirectoryLoader, DocumentSplit, DocumentIngestion +from pyrecdp.primitives.operations import ( + UrlLoader, + DirectoryLoader, + DocumentSplit, + DocumentIngestion, +) + class CustomStopper(Stopper): def __init__(self): @@ -56,13 +63,13 @@ def stop(self, flag): @ray.remote -class Progress_Actor(): +class Progress_Actor: def __init__(self, config) -> None: self.config = config def track_progress(self): if "epoch_value" not in self.config: - return -1,-1,-1,-1 + return -1, -1, -1, -1 if not self.config["epoch_value"].empty(): total_epochs = self.config["total_epochs"].get(block=False) total_steps = self.config["total_steps"].get(block=False) @@ -75,9 +82,9 @@ def track_progress(self): class LoggingCallback(LoggerCallback): def __init__(self, config) -> None: self.config = config - self.results = [] + self.results: List[Any] = [] - def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + def log_trial_result(self, iteration: int, trial, result: Dict): if "train_epoch" in trial.last_result: self.config["epoch_value"].put(trial.last_result["train_epoch"] + 1, block=False) self.config["total_epochs"].put(trial.last_result["total_epochs"], block=False) @@ -88,7 +95,7 @@ def get_result(self): return self.results -class ChatBotUI(): +class ChatBotUI: def __init__( self, all_models: Dict[str, FinetunedConfig], @@ -103,7 +110,7 @@ def __init__( node_port: str, node_user_name: str, conda_env_name: str, - master_ip_port: str + master_ip_port: str, ): self._all_models = all_models self._base_models = base_models @@ -118,12 +125,17 @@ def __init__( self.conda_env_name = conda_env_name self.master_ip_port = master_ip_port self.ray_nodes = ray.nodes() - self.ssh_connect = [None] * (len(self.ray_nodes)+1) + self.ssh_connect = [None] * (len(self.ray_nodes) + 1) self.ip_port = "http://127.0.0.1:8000" self.stopper = CustomStopper() self.test_replica = 4 self.bot_queue = list(range(self.test_replica)) - self.messages = ["What is AI?", "What is Spark?", "What is Ray?", "What is chatbot?"] + self.messages = [ + "What is AI?", + "What is Spark?", + "What is Ray?", + "What is chatbot?", + ] self.process_tool = None self.finetune_actor = None self.finetune_status = False @@ -137,27 +149,34 @@ def __init__( def history_to_messages(history): messages = [] for human_text, bot_text in history: - messages.append({ - "role": "user", - "content": human_text, - }) + messages.append( + { + "role": "user", + "content": human_text, + } + ) if bot_text is not None: - messages.append({ - "role": "assistant", - "content": bot_text, - }) + messages.append( + { + "role": "assistant", + "content": bot_text, + } + ) return messages - + @staticmethod def add_knowledge(prompt, enhance_knowledge): description = "Known knowledge: {knowledge}. Then please answer the question based on follow conversation: {conversation}." return description.format(knowledge=enhance_knowledge, conversation=prompt) def clear(self): - return None, f"""| | | + return ( + None, + """| | | |---|---| | Total Latency [s] | - | - | Tokens | - |""" + | Tokens | - |""", + ) def reset(self, id): id = int(id) @@ -166,21 +185,29 @@ def reset(self, id): def user(self, user_message, history): return "", history + [[user_message, None]] - def model_generate(self, prompt, request_url, config): print("prompt: ", prompt) - + sample_input = {"text": prompt, "config": config, "stream": True} - proxies = { "http": None, "https": None} + proxies = {"http": None, "https": None} outputs = requests.post(request_url, proxies=proxies, json=sample_input, stream=True) outputs.raise_for_status() for output in outputs.iter_content(chunk_size=None, decode_unicode=True): # remove context if prompt in output: - output = output[len(prompt):] + output = output[len(prompt) :] yield output - def bot(self, history, model_endpoint, Max_new_tokens, Temperature, Top_p, Top_k, enhance_knowledge=None): + def bot( + self, + history, + model_endpoint, + Max_new_tokens, + Temperature, + Top_p, + Top_k, + enhance_knowledge=None, + ): prompt = self.history_to_messages(history) prompt = self.process_tool.get_prompt(prompt) if enhance_knowledge: @@ -201,9 +228,9 @@ def bot(self, history, model_endpoint, Max_new_tokens, Temperature, Top_p, Top_k if len(output) != 0: time_end = time.time() if history[-1][1] is None: - history[-1][1]=output + history[-1][1] = output else: - history[-1][1]+=output + history[-1][1] += output history[-1][1] = self.process_tool.convert_output(history[-1][1]) time_spend = round(time_end - time_start, 3) token_num += 1 @@ -212,9 +239,19 @@ def bot(self, history, model_endpoint, Max_new_tokens, Temperature, Top_p, Top_k |---|---| | Total Latency [s] | {time_spend} | | Tokens | {token_num} |""" - yield [history, new_token_latency] + yield [history, new_token_latency] - def bot_test(self, bot_queue, queue_id, history, model_endpoint, Max_new_tokens, Temperature, Top_p, Top_k): + def bot_test( + self, + bot_queue, + queue_id, + history, + model_endpoint, + Max_new_tokens, + Temperature, + Top_p, + Top_k, + ): prompt = self.history_to_messages(history) prompt = self.process_tool.get_prompt(prompt) request_url = model_endpoint @@ -232,15 +269,26 @@ def bot_test(self, bot_queue, queue_id, history, model_endpoint, Max_new_tokens, if len(output) != 0: time_end = time.time() if history[-1][1] is None: - history[-1][1]=output + history[-1][1] = output else: - history[-1][1]+=output + history[-1][1] += output history[-1][1] = self.process_tool.convert_output(history[-1][1]) time_spend = time_end - time_start bot_queue.put([queue_id, history, time_spend]) bot_queue.put([queue_id, "", ""]) - def bot_rag(self, history, model_endpoint, Max_new_tokens, Temperature, Top_p, Top_k, rag_selector, rag_path, returned_k): + def bot_rag( + self, + history, + model_endpoint, + Max_new_tokens, + Temperature, + Top_p, + Top_k, + rag_selector, + rag_path, + returned_k, + ): enhance_knowledge = None if os.path.isabs(rag_path): tmp_folder = os.getcwd() @@ -257,11 +305,27 @@ def bot_rag(self, history, model_endpoint, Max_new_tokens, Temperature, Top_p, T sim_res = vectorstore.similarity_search(question, k=int(returned_k)) enhance_knowledge = sim_res[0].page_content - bot_generator = self.bot(history, model_endpoint, Max_new_tokens, Temperature, Top_p, Top_k, enhance_knowledge) + bot_generator = self.bot( + history, + model_endpoint, + Max_new_tokens, + Temperature, + Top_p, + Top_k, + enhance_knowledge, + ) for output in bot_generator: yield output - def regenerate(self, db_dir, web_urls, data_pdfs, embedding_model, splitter_chunk_size, cpus_per_worker): + def regenerate( + self, + db_dir, + web_urls, + data_pdfs, + embedding_model, + splitter_chunk_size, + cpus_per_worker, + ): pdf_folder = [] if data_pdfs: for _, file in enumerate(data_pdfs): @@ -274,7 +338,7 @@ def regenerate(self, db_dir, web_urls, data_pdfs, embedding_model, splitter_chun if not os.path.exists(save_dir): os.makedirs(save_dir) web_urls = web_urls.split(";") - target_urls = [url.strip() for url in web_urls if url!=""] + target_urls = [url.strip() for url in web_urls if url != ""] if len(target_urls) > 0 and len(pdf_folder) > 0: raise gr.Warning("Setting both 'web urls' and 'pdf files' is not supported") @@ -282,9 +346,13 @@ def regenerate(self, db_dir, web_urls, data_pdfs, embedding_model, splitter_chun index_name = "knowledge_db" text_splitter = "RecursiveCharacterTextSplitter" splitter_chunk_size = int(splitter_chunk_size) - text_splitter_args = {"chunk_size": splitter_chunk_size, "chunk_overlap": 0, "separators": ["\n\n", "\n", " ", ""]} + text_splitter_args = { + "chunk_size": splitter_chunk_size, + "chunk_overlap": 0, + "separators": ["\n\n", "\n", " ", ""], + } embeddings_type = "HuggingFaceEmbeddings" - embeddings_args = {'model_name': embedding_model} + embeddings_args = {"model_name": embedding_model} if embedding_model != self.embedding_model_name: self.embedding_model_name = embedding_model self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model_name) @@ -292,22 +360,27 @@ def regenerate(self, db_dir, web_urls, data_pdfs, embedding_model, splitter_chun pipeline = TextPipeline() ops = [] if len(target_urls) > 0: - ops.append(UrlLoader(urls=target_urls, target_tag='div', target_attrs={'class': 'main-content'})) + ops.append( + UrlLoader( + urls=target_urls, + target_tag="div", + target_attrs={"class": "main-content"}, + ) + ) if len(pdf_folder) > 0: ops.append(DirectoryLoader(input_files=pdf_folder)) - ops.extend([ - DocumentSplit(text_splitter=text_splitter, text_splitter_args=text_splitter_args), - DocumentIngestion( - vector_store=vector_store_type, - vector_store_args={ - "output_dir": save_dir, - "index": index_name - }, - embeddings=embeddings_type, - embeddings_args=embeddings_args, - num_cpus=cpus_per_worker - ) - ]) + ops.extend( + [ + DocumentSplit(text_splitter=text_splitter, text_splitter_args=text_splitter_args), + DocumentIngestion( + vector_store=vector_store_type, + vector_store_args={"output_dir": save_dir, "index": index_name}, + embeddings=embeddings_type, + embeddings_args=embeddings_args, + num_cpus=cpus_per_worker, + ), + ] + ) pipeline.add_operations(ops) pipeline.execute() return db_dir @@ -315,15 +388,40 @@ def regenerate(self, db_dir, web_urls, data_pdfs, embedding_model, splitter_chun def send_all_bot(self, id, history, model_endpoint, Max_new_tokens, Temperature, Top_p, Top_k): id = int(id) self.bot_queue[id] = Queue() - p = Process(target=self.bot_test, args=(self.bot_queue[id], id, history, model_endpoint, Max_new_tokens, Temperature, Top_p, Top_k)) + p = Process( + target=self.bot_test, + args=( + self.bot_queue[id], + id, + history, + model_endpoint, + Max_new_tokens, + Temperature, + Top_p, + Top_k, + ), + ) p.start() - while(True): + while True: res = self.bot_queue[id].get() if res[1] == "": break yield res[1] - def finetune(self, model_name, custom_model_name, custom_tokenizer_name, dataset, new_model_name, batch_size, num_epochs, max_train_step, lr, worker_num, cpus_per_worker_ftn): + def finetune( + self, + model_name, + custom_model_name, + custom_tokenizer_name, + dataset, + new_model_name, + batch_size, + num_epochs, + max_train_step, + lr, + worker_num, + cpus_per_worker_ftn, + ): if model_name == "specify other models": model_desc = None origin_model_path = custom_model_name @@ -339,7 +437,11 @@ def finetune(self, model_name, custom_model_name, custom_tokenizer_name, dataset gpt_base_model = model_desc.gpt_base_model last_gpt_base_model = False finetuned_model_path = os.path.join(self.finetuned_model_path, model_name, new_model_name) - finetuned_checkpoint_path = os.path.join(self.finetuned_checkpoint_path, model_name, new_model_name) if self.finetuned_checkpoint_path != "" else None + finetuned_checkpoint_path = ( + os.path.join(self.finetuned_checkpoint_path, model_name, new_model_name) + if self.finetuned_checkpoint_path != "" + else None + ) finetune_config = self.config.copy() training_config = finetune_config.get("Training") @@ -347,15 +449,21 @@ def finetune(self, model_name, custom_model_name, custom_tokenizer_name, dataset exist_cpus_per_worker_ftn = int(training_config["resources_per_worker"]["CPU"]) ray_resources = ray.available_resources() - if "CPU" not in ray_resources or cpus_per_worker_ftn * worker_num + 1 > int(ray.available_resources()["CPU"]): + if "CPU" not in ray_resources or cpus_per_worker_ftn * worker_num + 1 > int( + ray.available_resources()["CPU"] + ): raise gr.Error("Resources are not meeting the demand") - if worker_num != exist_worker or cpus_per_worker_ftn != exist_cpus_per_worker_ftn or not (gpt_base_model and last_gpt_base_model): + if ( + worker_num != exist_worker + or cpus_per_worker_ftn != exist_cpus_per_worker_ftn + or not (gpt_base_model and last_gpt_base_model) + ): ray.shutdown() new_ray_init_config = { "runtime_env": { "env_vars": { - "OMP_NUM_THREADS": str(cpus_per_worker_ftn), - "ACCELERATE_USE_CPU": "True", + "OMP_NUM_THREADS": str(cpus_per_worker_ftn), + "ACCELERATE_USE_CPU": "True", "ACCELERATE_MIXED_PRECISION": "no", "CCL_WORKER_COUNT": "1", "CCL_LOG_LEVEL": "info", @@ -390,11 +498,22 @@ def finetune(self, model_name, custom_model_name, custom_tokenizer_name, dataset finetune_config["Training"]["max_train_steps"] = max_train_step from finetune.finetune import main - finetune_config["total_epochs"] = queue.Queue(actor_options={"resources": {"queue_hardware": 1}}) - finetune_config["total_steps"] = queue.Queue(actor_options={"resources": {"queue_hardware": 1}}) - finetune_config["epoch_value"] = queue.Queue(actor_options={"resources": {"queue_hardware": 1}}) - finetune_config["step_value"] = queue.Queue(actor_options={"resources": {"queue_hardware": 1}}) - self.finetune_actor = Progress_Actor.options(resources={"queue_hardware": 1}).remote(finetune_config) + + finetune_config["total_epochs"] = queue.Queue( + actor_options={"resources": {"queue_hardware": 1}} + ) + finetune_config["total_steps"] = queue.Queue( + actor_options={"resources": {"queue_hardware": 1}} + ) + finetune_config["epoch_value"] = queue.Queue( + actor_options={"resources": {"queue_hardware": 1}} + ) + finetune_config["step_value"] = queue.Queue( + actor_options={"resources": {"queue_hardware": 1}} + ) + self.finetune_actor = Progress_Actor.options(resources={"queue_hardware": 1}).remote( + finetune_config + ) callback = LoggingCallback(finetune_config) finetune_config["run_config"] = {} @@ -423,17 +542,21 @@ def finetune(self, model_name, custom_model_name, custom_tokenizer_name, dataset new_prompt.intro = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n" new_prompt.human_id = "\n### Instruction" new_prompt.bot_id = "\n### Response" - new_prompt.stop_words.extend(["### Instruction", "# Instruction", "### Question", "##", " ="]) - new_model_desc = ModelDescription(model_id_or_path=finetuned_model_path, - tokenizer_name_or_path=tokenizer_path, - prompt=new_prompt, - chat_processor=model_desc.chat_processor if model_desc is not None else "ChatModelGptJ", - ) + new_prompt.stop_words.extend( + ["### Instruction", "# Instruction", "### Question", "##", " ="] + ) + new_model_desc = ModelDescription( + model_id_or_path=finetuned_model_path, + tokenizer_name_or_path=tokenizer_path, + prompt=new_prompt, + chat_processor=model_desc.chat_processor if model_desc is not None else "ChatModelGptJ", + ) new_model_desc.config.trust_remote_code = True - new_finetuned = FinetunedConfig(name=new_model_name, - route_prefix="/" + new_model_name, - model_description=new_model_desc - ) + new_finetuned = FinetunedConfig( + name=new_model_name, + route_prefix="/" + new_model_name, + model_description=new_model_desc, + ) self._all_models[new_model_name] = new_finetuned return gr.Dropdown.update(choices=list(self._all_models.keys())) @@ -442,19 +565,32 @@ def finetune_progress(self, progress=gr.Progress()): while True: time.sleep(1) if self.finetune_actor is None: - if finetune_flag == False: + if finetune_flag is False: continue else: break - if self.finetune_status == True: + if self.finetune_status is True: break finetune_flag = True try: - total_epochs, total_steps, value_epoch, value_step = ray.get(self.finetune_actor.track_progress.remote()) + total_epochs, total_steps, value_epoch, value_step = ray.get( + self.finetune_actor.track_progress.remote() + ) if value_epoch == -1: continue - progress(float(int(value_step)/int(total_steps)), desc="Start Training: epoch "+ str(value_epoch)+" / "+str(total_epochs) +" "+"step " + str(value_step)+ " / "+ str(total_steps)) - except Exception as e: + progress( + float(int(value_step) / int(total_steps)), + desc="Start Training: epoch " + + str(value_epoch) + + " / " + + str(total_epochs) + + " " + + "step " + + str(value_step) + + " / " + + str(total_steps), + ) + except Exception: pass self.finetune_status = False return "

Completed the fine-tuning process.

" @@ -469,17 +605,23 @@ def deploy_func(self, model_name: str, replica_num: int, cpus_per_worker_deploy: stop_words = ["### Instruction", "# Instruction", "### Question", "##", " ="] finetuned = self._all_models[model_name] model_desc = finetuned.model_description - prompt = model_desc.prompt if model_desc.prompt else {} + prompt = model_desc.prompt print("model path: ", model_desc.model_id_or_path) - chat_model = getattr(sys.modules[__name__], model_desc.chat_processor, None) - if chat_model is None: - return model_name + " deployment failed. " + model_desc.chat_processor + " does not exist." - self.process_tool = chat_model(**prompt.dict()) + if model_desc.chat_processor is not None: + chat_model = getattr(sys.modules[__name__], model_desc.chat_processor, None) + if chat_model is None: + return ( + model_name + + " deployment failed. " + + model_desc.chat_processor + + " does not exist." + ) + self.process_tool = chat_model(**prompt.dict()) finetuned_deploy = finetuned.copy(deep=True) - finetuned_deploy.device = 'cpu' - finetuned_deploy.ipex.precision = 'bf16' + finetuned_deploy.device = "cpu" + finetuned_deploy.ipex.precision = "bf16" finetuned_deploy.model_description.prompt.stop_words = stop_words finetuned_deploy.cpus_per_worker = cpus_per_worker_deploy # transformers 4.35 is needed for neural-chat-7b-v3-1, will be fixed later @@ -487,63 +629,85 @@ def deploy_func(self, model_name: str, replica_num: int, cpus_per_worker_deploy: pip_env = "transformers==4.35.0" else: pip_env = "transformers==4.31.0" - deployment = PredictorDeployment.options(num_replicas=replica_num, ray_actor_options={"num_cpus": cpus_per_worker_deploy, "runtime_env": {"pip": [pip_env]}}).bind(finetuned_deploy) - handle = serve.run(deployment, _blocking=True, port=finetuned_deploy.port, name=finetuned_deploy.name, route_prefix=finetuned_deploy.route_prefix) - return self.ip_port + finetuned_deploy.route_prefix + deployment = PredictorDeployment.options( # type: ignore + num_replicas=replica_num, + ray_actor_options={ + "num_cpus": cpus_per_worker_deploy, + "runtime_env": {"pip": [pip_env]}, + }, + ).bind(finetuned_deploy) + serve.run( + deployment, + _blocking=True, + port=finetuned_deploy.port, + name=finetuned_deploy.name, + route_prefix=finetuned_deploy.route_prefix, + ) + return ( + self.ip_port + if finetuned_deploy.route_prefix is None + else self.ip_port + finetuned_deploy.route_prefix + ) def shutdown_finetune(self): self.stopper.stop(True) def shutdown_deploy(self): serve.shutdown() - + def get_ray_cluster(self): - command = 'conda activate ' + self.conda_env_name + '; ray status' + command = "conda activate " + self.conda_env_name + "; ray status" stdin, stdout, stderr = self.ssh_connect[-1].exec_command(command) - out = stdout.read().decode('utf-8') - out_words = [word for word in out.split("\n") if 'CPU' in word][0] + out = stdout.read().decode("utf-8") + out_words = [word for word in out.split("\n") if "CPU" in word][0] cpu_info = out_words.split(" ")[1].split("/") total_core = int(float(cpu_info[1])) used_core = int(float(cpu_info[0])) - utilization = float(used_core/total_core) - return ray_status_html.format(str(round(utilization*100, 1)), used_core, total_core) + utilization = float(used_core / total_core) + return ray_status_html.format(str(round(utilization * 100, 1)), used_core, total_core) def get_cpu_memory(self, index): if self.ray_nodes[index]["Alive"] == "False": return cpu_memory_html.format(str(round(0, 1)), str(round(0, 1))) - cpu_command = 'export TERM=xterm; echo $(top -n 1 -b | head -n 4 | tail -n 2)' + cpu_command = "export TERM=xterm; echo $(top -n 1 -b | head -n 4 | tail -n 2)" _, cpu_stdout, _ = self.ssh_connect[index].exec_command(cpu_command) - cpu_out = cpu_stdout.read().decode('utf-8') + cpu_out = cpu_stdout.read().decode("utf-8") cpu_out_words = cpu_out.split(" ") cpu_value = 100 - float(cpu_out_words[7]) - memory_command = 'export TERM=xterm; echo $(free -m)' + memory_command = "export TERM=xterm; echo $(free -m)" _, memory_stdout, _ = self.ssh_connect[index].exec_command(memory_command) - memory_out = memory_stdout.read().decode('utf-8') + memory_out = memory_stdout.read().decode("utf-8") memory_out_words = memory_out.split("Mem:")[1].split("Swap")[0].split(" ") memory_out_words = [m for m in memory_out_words if m != ""] total_memory = float(memory_out_words[0].strip()) free_memory = float(memory_out_words[2].strip()) buffer_memory = float(memory_out_words[4].strip()) - used_memory =(total_memory - free_memory - buffer_memory) / total_memory - return cpu_memory_html.format(str(round(cpu_value, 1)), str(round(used_memory*100, 1))) - + used_memory = (total_memory - free_memory - buffer_memory) / total_memory + return cpu_memory_html.format(str(round(cpu_value, 1)), str(round(used_memory * 100, 1))) + def kill_node(self, btn_txt, index): serve.shutdown() - if btn_txt=="Kill": + if btn_txt == "Kill": index = int(index) - command = 'conda activate ' + self.conda_env_name + '; ray stop' + command = "conda activate " + self.conda_env_name + "; ray stop" self.ssh_connect[index].exec_command(command) self.ray_nodes[index]["Alive"] = "False" time.sleep(2) return "Start", "" - elif btn_txt=="Start": + elif btn_txt == "Start": index = int(index) - command = "conda activate " + self.conda_env_name + "; RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 ray start --address=" + self.master_ip_port + r""" --resources='{"special_hardware": 2}'""" + command = ( + "conda activate " + + self.conda_env_name + + "; RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 ray start --address=" + + self.master_ip_port + + r""" --resources='{"special_hardware": 2}'""" + ) self.ssh_connect[index].exec_command(command) self.ray_nodes[index]["Alive"] = "True" time.sleep(2) return "Kill", "" - + def watch_node_status(self, index): if self.ray_nodes[index]["Alive"] == "False": return "

DEAD

" @@ -553,15 +717,15 @@ def watch_node_status(self, index): def set_custom_model(self, base_model_name): visible = True if base_model_name == "specify other models" else False return gr.Textbox.update(visible=visible), gr.Textbox.update(visible=visible) - + def set_rag_default_path(self, selector, rag_path): if rag_path: return rag_path - if selector == False: + if selector is False: return None else: return self.default_rag_path - + def _init_ui(self): mark_alive = None for index in range(len(self.ray_nodes)): @@ -571,14 +735,20 @@ def _init_ui(self): self.ssh_connect[index] = paramiko.SSHClient() self.ssh_connect[index].load_system_host_keys() self.ssh_connect[index].set_missing_host_key_policy(paramiko.RejectPolicy()) - self.ssh_connect[index].connect(hostname=node_ip, port=self.node_port, username=self.user_name) + self.ssh_connect[index].connect( + hostname=node_ip, port=self.node_port, username=self.user_name + ) self.ssh_connect[-1] = paramiko.SSHClient() self.ssh_connect[-1].load_system_host_keys() self.ssh_connect[-1].set_missing_host_key_policy(paramiko.RejectPolicy()) - self.ssh_connect[-1].connect(hostname=self.ray_nodes[mark_alive]["NodeName"], port=self.node_port, username=self.user_name) - + self.ssh_connect[-1].connect( + hostname=self.ray_nodes[mark_alive]["NodeName"], + port=self.node_port, + username=self.user_name, + ) + title = "Manage LLM Lifecycle" - with gr.Blocks(css=custom_css,title=title) as gr_chat: + with gr.Blocks(css=custom_css, title=title) as gr_chat: head_content = """
@@ -591,85 +761,224 @@ def _init_ui(self):

The workflow is powered by Ray to provide infrastructure management, distributed training, model serving with reliability and auto scaling.

""" - notice = gr.Markdown(head_content, elem_classes="notice_markdown") + gr.Markdown(head_content, elem_classes="notice_markdown") with gr.Tab("Finetune"): step1 = "Finetune the model with the base model and data" - gr.HTML("

"+ step1 + "

") + gr.HTML("

" + step1 + "

") with gr.Group(): base_models_list = list(self._base_models.keys()) base_models_list.append("specify other models") - base_model_dropdown = gr.Dropdown(base_models_list, value=base_models_list[2], - label="Select Base Model", allow_custom_value=True) - custom_model_name = gr.Textbox(label="Model id", placeholder="The model id of a pretrained model configuration hosted inside a model repo on huggingface.co", visible=False, interactive=True, elem_classes="disable_status") - custom_tokenizer_name = gr.Textbox(label="Tokenizer id", placeholder="The model id of a predefined tokenizer hosted inside a model repo on huggingface.co", visible=False, interactive=True, elem_classes="disable_status") + base_model_dropdown = gr.Dropdown( + base_models_list, + value=base_models_list[2], + label="Select Base Model", + allow_custom_value=True, + ) + custom_model_name = gr.Textbox( + label="Model id", + placeholder="The model id of a pretrained model configuration hosted inside a model repo on huggingface.co", + visible=False, + interactive=True, + elem_classes="disable_status", + ) + custom_tokenizer_name = gr.Textbox( + label="Tokenizer id", + placeholder="The model id of a predefined tokenizer hosted inside a model repo on huggingface.co", + visible=False, + interactive=True, + elem_classes="disable_status", + ) with gr.Accordion("Parameters", open=False, visible=True): - batch_size = gr.Slider(0, 1000, 2, step=1, interactive=True, label="Batch Size", info="train batch size per worker.") + batch_size = gr.Slider( + 0, + 1000, + 2, + step=1, + interactive=True, + label="Batch Size", + info="train batch size per worker.", + ) num_epochs = gr.Slider(1, 100, 1, step=1, interactive=True, label="Epochs") - max_train_step = gr.Slider(0, 1000, 10, step=1, interactive=True, label="Step per Epoch", info="value 0 means use the entire dataset.") - lr = gr.Slider(0, 0.001, 0.00001, step=0.00001, interactive=True, label="Learning Rate") - worker_num = gr.Slider(1, 8, 2, step=1, interactive=True, label="Worker Number", info="the number of workers used for finetuning.") - cpus_per_worker_ftn = gr.Slider(1, 100, 24, step=1, interactive=True, label="Cpus per Worker", info="the number of cpu cores used for every worker.") - gpus_per_worker_ftn = gr.Slider(0, 16, 0, step=1, interactive=True, label="Gpus per Worker", info="the number of gpu used for every worker.") + max_train_step = gr.Slider( + 0, + 1000, + 10, + step=1, + interactive=True, + label="Step per Epoch", + info="value 0 means use the entire dataset.", + ) + lr = gr.Slider( + 0, + 0.001, + 0.00001, + step=0.00001, + interactive=True, + label="Learning Rate", + ) + worker_num = gr.Slider( + 1, + 8, + 2, + step=1, + interactive=True, + label="Worker Number", + info="the number of workers used for finetuning.", + ) + cpus_per_worker_ftn = gr.Slider( + 1, + 100, + 24, + step=1, + interactive=True, + label="Cpus per Worker", + info="the number of cpu cores used for every worker.", + ) + gr.Slider( + 0, + 16, + 0, + step=1, + interactive=True, + label="Gpus per Worker", + info="the number of gpu used for every worker.", + ) with gr.Row(): with gr.Column(scale=0.6): - data_url = gr.Text(label="Data URL", - value=self.default_data_path) + data_url = gr.Text(label="Data URL", value=self.default_data_path) with gr.Column(scale=0.2): - finetuned_model_name = gr.Text(label="New Model Name", - value="my_alpaca") + finetuned_model_name = gr.Text(label="New Model Name", value="my_alpaca") with gr.Column(scale=0.2, min_width=0): finetune_btn = gr.Button("Start to Finetune") stop_finetune_btn = gr.Button("Stop") - + with gr.Row(): - finetune_res = gr.HTML("

", show_label=False, elem_classes="disable_status") + finetune_res = gr.HTML( + "

", + show_label=False, + elem_classes="disable_status", + ) with gr.Tab("Deployment"): step2 = "Deploy the finetuned model as an online inference service" - gr.HTML("

"+ step2 + "

") + gr.HTML("

" + step2 + "

") with gr.Row(): with gr.Column(scale=0.8): all_models_list = list(self._all_models.keys()) - all_model_dropdown = gr.Dropdown(all_models_list, value=all_models_list[3], label="Select Model to Deploy", - elem_classes="disable_status", allow_custom_value=True) + all_model_dropdown = gr.Dropdown( + all_models_list, + value=all_models_list[3], + label="Select Model to Deploy", + elem_classes="disable_status", + allow_custom_value=True, + ) with gr.Column(scale=0.2, min_width=0): deploy_btn = gr.Button("Deploy") stop_deploy_btn = gr.Button("Stop") - + with gr.Accordion("Parameters", open=False, visible=True): - replica_num = gr.Slider(1, 8, 4, step=1, interactive=True, label="Model Replica Number") - cpus_per_worker_deploy = gr.Slider(1, 100, 24, step=1, interactive=True, label="Cpus per Worker", info="the number of cpu cores used for every worker.") - gpus_per_worker_deploy = gr.Slider(0, 16, 0, step=1, interactive=True, label="Gpus per Worker", info="the number of gpu used for every worker.") + replica_num = gr.Slider( + 1, 8, 4, step=1, interactive=True, label="Model Replica Number" + ) + cpus_per_worker_deploy = gr.Slider( + 1, + 100, + 24, + step=1, + interactive=True, + label="Cpus per Worker", + info="the number of cpu cores used for every worker.", + ) + gr.Slider( + 0, + 16, + 0, + step=1, + interactive=True, + label="Gpus per Worker", + info="the number of gpu used for every worker.", + ) with gr.Row(): with gr.Column(scale=1): - deployed_model_endpoint = gr.Text(label="Deployed Model Endpoint", value="" ,elem_classes="disable_status") + deployed_model_endpoint = gr.Text( + label="Deployed Model Endpoint", + value="", + elem_classes="disable_status", + ) with gr.Tab("Inference"): step3 = "Access the online inference service in your own application" - gr.HTML("

"+ step3 + "

") + gr.HTML("

" + step3 + "

") with gr.Accordion("Configuration", open=False, visible=True): - max_new_tokens = gr.Slider(1, 2000, 128, step=1, interactive=True, label="Max New Tokens", info="The maximum numbers of tokens to generate.") - Temperature = gr.Slider(0, 1, 0.7, step=0.01, interactive=True, label="Temperature", info="The value used to modulate the next token probabilities.") - Top_p = gr.Slider(0, 1, 1.0, step=0.01, interactive=True, label="Top p", info="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to`Top p` or higher are kept for generation.") - Top_k = gr.Slider(0, 100, 0, step=1, interactive=True, label="Top k", info="The number of highest probability vocabulary tokens to keep for top-k-filtering.") - + max_new_tokens = gr.Slider( + 1, + 2000, + 128, + step=1, + interactive=True, + label="Max New Tokens", + info="The maximum numbers of tokens to generate.", + ) + Temperature = gr.Slider( + 0, + 1, + 0.7, + step=0.01, + interactive=True, + label="Temperature", + info="The value used to modulate the next token probabilities.", + ) + Top_p = gr.Slider( + 0, + 1, + 1.0, + step=0.01, + interactive=True, + label="Top p", + info="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to`Top p` or higher are kept for generation.", + ) + Top_k = gr.Slider( + 0, + 100, + 0, + step=1, + interactive=True, + label="Top k", + info="The number of highest probability vocabulary tokens to keep for top-k-filtering.", + ) + with gr.Tab("Dialogue"): - chatbot = gr.Chatbot(elem_id="chatbot", label="chatbot", elem_classes="disable_status") + chatbot = gr.Chatbot( + elem_id="chatbot", + label="chatbot", + elem_classes="disable_status", + ) with gr.Row(): with gr.Column(scale=0.8): - msg = gr.Textbox(show_label=False, container=False, - placeholder="Input your question and press Enter") + msg = gr.Textbox( + show_label=False, + container=False, + placeholder="Input your question and press Enter", + ) with gr.Column(scale=0.2, min_width=20): - latency_status = gr.Markdown(""" + latency_status = gr.Markdown( + """ | | | |---|---| | Total Latency [s] | - | - | Tokens | - |""", elem_classes=["disable_status", "output-stats", "disablegenerating", "div_height"]) + | Tokens | - |""", + elem_classes=[ + "disable_status", + "output-stats", + "disablegenerating", + "div_height", + ], + ) with gr.Row(): with gr.Column(scale=0.5, min_width=0): send_btn = gr.Button("Send") @@ -683,10 +992,19 @@ def _init_ui(self): msgs = list(range(self.test_replica)) for i in range(self.test_replica): with gr.Column(scale=scale_num, min_width=1): - chatbots[i] = gr.Chatbot(elem_id="chatbot"+str(i+1), label="chatbot"+str(i+1), min_width=1, elem_classes="disable_status") - msgs[i] = gr.Textbox(show_label=False, container=False, - placeholder="Input your question and press Enter", - value=self.messages[i], min_width=1) + chatbots[i] = gr.Chatbot( + elem_id="chatbot" + str(i + 1), + label="chatbot" + str(i + 1), + min_width=1, + elem_classes="disable_status", + ) + msgs[i] = gr.Textbox( + show_label=False, + container=False, + placeholder="Input your question and press Enter", + value=self.messages[i], + min_width=1, + ) with gr.Row(visible=False): ids = list(range(self.test_replica)) for i in range(self.test_replica): @@ -698,53 +1016,125 @@ def _init_ui(self): send_all_btn = gr.Button("Send all requsts") with gr.Column(scale=0.5): reset_all_btn = gr.Button("Reset") - + with gr.Tab("RAG"): step3_rag = "Use RAG to enhance generation capabilities" - gr.HTML("

"+ step3_rag + "

") + gr.HTML("

" + step3_rag + "

") with gr.Accordion("Configuration", open=False, visible=True): - max_new_tokens_rag = gr.Slider(1, 2000, 128, step=1, interactive=True, label="Max New Tokens", info="The maximum numbers of tokens to generate.") - Temperature_rag = gr.Slider(0, 1, 0.7, step=0.01, interactive=True, label="Temperature", info="The value used to modulate the next token probabilities.") - Top_p_rag = gr.Slider(0, 1, 1.0, step=0.01, interactive=True, label="Top p", info="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to`Top p` or higher are kept for generation.") - Top_k_rag = gr.Slider(0, 100, 0, step=1, interactive=True, label="Top k", info="The number of highest probability vocabulary tokens to keep for top-k-filtering.") + max_new_tokens_rag = gr.Slider( + 1, + 2000, + 128, + step=1, + interactive=True, + label="Max New Tokens", + info="The maximum numbers of tokens to generate.", + ) + Temperature_rag = gr.Slider( + 0, + 1, + 0.7, + step=0.01, + interactive=True, + label="Temperature", + info="The value used to modulate the next token probabilities.", + ) + Top_p_rag = gr.Slider( + 0, + 1, + 1.0, + step=0.01, + interactive=True, + label="Top p", + info="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to`Top p` or higher are kept for generation.", + ) + Top_k_rag = gr.Slider( + 0, + 100, + 0, + step=1, + interactive=True, + label="Top k", + info="The number of highest probability vocabulary tokens to keep for top-k-filtering.", + ) with gr.Accordion("RAG parameters", open=False, visible=True): with gr.Row(): with gr.Column(scale=0.5): - data_web_urls = gr.Textbox(label="web urls", value="https://www.intc.com/news-events/press-releases/detail/1655/intel-reports-third-quarter-2023-financial-results", placeholder="The urls of web dataset. Support multiple web urls seperated by ';'") + data_web_urls = gr.Textbox( + label="web urls", + value="https://www.intc.com/news-events/press-releases/detail/1655/intel-reports-third-quarter-2023-financial-results", + placeholder="The urls of web dataset. Support multiple web urls seperated by ';'", + ) with gr.Column(scale=0.5): # data_pdf_path = gr.Textbox(label="pdf folder", value='', placeholder="The folder of pdf files") - data_pdfs = gr.File(label="upload pdf files", file_count="multiple", file_types=[".pdf"], elem_classes="file_height") + data_pdfs = gr.File( + label="upload pdf files", + file_count="multiple", + file_types=[".pdf"], + elem_classes="file_height", + ) with gr.Row(): with gr.Column(scale=0.4): - embedding_model = gr.Textbox(label="embedding model", value="sentence-transformers/all-mpnet-base-v2", placeholder="Model name to use") + embedding_model = gr.Textbox( + label="embedding model", + value="sentence-transformers/all-mpnet-base-v2", + placeholder="Model name to use", + ) with gr.Column(scale=0.3): - splitter_chunk_size = gr.Textbox(label="splitter_chunk_size", value="500", placeholder="Maximum size of chunks to return") + splitter_chunk_size = gr.Textbox( + label="splitter_chunk_size", + value="500", + placeholder="Maximum size of chunks to return", + ) with gr.Column(scale=0.3): - returned_k = gr.Textbox(label="returned_k", value=1, placeholder="Number of retrieved chunks to return") - - + returned_k = gr.Textbox( + label="returned_k", + value=1, + placeholder="Number of retrieved chunks to return", + ) + with gr.Row(): with gr.Column(scale=0.2): rag_selector = gr.Checkbox(label="RAG", min_width=0) with gr.Column(scale=0.6): - rag_path = gr.Textbox(show_label=False, container=False, placeholder="The path of vectorstore", elem_classes="disable_status") + rag_path = gr.Textbox( + show_label=False, + container=False, + placeholder="The path of vectorstore", + elem_classes="disable_status", + ) with gr.Column(scale=0.2): regenerate_btn = gr.Button("Regenerate", min_width=0) with gr.Tab("Dialogue"): - chatbot_rag = gr.Chatbot(elem_id="chatbot", label="chatbot", elem_classes="disable_status") + chatbot_rag = gr.Chatbot( + elem_id="chatbot", + label="chatbot", + elem_classes="disable_status", + ) with gr.Row(): with gr.Column(scale=0.8): - msg_rag = gr.Textbox(show_label=False, container=False, - placeholder="Input your question and press Enter") + msg_rag = gr.Textbox( + show_label=False, + container=False, + placeholder="Input your question and press Enter", + ) with gr.Column(scale=0.2, min_width=0): - latency_status_rag = gr.Markdown(""" + latency_status_rag = gr.Markdown( + """ | | | |---|---| | Total Latency [s] | - | - | Tokens | - |""", elem_classes=["disable_status", "output-stats", "disablegenerating", "div_height"]) + | Tokens | - |""", + elem_classes=[ + "disable_status", + "output-stats", + "disablegenerating", + "div_height", + ], + ) with gr.Row(): with gr.Column(scale=0.5, min_width=0): send_btn_rag = gr.Button("Send") @@ -756,51 +1146,88 @@ def _init_ui(self): with gr.Column(scale=0.1, min_width=45): with gr.Row(): node_pic = r"./ui/images/Picture2.png" - gr.Image(type="pil", value=node_pic, show_label=False, min_width=45, height=45, width=45, elem_id="notshowimg", container=False) + gr.Image( + type="pil", + value=node_pic, + show_label=False, + min_width=45, + height=45, + width=45, + elem_id="notshowimg", + container=False, + ) with gr.Row(): - gr.HTML("

Ray Cluster

") + gr.HTML( + "

Ray Cluster

" + ) with gr.Column(scale=0.9): with gr.Row(): with gr.Column(scale=0.05, min_width=40): gr.HTML("

cpu core

") with gr.Column(): - gr.HTML(self.get_ray_cluster, elem_classes="disablegenerating", every=2) - + gr.HTML( + self.get_ray_cluster, + elem_classes="disablegenerating", + every=2, + ) + stop_btn = [] node_status = [] node_index = [] for index in range(len(self.ray_nodes)): - if self.ray_nodes[index]["Alive"] == False: + if self.ray_nodes[index]["Alive"] is False: continue node_ip = self.ray_nodes[index]["NodeName"] with gr.Row(): with gr.Column(scale=0.1, min_width=25): with gr.Row(): - if index==0: + if index == 0: func = lambda: self.watch_node_status(index=0) - elif index==1: + elif index == 1: func = lambda: self.watch_node_status(index=1) - elif index==2: + elif index == 2: func = lambda: self.watch_node_status(index=2) - elif index==3: + elif index == 3: func = lambda: self.watch_node_status(index=3) - node_status.append(gr.HTML(func, elem_classes="statusstyle", every=2)) + + node_status.append( + gr.HTML(func, elem_classes="statusstyle", every=2) + ) with gr.Row(): node_index.append(gr.Text(value=len(stop_btn), visible=False)) if node_ip == self.head_node_ip: - stop_btn.append(gr.Button("Kill", interactive=False, elem_classes="btn-style")) + stop_btn.append( + gr.Button( + "Kill", + interactive=False, + elem_classes="btn-style", + ) + ) else: stop_btn.append(gr.Button("Kill", elem_classes="btn-style")) with gr.Column(scale=0.065, min_width=45): with gr.Row(): node_pic = r"./ui/images/Picture1.png" - gr.Image(type="pil", value=node_pic, show_label=False, min_width=45, height=45, width=45, elem_id="notshowimg", container=False) + gr.Image( + type="pil", + value=node_pic, + show_label=False, + min_width=45, + height=45, + width=45, + elem_id="notshowimg", + container=False, + ) with gr.Row(): if node_ip == self.head_node_ip: - gr.HTML("

head node

") + gr.HTML( + "

head node

" + ) else: - gr.HTML("

work node

") + gr.HTML( + "

work node

" + ) with gr.Column(scale=0.835): with gr.Row(): with gr.Column(scale=0.05, min_width=40): @@ -808,108 +1235,243 @@ def _init_ui(self): gr.HTML("

") gr.HTML("

memory

") with gr.Column(): - if index==0: + if index == 0: func = lambda: self.get_cpu_memory(index=0) - elif index==1: + elif index == 1: func = lambda: self.get_cpu_memory(index=1) - elif index==2: + elif index == 2: func = lambda: self.get_cpu_memory(index=2) - elif index==3: + elif index == 3: func = lambda: self.get_cpu_memory(index=3) + gr.HTML(func, elem_classes="disablegenerating", every=2) msg.submit(self.user, [msg, chatbot], [msg, chatbot], queue=False).then( - self.bot, [chatbot, deployed_model_endpoint, max_new_tokens, Temperature, Top_p, Top_k], - [chatbot, latency_status] + self.bot, + [ + chatbot, + deployed_model_endpoint, + max_new_tokens, + Temperature, + Top_p, + Top_k, + ], + [chatbot, latency_status], ) clear_btn.click(self.clear, None, [chatbot, latency_status], queue=False) send_btn.click(self.user, [msg, chatbot], [msg, chatbot], queue=False).then( - self.bot, [chatbot, deployed_model_endpoint, max_new_tokens, Temperature, Top_p, Top_k], - [chatbot, latency_status] + self.bot, + [ + chatbot, + deployed_model_endpoint, + max_new_tokens, + Temperature, + Top_p, + Top_k, + ], + [chatbot, latency_status], ) - regenerate_btn.click(self.regenerate, [rag_path, data_web_urls, data_pdfs, embedding_model, splitter_chunk_size, cpus_per_worker_deploy], [rag_path]) + regenerate_btn.click( + self.regenerate, + [ + rag_path, + data_web_urls, + data_pdfs, + embedding_model, + splitter_chunk_size, + cpus_per_worker_deploy, + ], + [rag_path], + ) clear_btn_rag.click(self.clear, None, [chatbot_rag, latency_status_rag], queue=False) rag_selector.select(self.set_rag_default_path, [rag_selector, rag_path], rag_path) - msg_rag.submit(self.user, [msg_rag, chatbot_rag], [msg_rag, chatbot_rag], queue=False).then( - self.bot_rag, [chatbot_rag, deployed_model_endpoint, max_new_tokens_rag, Temperature_rag, Top_p_rag, Top_k_rag, rag_selector, rag_path, returned_k], - [chatbot_rag, latency_status_rag] + msg_rag.submit( + self.user, [msg_rag, chatbot_rag], [msg_rag, chatbot_rag], queue=False + ).then( + self.bot_rag, + [ + chatbot_rag, + deployed_model_endpoint, + max_new_tokens_rag, + Temperature_rag, + Top_p_rag, + Top_k_rag, + rag_selector, + rag_path, + returned_k, + ], + [chatbot_rag, latency_status_rag], ) - send_btn_rag.click(self.user, [msg_rag, chatbot_rag], [msg_rag, chatbot_rag], queue=False).then( - self.bot_rag, [chatbot_rag, deployed_model_endpoint, max_new_tokens_rag, Temperature_rag, Top_p_rag, Top_k_rag, rag_selector, rag_path, returned_k], - [chatbot_rag, latency_status_rag] + send_btn_rag.click( + self.user, [msg_rag, chatbot_rag], [msg_rag, chatbot_rag], queue=False + ).then( + self.bot_rag, + [ + chatbot_rag, + deployed_model_endpoint, + max_new_tokens_rag, + Temperature_rag, + Top_p_rag, + Top_k_rag, + rag_selector, + rag_path, + returned_k, + ], + [chatbot_rag, latency_status_rag], ) for i in range(self.test_replica): - send_all_btn.click(self.user, [msgs[i], chatbots[i]], [msgs[i], chatbots[i]], queue=False).then( - self.send_all_bot, [ids[i], chatbots[i], deployed_model_endpoint, max_new_tokens, Temperature, Top_p, Top_k], - chatbots[i] + send_all_btn.click( + self.user, + [msgs[i], chatbots[i]], + [msgs[i], chatbots[i]], + queue=False, + ).then( + self.send_all_bot, + [ + ids[i], + chatbots[i], + deployed_model_endpoint, + max_new_tokens, + Temperature, + Top_p, + Top_k, + ], + chatbots[i], ) for i in range(self.test_replica): reset_all_btn.click(self.reset, [ids[i]], [msgs[i], chatbots[i]], queue=False) - + for i in range(len(stop_btn)): - stop_btn[i].click(self.kill_node, [stop_btn[i], node_index[i]], [stop_btn[i], deployed_model_endpoint]) + stop_btn[i].click( + self.kill_node, + [stop_btn[i], node_index[i]], + [stop_btn[i], deployed_model_endpoint], + ) - base_model_dropdown.select(self.set_custom_model, [base_model_dropdown], [custom_model_name, custom_tokenizer_name]) - finetune_event = finetune_btn.click(self.finetune, [base_model_dropdown, custom_model_name, custom_tokenizer_name, data_url, finetuned_model_name, batch_size, num_epochs, max_train_step, lr, worker_num, cpus_per_worker_ftn], [all_model_dropdown]) - finetune_progress_event = finetune_btn.click(self.finetune_progress, None, [finetune_res]) - stop_finetune_btn.click(fn=self.shutdown_finetune, inputs=None, outputs=None, cancels=[finetune_event, finetune_progress_event]) - deploy_event = deploy_btn.click(self.deploy_func, [all_model_dropdown, replica_num, cpus_per_worker_deploy], [deployed_model_endpoint]) - stop_deploy_btn.click(fn=self.shutdown_deploy, inputs=None, outputs=None, cancels=[deploy_event]) + base_model_dropdown.select( + self.set_custom_model, + [base_model_dropdown], + [custom_model_name, custom_tokenizer_name], + ) + finetune_event = finetune_btn.click( + self.finetune, + [ + base_model_dropdown, + custom_model_name, + custom_tokenizer_name, + data_url, + finetuned_model_name, + batch_size, + num_epochs, + max_train_step, + lr, + worker_num, + cpus_per_worker_ftn, + ], + [all_model_dropdown], + ) + finetune_progress_event = finetune_btn.click( + self.finetune_progress, None, [finetune_res] + ) + stop_finetune_btn.click( + fn=self.shutdown_finetune, + inputs=None, + outputs=None, + cancels=[finetune_event, finetune_progress_event], + ) + deploy_event = deploy_btn.click( + self.deploy_func, + [all_model_dropdown, replica_num, cpus_per_worker_deploy], + [deployed_model_endpoint], + ) + stop_deploy_btn.click( + fn=self.shutdown_deploy, + inputs=None, + outputs=None, + cancels=[deploy_event], + ) gr.Markdown(foot_content) self.gr_chat = gr_chat + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Web UI for LLMs", add_help=True) - parser.add_argument("--finetune_model_path", default="./", type=str, help="Where to save the finetune model.") - parser.add_argument("--finetune_checkpoint_path", default="", type=str, help="Where to save checkpoints.") - parser.add_argument("--default_rag_path", default="./vector_store/", type=str, help="The path of vectorstore used by RAG.") - parser.add_argument("--node_port", default="22", type=str, help="The node port that ssh connects.") - parser.add_argument("--node_user_name", default="root", type=str, help="The node user name that ssh connects.") - parser.add_argument("--conda_env_name", default="test_gradio", type=str, help="The environment used to execute ssh commands.") - parser.add_argument("--master_ip_port", default="None", type=str, help="The ip:port of head node to connect when restart a worker node.") + parser = argparse.ArgumentParser(description="Web UI for LLM on Ray", add_help=True) + parser.add_argument( + "--finetune_model_path", + default="./", + type=str, + help="Where to save the finetune model.", + ) + parser.add_argument( + "--finetune_checkpoint_path", + default="", + type=str, + help="Where to save checkpoints.", + ) + parser.add_argument( + "--default_rag_path", + default="./vector_store/", + type=str, + help="The path of vectorstore used by RAG.", + ) + parser.add_argument( + "--node_port", default="22", type=str, help="The node port that ssh connects." + ) + parser.add_argument( + "--node_user_name", + default="root", + type=str, + help="The node user name that ssh connects.", + ) + parser.add_argument( + "--conda_env_name", + default="base", + type=str, + help="The environment used to execute ssh commands.", + ) + parser.add_argument( + "--master_ip_port", + default="None", + type=str, + help="The ip:port of head node to connect when restart a worker node.", + ) args = parser.parse_args() file_path = os.path.abspath(__file__) infer_path = os.path.dirname(file_path) repo_path = os.path.abspath(infer_path + os.path.sep + "../") - default_data_path = os.path.abspath(infer_path + os.path.sep + "../examples/data/sample_finetune_data.jsonl") + default_data_path = os.path.abspath( + infer_path + os.path.sep + "../examples/data/sample_finetune_data.jsonl" + ) sys.path.append(repo_path) from finetune.finetune import get_accelerate_environment_variable - finetune_config = { - "General": { - "config": {} - }, - "Dataset": { - "validation_file": None, - "validation_split_percentage": 0 - }, + + finetune_config: Dict[str, Any] = { + "General": {"config": {}}, + "Dataset": {"validation_file": None, "validation_split_percentage": 0}, "Training": { "optimizer": "AdamW", "lr_scheduler": "linear", "weight_decay": 0.0, "device": "CPU", "num_training_workers": 2, - "resources_per_worker": { - "CPU": 24 - }, - "accelerate_mode": "CPU_DDP" + "resources_per_worker": {"CPU": 24}, + "accelerate_mode": "CPU_DDP", }, - "failure_config": { - "max_failures": 5 - } + "failure_config": {"max_failures": 5}, } - ray_init_config = { + ray_init_config: Dict[str, Any] = { "runtime_env": { "env_vars": { - "OMP_NUM_THREADS": "24", - "ACCELERATE_USE_CPU": "True", + "OMP_NUM_THREADS": "24", + "ACCELERATE_USE_CPU": "True", "ACCELERATE_MIXED_PRECISION": "no", "CCL_WORKER_COUNT": "1", "CCL_LOG_LEVEL": "info", @@ -919,7 +1481,9 @@ def _init_ui(self): "address": "auto", "_node_ip_address": "127.0.0.1", } - accelerate_env_vars = get_accelerate_environment_variable(finetune_config["Training"]["accelerate_mode"]) + accelerate_env_vars = get_accelerate_environment_variable( + finetune_config["Training"]["accelerate_mode"], config=None + ) ray_init_config["runtime_env"]["env_vars"].update(accelerate_env_vars) context = ray.init(**ray_init_config) head_node_ip = context.get("address").split(":")[0] @@ -928,6 +1492,22 @@ def _init_ui(self): finetune_checkpoint_path = args.finetune_checkpoint_path default_rag_path = args.default_rag_path - initial_model_list = {k : all_models[k] for k in sorted(all_models.keys())} - ui = ChatBotUI(initial_model_list, initial_model_list, finetune_model_path, finetune_checkpoint_path, repo_path, default_data_path, default_rag_path, finetune_config, head_node_ip, args.node_port, args.node_user_name, args.conda_env_name, args.master_ip_port) - ui.gr_chat.queue(concurrency_count=10).launch(share=True, server_port=8080, server_name="0.0.0.0") + initial_model_list = {k: all_models[k] for k in sorted(all_models.keys())} + ui = ChatBotUI( + initial_model_list, + initial_model_list, + finetune_model_path, + finetune_checkpoint_path, + repo_path, + default_data_path, + default_rag_path, + finetune_config, + head_node_ip, + args.node_port, + args.node_user_name, + args.conda_env_name, + args.master_ip_port, + ) + ui.gr_chat.queue(concurrency_count=10).launch( + share=True, server_port=8080, server_name="0.0.0.0" + )