Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Finetune] Integrate Chat template #178

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/finetune_parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The following are the parameters supported in the finetuning workflow.
|lora_config|task_type: CAUSAL_LM<br>r: 8<br>lora_alpha: 32<br>lora_dropout: 0.1|Will be passed to the LoraConfig `__init__()` method, then it'll be used as config to build Peft model object.|
|deltatuner_config|"algo": "lora"<br>"denas": True<br>"best_model_structure": "/path/to/best_structure_of_deltatuner_model"|Will be passed to the DeltaTunerArguments `__init__()` method, then it'll be used as config to build [Deltatuner model](https://github.com/intel/e2eAIOK/tree/main/e2eAIOK/deltatuner) object.|
|enable_gradient_checkpointing|False|enable gradient checkpointing to save GPU memory, but will cost more compute runtime|
|chat_template|None|User-defined chat template.|
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you compared the impact of different templates on fine-tuning performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not yet



## Dataset Parameters
Expand Down
11 changes: 10 additions & 1 deletion llm_on_ray/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,13 @@
from llm_on_ray.common.torch_config import TorchConfig
from llm_on_ray.common.config import Config
from llm_on_ray.common.init import init
from llm_on_ray.common import agentenv, dataset, initializer, model, optimizer, tokenizer, trainer
from llm_on_ray.common import (
agentenv,
dataset,
initializer,
model,
optimizer,
tokenizer,
trainer,
dataprocesser,
)
3 changes: 2 additions & 1 deletion llm_on_ray/common/dataprocesser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
#

from llm_on_ray.common.dataprocesser.dataprocesser import DataProcesser
from llm_on_ray.common.dataprocesser.general_processer import GeneralProcesser
from llm_on_ray.common.dataprocesser.general_processer import ChatDataPreprocess
from llm_on_ray.common.dataprocesser.general_processer import SlimOrcaDataPreprocess
from llm_on_ray.common.dataprocesser.rm_dataprocesser import RMDataProcesser


Expand Down
214 changes: 169 additions & 45 deletions llm_on_ray/common/dataprocesser/general_processer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
from llm_on_ray.common.dataprocesser import DataProcesser

INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
INSTRUCTION_KEY = "### Instruction:"
INPUT_KEY = "Input:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"
INSTRUCTION_KEY = "### Instruction: "
INPUT_KEY = "Input: "
RESPONSE_KEY = "### Response: "
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"

PROMPT_NO_INPUT_FORMAT = """{intro}
Expand All @@ -36,15 +35,12 @@
{instruction}

{response_key}
{response}

{end_key}""".format(
{response}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
response_key=RESPONSE_KEY,
response="{response}",
end_key=END_KEY,
)

PROMPT_WITH_INPUT_FORMAT = """{intro}
Expand All @@ -56,20 +52,22 @@
{input}

{response_key}
{response}

{end_key}""".format(
{response}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
input_key=INPUT_KEY,
input="{input}",
response_key=RESPONSE_KEY,
response="{response}",
end_key=END_KEY,
)
TEXT_COLUMN_NAME = "text"

SLIMORCA_PROMPT_DICT = {
"prompt_with_input": ("### System: {system} \n" "### User: {user} \n### Assistant: {gpt}"),
"prompt_without_input": ("### System: {system} \n" "### Assistant: {gpt}"),
}


class DataCollatorForCompletionOnlyLM(transformers.DataCollatorForLanguageModeling):
def torch_call(self, examples):
Expand Down Expand Up @@ -98,9 +96,74 @@ def torch_call(self, examples):
return batch


