From 309bb63de7ef99a0287a0f13107f640afa426d98 Mon Sep 17 00:00:00 2001 From: harborn Date: Wed, 29 May 2024 16:53:34 +0800 Subject: [PATCH] [Refactor] Remove dataset/tokenizer/model packaging under common, make pretrain/finetune scripts without common dependency. (#233) * update * update * update * update * update * update * update * add license header * update --- docs/finetune_parameters.md | 1 + llm_on_ray/finetune/finetune.py | 386 +++++++++++--------- llm_on_ray/finetune/finetune_config.py | 8 +- llm_on_ray/finetune/template.py | 64 ++++ llm_on_ray/pretrain/plugin/hf_pretrainer.py | 4 +- llm_on_ray/pretrain/pretrain.py | 12 +- 6 files changed, 298 insertions(+), 177 deletions(-) create mode 100644 llm_on_ray/finetune/template.py diff --git a/docs/finetune_parameters.md b/docs/finetune_parameters.md index 4f113e69f..56f4e6eaf 100644 --- a/docs/finetune_parameters.md +++ b/docs/finetune_parameters.md @@ -10,6 +10,7 @@ The following are the parameters supported in the finetuning workflow. |tokenizer_name|None|Path to pretrained tokenizer from huggingface.co/models. If not provided, the tokenizer will be loaded from the `base_model`.| |gpt_base_model|True|This parameter is for [Transformers#22482](https://github.com/huggingface/transformers/issues/22482). It needs to be set to True when the pretrained model is realted to gpt, otherwise it is False.| |output_dir|/tmp/llm-ray/output|The output directory to store the finetuned model.| +|report_to|none|The list of integrations to report the results and logs to. Possible values are: "none", "tensorboard".| |resume_from_checkpoint|null|The path to a folder with a valid checkpoint for your model.| |save_strategy|no|The checkpoint save strategy to adopt during training. Possible values are: "no", "epoch", "steps".| |config|trust_remote_code: False
use_auth_token: None|Will be passed to the transformers `from_pretrained()` method| diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py index a2d6b60f2..eb8cea170 100644 --- a/llm_on_ray/finetune/finetune.py +++ b/llm_on_ray/finetune/finetune.py @@ -21,10 +21,14 @@ import sys from typing import Any, Dict, Union, Optional -import torch +from itertools import chain +import torch +import datasets import transformers +from peft import get_peft_model, LoraConfig + import ray from ray.train.torch import TorchTrainer from ray.air.config import ScalingConfig @@ -33,92 +37,24 @@ from pydantic_yaml import parse_yaml_raw_as from llm_on_ray import common +from llm_on_ray.finetune import template from llm_on_ray.finetune.finetune_config import FinetuneConfig from importlib import util -use_habana = False -if util.find_spec("habana_frameworks") is not None: - from optimum.habana.utils import set_seed - - use_habana = True -else: - from accelerate.utils import set_seed, is_xpu_available - use_habana = False +def set_seed(config): + seed = config["Training"].get("seed", None) + if seed is None: + return + device = config["Training"]["device"] + if device in ["cpu", "gpu"]: + from accelerate.utils import set_seed as _set_seed + _set_seed(seed) + elif device in ["hpu"]: + from optimum.habana.utils import set_seed as _set_seed -def get_accelerate_environment_variable(config: Dict[str, Any]) -> dict: - device = config["Training"]["device"] - accelerate_mode = config["Training"]["accelerate_mode"] - mixed_precision = config["Training"]["mixed_precision"] - mode_env_vars = { - "cpu": { - "DDP": { - "ACCELERATE_USE_CPU": "true", - "ACCELERATE_USE_IPEX": "true", - "ACCELERATE_MIXED_PRECISION": mixed_precision, - } - }, - "gpu": { - "DDP": { - "ACCELERATE_USE_CPU": "false", - "ACCELERATE_USE_XPU": "true", - "ACCELERATE_USE_IPEX": "true", - "ACCELERATE_MIXED_PRECISION": mixed_precision, - }, - "FSDP": { - "ACCELERATE_USE_CPU": "false", - "ACCELERATE_USE_XPU": "true", - "ACCELERATE_USE_IPEX": "true", - "ACCELERATE_USE_FSDP": "true", - "FSDP_SHARDING_STRATEGY": "1", - "FSDP_OFFLOAD_PARAMS": "false", - "FSDP_AUTO_WRAP_POLICY": "NO_WRAP", - "FSDP_BACKWARD_PREFETCH": "BACKWARD_PRE", - "FSDP_STATE_DICT_TYPE": "SHARDED_STATE_DICT", - "FSDP_FORWARD_PREFETCH": "false", - "FSDP_USE_ORIG_PARAMS": "false", - "FSDP_SYNC_MODULE_STATES": "true", - "ACCELERATE_MIXED_PRECISION": mixed_precision, - }, - "DEEPSPEED": { - "ACCELERATE_USE_CPU": "false", - "ACCELERATE_USE_XPU": "true", - "ACCELERATE_USE_IPEX": "true", - "ACCELERATE_USE_DEEPSPEED": "true", - "ACCELERATE_MIXED_PRECISION": mixed_precision, - }, - }, - "hpu": { - "DDP": { - "ACCELERATE_USE_CPU": "false", - "ACCELERATE_USE_XPU": "false", - "ACCELERATE_USE_IPEX": "false", - "ACCELERATE_MIXED_PRECISION": mixed_precision, - }, - "DEEPSPEED": { - "ACCELERATE_USE_CPU": "false", - "ACCELERATE_USE_XPU": "false", - "ACCELERATE_USE_IPEX": "false", - "ACCELERATE_USE_DEEPSPEED": "true", - "ACCELERATE_MIXED_PRECISION": mixed_precision, - }, - }, - } - if device not in mode_env_vars or accelerate_mode not in mode_env_vars[device]: - supported_mode_info = "" - for k in mode_env_vars.keys(): - supported_mode_info += k + ":[" - for m in mode_env_vars[k]: - supported_mode_info += m + "," - supported_mode_info = supported_mode_info[:-1] - supported_mode_info += "]," - supported_mode_info = supported_mode_info[:-1] - - raise ValueError( - f"device {device} and accelerate mode {accelerate_mode} not supported. supported device and accelerate mode is {supported_mode_info}" - ) - return mode_env_vars[device][accelerate_mode] + _set_seed(seed) def convert_to_training_args(cls, config): @@ -128,6 +64,7 @@ def convert_to_training_args(cls, config): args = { "output_dir": config["General"]["output_dir"], + "report_to": config["General"]["report_to"], "resume_from_checkpoint": config["General"]["resume_from_checkpoint"], "gradient_checkpointing": config["General"]["enable_gradient_checkpointing"], "save_strategy": save_strategy if save_strategy != "False" else "no", @@ -141,8 +78,15 @@ def convert_to_training_args(cls, config): "lr_scheduler_type": config["Training"]["lr_scheduler"], "weight_decay": config["Training"]["weight_decay"], "gradient_accumulation_steps": config["Training"]["gradient_accumulation_steps"], + "do_train": True, } + # set attr do_eval + vf = config["Dataset"].get("validation_file", None) + vsp = config["Dataset"].get("validation_split_percentage", 0) + if vf is not None or (vsp / 100 > 0.0 and vsp / 100 < 1.0): + args.update({"do_eval": True}) + # set attr max_steps if config["Training"]["max_train_steps"] is not None: args.update({"max_steps": config["Training"]["max_train_steps"]}) @@ -172,15 +116,6 @@ def convert_to_training_args(cls, config): return cls(**args) -def get_device_environment_variable(device): - if device == "hpu": - return { - "HABANA_VISIBLE_DEVICES": "all", - "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES": "true", - } - return {} - - def convert_dtype(dtype: str) -> Optional[torch.dtype]: supported_dtypes = { "fp16": torch.float16, @@ -190,67 +125,175 @@ def convert_dtype(dtype: str) -> Optional[torch.dtype]: return supported_dtypes[dtype] -def train_func(config: Dict[str, Any]): - os.chdir(config["cwd"]) - - device = config["Training"]["device"] - - base_model = config["General"]["base_model"] +def load_tokenizer(config: Dict): if config["General"].get("tokenizer_name") is not None: tokenizer_name = config["General"].get("tokenizer_name") else: - tokenizer_name = base_model - dataset_file = config["Dataset"]["train_file"] + tokenizer_name = config["General"]["base_model"] + load_config = config["General"].get("config", {}) + tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name, **load_config) + return tokenizer + + +def load_dataset(config: Dict): + dataset_file = config["Dataset"].get("train_file", None) + if dataset_file is None: + return + + if os.path.exists(dataset_file): + # load from local file + def local_load(name, **load_config): + if os.path.isfile(name): + file = os.path.basename(os.path.abspath(name)) + path = os.path.dirname(os.path.abspath(name)) + dataset = datasets.load_dataset(path, data_files=file, **load_config) + else: + dataset = datasets.load_dataset(name, **load_config) + return dataset["train"] + + train_dataset = local_load(dataset_file) + validation_file = config["Dataset"].get("validation_file", None) + if validation_file is not None: + validation_dataset = local_load(validation_file) + return datasets.DatasetDict({"train": train_dataset, "validation": validation_dataset}) + + validation_split_percentage = config["Dataset"].get("validation_split_percentage", 0) + if validation_split_percentage / 100 > 0.0 and validation_split_percentage / 100 < 1.0: + dataset_dict = train_dataset.train_test_split( + test_size=validation_split_percentage / 100 + ) + dataset_dict["validation"] = dataset_dict["test"] + return dataset_dict - seed = config["Training"].get("seed") - if seed is not None: - set_seed(seed) + return datasets.DatasetDict({"train": train_dataset}) + else: + # try to download and load dataset from huggingface.co + load_config = config["General"].get("config", {}) + use_auth_token = load_config.get("use_auth_token", None) + raw_dataset = datasets.load_dataset(dataset_file, use_auth_token=use_auth_token) + + validation_split_percentage = config["Dataset"].get("validation_split_percentage", 0) + if "validation" not in raw_dataset.keys() and ( + validation_split_percentage / 100 > 0.0 and validation_split_percentage / 100 < 1.0 + ): + dataset_dict = raw_dataset["train"].train_test_split( + test_size=validation_split_percentage / 100 + ) + dataset_dict["validation"] = dataset_dict["test"] + return dataset_dict + + return raw_dataset + + +def tokenize_dataset(config: Dict, tokenizer, dataset): + max_length = config["Dataset"].get("max_length", 512) + group = config["Dataset"].get("group", True) + block_size = config["Dataset"].get("block_size", 512) + tokenizer.pad_token = tokenizer.eos_token + + if isinstance(dataset, datasets.Dataset): + column_names = dataset.column_names + + if isinstance(dataset, datasets.DatasetDict): + column_names = dataset["train"].column_names + + if column_names and template.TEXT_COLUMN_NAME not in column_names: + + def prompt(rec): + instruction = rec["instruction"] + response = rec["response"] + context = rec.get("context") + if not instruction: + raise ValueError(f"Expected an instruction in: {rec}") + if not response: + raise ValueError(f"Expected a response in: {rec}") + if context: + rec["text"] = template.PROMPT_WITH_INPUT_FORMAT.format( + instruction=instruction, response=response, input=context + ) + else: + rec["text"] = template.PROMPT_NO_INPUT_FORMAT.format( + instruction=instruction, response=response + ) + return rec + + dataset = dataset.map( + prompt, + load_from_cache_file=False, + desc="Prompt", + ) + column_names += [template.TEXT_COLUMN_NAME] - tokenizer = common.tokenizer.Tokenizer.registory.get("HuggingFaceTokenizer")()( - config={ - "name": tokenizer_name, - "config": config["General"]["config"], - } - ) + def tokenize_function(examples): + return tokenizer(examples[template.TEXT_COLUMN_NAME], max_length=max_length) - datasets = common.dataset.Dataset.registory.get("HuggingfaceDataset")()( - config={ - "name": dataset_file, - "validation_file": config["Dataset"]["validation_file"], - "validation_split_percentage": config["Dataset"]["validation_split_percentage"], - } + tokenized_dataset = dataset.map( + tokenize_function, + remove_columns=column_names, + load_from_cache_file=False, + desc="Tokenize dataset", ) - dataprocesser = common.dataprocesser.DataProcesser.registory.get("GeneralProcesser")( - config={ - "per_device_train_batch_size": config["Training"]["batch_size"], - "per_device_eval_batch_size": config["Training"]["batch_size"], - "preprocessing_num_workers": config["Dataset"].get("preprocessing_num_workers", 1), - "max_length": config["Dataset"].get("max_length", 512), - "group": config["Dataset"].get("group", True), - "block_size": config["Dataset"].get("block_size", 512), - "shuffle": config["Dataset"].get("shuffle", False), - } - ) - tokenized_datasets = dataprocesser.tokenize_dataset(tokenizer, datasets) - - model = common.model.Model.registory.get("HuggingFaceModelForCausalLM")()( - config={ - "name": base_model, - "dtype": convert_dtype(config["Training"].get("mixed_precision", "no")), - "device": torch.device(device), - "config": config["General"]["config"], - "enable_gradient_checkpointing": config["General"].get( - "enable_gradient_checkpointing", False - ), - "lora_config": config["General"].get("lora_config", None), - } - ) + if group: + + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + tokenized_dataset = tokenized_dataset.map( + group_texts, + batched=True, + load_from_cache_file=False, + desc=f"Grouping texts in chunks of {block_size}", + ) - data_collator = common.dataprocesser.general_processer.DataCollatorForCompletionOnlyLM( + return tokenized_dataset + + +def prepare_data_collator(config: Dict, tokenizer): + return transformers.DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 ) + +def load_model(config: Dict): + model_name = config["General"]["base_model"] + model_dtype = convert_dtype(config["Training"].get("mixed_precision", "no")) + model_config = config["General"].get("config", {}) + model = transformers.AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=model_dtype, **model_config + ) + + lora_config = config["General"].get("lora_config", None) + if lora_config: + peft_config = LoraConfig(**lora_config) + model = get_peft_model(model, peft_config) + + egc = config["General"].get("enable_gradient_checkpointing", False) + if egc: + model.enable_input_require_grads() + model.gradient_checkpointing_enable() + model.config.use_cache = False + + model.to(dtype=model_dtype, device=torch.device(config["Training"]["device"])) + + return model + + +def get_trainer(config: Dict, model, tokenizer, tokenized_dataset, data_collator): + device = config["Training"]["device"] if device in ["cpu", "gpu"]: from transformers import Trainer, TrainingArguments @@ -258,49 +301,54 @@ def train_func(config: Dict[str, Any]): trainer = Trainer( model=model, args=training_args, - train_dataset=tokenized_datasets["train"], - eval_dataset=tokenized_datasets["validation"] - if tokenized_datasets.get("validation") is not None + train_dataset=tokenized_dataset["train"], + eval_dataset=tokenized_dataset["validation"] + if tokenized_dataset.get("validation") is not None else None, tokenizer=tokenizer, data_collator=data_collator, ) - - common.logger.info("train start") - trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) - trainer.save_model() - common.logger.info("train finish") + return training_args, trainer elif device in ["hpu"]: from optimum.habana.transformers import GaudiTrainer from optimum.habana.transformers import GaudiTrainingArguments - from optimum.habana import GaudiConfig - # If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config - if config["general"].get("gaudi_config_name") is not None: - gaudi_config = GaudiConfig.from_pretrained( - config["general"].get("gaudi_config_name"), - ) - else: - gaudi_config = GaudiConfig() - gaudi_config.use_fused_adam = True - gaudi_config.use_fused_clip_norm = True training_args = convert_to_training_args(GaudiTrainingArguments, config) trainer = GaudiTrainer( model=model, args=training_args, - gaudi_config=gaudi_config, - train_dataset=tokenized_datasets["train"], - eval_dataset=tokenized_datasets["validation"] - if tokenized_datasets.get("validation") is not None + train_dataset=tokenized_dataset["train"], + eval_dataset=tokenized_dataset["validation"] + if tokenized_dataset.get("validation") is not None else None, tokenizer=tokenizer, data_collator=data_collator, ) + return training_args, trainer + return None + + +def train_func(config: Dict[str, Any]): + os.chdir(config["cwd"]) + + set_seed(config) - common.logger.info("train start") - trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) - trainer.save_model() - common.logger.info("train finish") + tokenizer = load_tokenizer(config) + + dataset = load_dataset(config) + + tokenized_dataset = tokenize_dataset(config, tokenizer, dataset) + + data_collator = prepare_data_collator(config, tokenizer) + + model = load_model(config) + + training_args, trainer = get_trainer(config, model, tokenizer, tokenized_dataset, data_collator) + + common.logger.info("train start") + trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + common.logger.info("train finish") def get_finetune_config(): @@ -356,7 +404,6 @@ def main(external_config=None): "CCL_ZE_IPC_EXCHANGE": "sockets", "CCL_WORKER_COUNT": str(ccl_worker_count), "CCL_LOG_LEVEL": "info", - "WORLD_SIZE": str(num_training_workers), "FI_TCP_IFACE": "lo", "FI_PROVIDER": "tcp", } @@ -384,8 +431,11 @@ def main(external_config=None): # if try to use Intel GPU, convert device to 'xpu' # due to accelerate internal use 'xpu' represent Intel GPU - if device == "gpu" and is_xpu_available(): - device = "xpu" + if device == "gpu": + from accelerate.utils import is_xpu_available + + if is_xpu_available(): + device = "xpu" if config.get("torch_config", None) is None: backend = None diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index e78600a6d..c44e2e57d 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -15,7 +15,7 @@ # from pydantic import BaseModel, validator -from typing import Optional, List +from typing import Optional, List, Union PRECISION_BF16 = "bf16" @@ -56,6 +56,7 @@ class General(BaseModel): gaudi_config_name: Optional[str] = None gpt_base_model: bool output_dir: str + report_to: str = "none" resume_from_checkpoint: Optional[str] = None save_strategy: str = "no" config: GeneralConfig @@ -63,6 +64,11 @@ class General(BaseModel): deltatuner_config: Optional[DeltatunerConfig] = None enable_gradient_checkpointing: bool = False + @validator("report_to") + def check_report_to(cls, v: str): + assert v in ["none", "tensorboard"] + return v + class Dataset(BaseModel): train_file: str diff --git a/llm_on_ray/finetune/template.py b/llm_on_ray/finetune/template.py new file mode 100644 index 000000000..cf8647d7f --- /dev/null +++ b/llm_on_ray/finetune/template.py @@ -0,0 +1,64 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#!/usr/bin/env python + +INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request." +INSTRUCTION_KEY = "### Instruction:" +INPUT_KEY = "Input:" +RESPONSE_KEY = "### Response:" +END_KEY = "### End" +RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" + +PROMPT_NO_INPUT_FORMAT = """{intro} + +{instruction_key} +{instruction} + +{response_key} +{response} + +{end_key}""".format( + intro=INTRO_BLURB, + instruction_key=INSTRUCTION_KEY, + instruction="{instruction}", + response_key=RESPONSE_KEY, + response="{response}", + end_key=END_KEY, +) + +PROMPT_WITH_INPUT_FORMAT = """{intro} + +{instruction_key} +{instruction} + +{input_key} +{input} + +{response_key} +{response} + +{end_key}""".format( + intro=INTRO_BLURB, + instruction_key=INSTRUCTION_KEY, + instruction="{instruction}", + input_key=INPUT_KEY, + input="{input}", + response_key=RESPONSE_KEY, + response="{response}", + end_key=END_KEY, +) +TEXT_COLUMN_NAME = "text" diff --git a/llm_on_ray/pretrain/plugin/hf_pretrainer.py b/llm_on_ray/pretrain/plugin/hf_pretrainer.py index 66c6806f2..21eaa4b19 100755 --- a/llm_on_ray/pretrain/plugin/hf_pretrainer.py +++ b/llm_on_ray/pretrain/plugin/hf_pretrainer.py @@ -176,14 +176,14 @@ def train(self): model_config = self.config.get("model") model_config["deepspeed_zero_stage"] = training_args.deepspeed_plugin.zero_stage if model_config: - self.model = common.load_model(model_config) + self.model = common.load.load_model(model_config) else: common.logger.warn("No internal model plugin provided") self.model.train() tokenizer_config = self.config.get("tokenizer") if tokenizer_config: - self.tokenizer = common.load_tokenizer(tokenizer_config) + self.tokenizer = common.load.load_tokenizer(tokenizer_config) else: common.logger.warn("No internal tokenizer plugin provided") diff --git a/llm_on_ray/pretrain/pretrain.py b/llm_on_ray/pretrain/pretrain.py index 74ed16b37..8b2cfa11b 100644 --- a/llm_on_ray/pretrain/pretrain.py +++ b/llm_on_ray/pretrain/pretrain.py @@ -52,7 +52,7 @@ def train_func(config: Dict[str, Any]): initializer_config = config.get("initializer") if initializer_config: try: - initializer = common.get_initializer(initializer_config) + initializer = common.load.get_initializer(initializer_config) initializer.init() except Exception as e: common.logger.critical(e, exc_info=True) @@ -77,32 +77,32 @@ def train_func(config: Dict[str, Any]): datasets_config = config.get("datasets") if datasets_config: - datasets = common.load_dataset(datasets_config) + datasets = common.load.load_dataset(datasets_config) common.logger.info(" ") else: common.logger.warn("No datasets plugin provided, use the built-in datasets of trainer") tokenizer_config = config.get("tokenizer") if tokenizer_config: - tokenizer = common.load_tokenizer(tokenizer_config) + tokenizer = common.load.load_tokenizer(tokenizer_config) else: common.logger.warn("No tokenizer plugin provided, use the built-in tokenizer of trainer") model_config = config.get("model") if model_config: - model = common.load_model(model_config) + model = common.load.load_model(model_config) else: common.logger.warn("No model plugin provided, use the built-in model of trainer") optimizer_config = config.get("optimizer") if optimizer_config: - optimizer = common.load_optimizer(model, config.get("optimizer")) + optimizer = common.load.load_optimizer(model, config.get("optimizer")) else: common.logger.warn("No optimizer plugin provided, use the built-in optimizer of trainer") trainer_config = config.get("trainer") if trainer_config: - trainer = common.get_trainer(config.get("trainer")) + trainer = common.load.get_trainer(config.get("trainer")) try: trainer.prepare(model, tokenizer, datasets, optimizer, accelerator)