class GeneralProcesser(DataProcesser):
class ChatDataPreprocess(DataProcesser):
base_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.\n"""

def __init__(self, config):
super().__init__(config)
self.prompt_template = self.base_template
self.user = "### Instruction:\n"
self.assistant = "### Response:\n"
self.end = "### End\n"

def create_data(self, examples):
if self.config.get("gpt_base_model"):
instruction = examples["instruction"]
response = examples["response"]
context = examples.get("context")
if not instruction:
raise ValueError(f"Expected an instruction in: {examples}")
if not response:
raise ValueError(f"Expected a response in: {examples}")
if context:
new_messages = PROMPT_WITH_INPUT_FORMAT.format(
instruction=instruction, response=response, input=context
)
else:
new_messages = PROMPT_NO_INPUT_FORMAT.format(
instruction=instruction, response=response
)
else:
new_messages = [
{
"role": "system",
"content": INTRO_BLURB + "\n",
},
{
"role": "user",
"content": examples["instruction"]
+ "\n"
+ INPUT_KEY
+ examples["context"]
+ "\n",
},
{"role": "assistant", "content": examples["response"] + "\n"},
]

return new_messages

def tokenize_func(self, tokenizer, message):
if self.config.get("gpt_base_model"):
return tokenizer(
message, add_special_tokens=False, max_length=self.config.get("max_length")
)
else:
if self.config.get("chat_template") is not None:
tokenizer.chat_template = self.config.get("chat_template")
elif tokenizer.chat_template is not None:
pass
else:
tokenizer.chat_template = self.config.get("default_chat_template")

new_tokenizer = tokenizer.apply_chat_template(
message,
tokenize=False,
)
return tokenizer(
new_tokenizer, add_special_tokens=False, max_length=self.config.get("max_length")
)

def tokenize_dataset(self, tokenizer, dataset):
max_length = self.config.get("max_length")
group = self.config.get("group")
block_size = self.config.get("block_size")
tokenizer.pad_token = tokenizer.eos_token
Expand All @@ -111,38 +174,8 @@ def tokenize_dataset(self, tokenizer, dataset):
if isinstance(dataset, datasets.DatasetDict):
column_names = dataset["train"].column_names

if column_names and TEXT_COLUMN_NAME not in column_names:

def prompt(rec):
instruction = rec["instruction"]
response = rec["response"]
context = rec.get("context")
if not instruction:
raise ValueError(f"Expected an instruction in: {rec}")
if not response:
raise ValueError(f"Expected a response in: {rec}")
if context:
rec["text"] = PROMPT_WITH_INPUT_FORMAT.format(
instruction=instruction, response=response, input=context
)
else:
rec["text"] = PROMPT_NO_INPUT_FORMAT.format(
instruction=instruction, response=response
)
return rec

dataset = dataset.map(
prompt,
load_from_cache_file=False,
desc="Prompt",
)
column_names += [TEXT_COLUMN_NAME]

def tokenize_function(examples):
return tokenizer(examples[TEXT_COLUMN_NAME], max_length=max_length)

tokenized_datasets = dataset.map(
tokenize_function,
lambda examples: self.tokenize_func(tokenizer, self.create_data(examples)),
remove_columns=column_names,
load_from_cache_file=False,
desc="Tokenize dataset",
Expand Down Expand Up @@ -208,3 +241,94 @@ def prepare_dataloader(self, tokenizer, dataset):
}
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, **eval_dataloader_params)
return train_dataloader, eval_dataloader


class SlimOrcaDataPreprocess(ChatDataPreprocess):
chat_template = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"{{ '### System: ' + message['content'] }}"
"{% elif message['role'] == 'user' %}"
"{{ '### User: ' + message['content'] }}"
"{% elif message['role'] == 'assistant' %}"
"{{ '### Assistant: ' + message['content'] }}"
"{% endif %}"
"{% endfor %}"
)

def __init__(self, config):
super().__init__(config)
self.config["chat_template"] = self.chat_template
self.default_system = "You are a helpful, respectful and honest assistant."

def create_data(self, data):
examples = {}
conv = data["conversations"]
# system
if conv[0]["from"] != "system":
examples["system"] = self.default_system
start = 0
elif conv[0]["from"] == "system" and conv[0]["value"] == "":
examples[conv[0]["from"]] = self.default_system
start = 1
else:
examples[conv[0]["from"]] = conv[0]["value"]
start = 1

for j in range(start, len(conv) - 1, 2):
examples[conv[j]["from"]] = conv[j]["value"]
examples[conv[j + 1]["from"]] = conv[j + 1]["value"]

if self.config.get("gpt_base_model"):
if examples["human"]:
return SLIMORCA_PROMPT_DICT["prompt_with_input"].format(
instruction=examples["system"],
response=examples["gpt"],
input=examples["human"],
)
else:
return SLIMORCA_PROMPT_DICT["prompt_without_input"].format(
instruction=examples["system"], response=examples["gpt"]
)
else:
new_messages = [
{"role": "system", "content": examples["system"] + "\n"},
{
"role": "user",
"content": examples["system"] + "\n" + INPUT_KEY + examples["human"] + "\n",
},
{"role": "assistant", "content": examples["gpt"] + "\n"},
]
return new_messages


class OpenOrcaDataPreprocess(ChatDataPreprocess):
def __init__(self, config):
super().__init__(config)
self.default_system = "You are an AI assistant. You will be given a task. You must generate a detailed and long answer."

def create_data(self, examples):
if self.config.get("gpt_base_model"):
if not examples["system"]:
examples["system"] = self.default_system

if examples["question"]:
return PROMPT_WITH_INPUT_FORMAT.format(
instruction=examples["system"],
response=examples["chosen"],
input=examples["question"],
)
else:
return PROMPT_NO_INPUT_FORMAT.format(
instruction=examples["system"], response=examples["chosen"]
)
else:
new_messages = [
{"role": "system", "content": INTRO_BLURB + "\n"},
{
"role": "user",
"content": examples["system"] + "\n" + INPUT_KEY + examples["question"] + "\n",
},
{"role": "assistant", "content": examples["chosen"] + "\n"},
]
return new_messages
16 changes: 14 additions & 2 deletions llm_on_ray/common/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@
class DefaultTrainer(Trainer):
def __init__(self, config):
self.model = None
self.tokenizer = None
self.config = config
dataprocesser_config = config.get("dataprocesser")
dataprocesser_type = dataprocesser_config.get("type")
Factory = dataprocesser.DataProcesser.registory.get(dataprocesser_type)
if dataprocesser_type == "chat":
Factory = dataprocesser.DataProcesser.registory.get("ChatDataPreprocess")
elif dataprocesser_type == "SlimOrca":
Factory = dataprocesser.DataProcesser.registory.get("SlimOrcaDataPreprocess")
else:
raise ValueError(f"there is no {dataprocesser_type} dataprocesser.")

if Factory is None:
raise ValueError(f"there is no {dataprocesser_type} dataprocesser.")
self.dataprocesser = Factory(dataprocesser_config)
Expand Down Expand Up @@ -121,7 +128,7 @@ def _get_lr_scheduler(

def prepare(self, model, tokenizer, dataset, optimizer, accelerator):
self._coordinate(accelerator)

self.tokenizer = tokenizer
embedding_size = model.get_input_embeddings().weight.shape[0]
logger.info(f"model embedding size: {embedding_size}")
if len(tokenizer) > embedding_size:
Expand Down Expand Up @@ -290,6 +297,11 @@ def train(self):
is_main_process=self.accelerator.is_main_process,
save_function=self.accelerator.save,
)
self.tokenizer.save_pretrained(
output,
is_main_process=self.accelerator.is_main_process,
save_function=self.accelerator.save,
)
logger.info(f"finish save model to {output}")

self.accelerator.wait_for_everyone()
Expand Down
29 changes: 26 additions & 3 deletions llm_on_ray/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#

#!/usr/bin/env python
# !/usr/bin/env python

import os
import argparse
Expand Down Expand Up @@ -221,7 +221,17 @@ def train_func(config: Dict[str, Any]):
}
)

dataprocesser = common.dataprocesser.DataProcesser.registory.get("GeneralProcesser")(
dataprocesser_type = config["Dataset"]["type"]
if dataprocesser_type == "chat":
preprocesser_name = "ChatDataPreprocess"
elif dataprocesser_type == "OpenOrca":
preprocesser_name = "OpenOrcaDataPreprocess"
elif dataprocesser_type == "SlimOrca":
preprocesser_name = "SlimOrcaDataPreprocess"
else:
raise ValueError(f"there is no {dataprocesser_type} dataprocesser.")

dataprocesser = common.dataprocesser.DataProcesser.registory.get(preprocesser_name)(
config={
"per_device_train_batch_size": config["Training"]["batch_size"],
"per_device_eval_batch_size": config["Training"]["batch_size"],
Expand All @@ -230,6 +240,11 @@ def train_func(config: Dict[str, Any]):
"group": config["Dataset"].get("group", True),
"block_size": config["Dataset"].get("block_size", 512),
"shuffle": config["Dataset"].get("shuffle", False),
"name": tokenizer_name,
"config": config["General"]["config"],
"gpt_base_model": config["General"].get("gpt_base_model", False),
"chat_template": config["General"]["chat_template"],
"default_chat_template": config["General"]["default_chat_template"],
}
)
tokenized_datasets = dataprocesser.tokenize_dataset(tokenizer, datasets)
Expand Down Expand Up @@ -356,7 +371,15 @@ def main(external_config=None):
) # additional 1 for head worker
ray.init(num_cpus=num_cpus, runtime_env=runtime_env)
else:
ray.init(runtime_env=runtime_env)
import intel_extension_for_pytorch as ipex

if "xpu" in ipex.__version__:
num_cpus = (
resources_per_worker["CPU"] * num_training_workers + 1
) # additional 1 for head worker
ray.init(num_cpus=num_cpus, runtime_env=runtime_env)
else:
ray.init(runtime_env=runtime_env)

common.logger.info(f"ray available resources = {ray.available_resources()}")
use_gpu = True if device == "gpu" else False
Expand Down
Loading
Loading