diff --git a/docs/finetune_parameters.md b/docs/finetune_parameters.md index 55c303e70..94c468032 100644 --- a/docs/finetune_parameters.md +++ b/docs/finetune_parameters.md @@ -19,16 +19,25 @@ The following are the parameters supported in the finetuning workflow. ## Dataset Parameters -|Configuration Name| Default|Meaning| -|-|-|-| -|train_file|examples/data/sample_finetune_data.jsonl|A json file containing the training data.| -|validation_file|None|A json file containing the validation data.| -|validation_split_percentage|5|The percentage of the train set used as validation set in case there's no validation split| -|preprocessing_num_workers|None|The number of processes to use for the preprocessing.| -|max_length|512|Padding sequential data to max length of a batch| -|group|True|Whether to concatenate the sentence for more efficient training| -|block_size|512|The block size of concatenated sentence| -|shuffle|False|Whether shuffle the data at every epoch| +| Configuration Name | Default| Meaning | +|-----------------------------|-|------------------------------------------------------------------------------------------------------------------------------------------| +| train_file |examples/data/sample_finetune_data.jsonl| A json file containing the training data. | +| validation_file |None| A json file containing the validation data. | +| validation_split_percentage |5| The percentage of the train set used as validation set in case there's no validation split | +| preprocessing_num_workers |None| The number of processes to use for the preprocessing. | +| max_length |512| Padding sequential data to max length of a batch | +| group |True| Whether to concatenate the sentence for more efficient training | +| block_size |512| The block size of concatenated sentence | +| shuffle |False| Whether shuffle the data at every epoch | +| max_source_length |384| The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. | +| padding_side |right| The side on which the model should have padding applied. Should be selected between ['right', 'left']. | +| truncation_side |right| The side on which the model should have truncation applied. Should be selected between ['right', 'left']. | +| max_seq_length |max_length| The maximum total input sequence length after tokenization. | +| truncation |True| truncation strategy. Should be selected between ['only_first', 'only_second', 'longest_first/True', 'do_not_truncate/False']. | +| padding |True| padding strategy. Should be selected between ['longest/True', 'do_not_pad/False', 'max_length'] +| mask_input |True| mask the input part in lables | +| mask_response |True| mask the response part in lables | +| data_preprocess_type |neural_chat| The type of the encode input | ## Training Parameters diff --git a/llm_on_ray/finetune/data_process.py b/llm_on_ray/finetune/data_process.py new file mode 100644 index 000000000..6435928a1 --- /dev/null +++ b/llm_on_ray/finetune/data_process.py @@ -0,0 +1,220 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import re +from itertools import chain + +import torch + +IGNORE_INDEX = -100 + + +class DataProcessor: + # We used the following prompts for fine-tuning the Alpaca model. You can find reference doc form this URL(https://github.com/tatsu-lab/stanford_alpaca/blob/main/README.md#data-release) + def __init__(self, config, tokenizer): + self.tokenizer = tokenizer + self.end = tokenizer.eos_token + self.intro = "Below is an instruction that describes a task. Write a response that appropriately completes the request." + self.instruction = "### Instruction:\n" + self.input = "### Input:\n" + self.response = "### Response:\n" + self.padding_side = config["Dataset"].get("padding_side", "right") + self.truncation_side = config["Dataset"].get("truncation_side", "right") + self.max_length = self.max_seq_length = config["Dataset"].get("max_length", 512) + self.max_source_length = config["Dataset"].get("max_source_length", 384) + self.truncation = config["Dataset"].get("truncation", True) + self.padding = config["Dataset"].get("padding", True) + self.mask_input = config["Dataset"].get("mask_input", True) + self.mask_response = config["Dataset"].get("mask_response", True) + + def make_prompt(self, examples): + prompts = {} + prompts["prompt_sources"] = [] + prompts["prompt_targets"] = [] + for rec in examples: + instruction = rec["instruction"] + response = rec["response"] + context = rec.get("context") + if not instruction: + raise ValueError(f"Expected an instruction in: {rec}") + if not response: + raise ValueError(f"Expected a response in: {rec}") + if context: + prompt = ( + self.intro + + self.end + + "\n" + + self.instruction + + instruction + + self.input + + context + + self.end + + "\n" + + self.response + ) + prompts["prompt_sources"].append(prompt) + else: + prompt = ( + self.intro + + self.end + + "\n" + + self.instruction + + instruction + + self.end + + "\n" + + self.response + ) + prompts["prompt_sources"].append(prompt) + prompt_response = response + self.end + prompts["prompt_targets"].append(prompt_response) + return prompts + + def __truncate_sequences(self, sequences, max_length): + """ + Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L40 + """ + words_to_cut = sum(list(map(len, sequences))) - max_length + if words_to_cut <= 0: + return sequences + + while words_to_cut > 0 and len(sequences) > 0: + words_to_cut -= len(sequences[0]) + sequences = sequences[1:] + return sequences + + def tokenize_by_neural_chat(self, examples): + """ + Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L225 + The only differences are: + - using our own prompt style + - add left or right padding and truncation + - add mask_input and mask_response + """ + keys = list(examples.data.keys()) + if len(keys) != 2: + raise ValueError("Unsupported dataset format") + assistant_tokens = self.tokenizer.tokenize(self.response) + header = self.intro + self.end + "\n" + + examples["input_ids"] = [] + examples["labels"] = [] + examples["attention_mask"] = [] + for instruction, response in zip(examples[keys[0]], examples[keys[1]]): + convs = re.findall( + r"{0}.*?{2}|{1}.*?{2}".format(self.instruction, self.response, self.end), + instruction, + re.DOTALL, + ) + convs_tokens = [ + self.tokenizer.tokenize(conv) + self.tokenizer.tokenize("\n") for conv in convs + ] + header_tokens = self.tokenizer.tokenize(header) + self.tokenizer.tokenize("\n") + max_input = self.max_source_length - len(header_tokens) - len(assistant_tokens) + truncated_convs = self.__truncate_sequences(convs_tokens, max_input) + if len(truncated_convs) == 0: + truncated_convs = [convs_tokens[-1][: max_input - 3] + convs_tokens[-1][-3:]] + + prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens] + prompt_ids = [ + self.tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens + ] + prompt_ids = list(chain(*prompt_ids)) + + resp_ids = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(response.strip()) + ) + # keep last and eos_id + max_resp = self.max_seq_length - len(prompt_ids) - 1 + + # truncating response + if len(resp_ids) > max_resp: + if self.truncation_side == "right": + resp_ids = resp_ids[: max_resp - 1] + resp_ids[-1:] + else: + resp_ids = resp_ids[-max_resp:] + + # masking + input_ids = prompt_ids + resp_ids + [self.tokenizer.eos_token_id] + if self.mask_input: + labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids + [self.tokenizer.eos_token_id] + elif self.mask_response: + labels = prompt_ids + [IGNORE_INDEX] * len(resp_ids) + [self.tokenizer.eos_token_id] + else: + labels = input_ids + + # padding + input_len = len(input_ids) + pad_len = self.max_seq_length - input_len + if self.padding_side == "right": + input_ids = input_ids + [self.tokenizer.eos_token_id] * pad_len + labels = labels + [IGNORE_INDEX] * pad_len + attention_mask = [1] * input_len + [0] * pad_len + else: + input_ids = [self.tokenizer.eos_token_id] * pad_len + input_ids + labels = [IGNORE_INDEX] * pad_len + labels + attention_mask = [0] * pad_len + [1] * input_len + + assert len(input_ids) == self.max_seq_length + assert len(prompt_ids) <= self.max_source_length + assert len(labels) == len(input_ids) == len(attention_mask) + + examples["input_ids"].append(torch.tensor(input_ids)) + examples["labels"].append(labels) + examples["attention_mask"].append(attention_mask) + + return examples + + def tokenize(self, examples): + keys = list(examples.data.keys()) + if len(keys) != 2: + raise ValueError("Unsupported dataset format") + + examples["input_ids"] = [] + examples["labels"] = [] + examples["attention_mask"] = [] + for s, t in zip(examples[keys[0]], examples[keys[1]]): + results = self.tokenizer( + s + t, + padding=self.padding, + truncation=self.truncation, + return_tensors=None, + max_length=self.max_length, + ) + + input_ids = results["input_ids"] + input_len = len(input_ids) + labels = copy.deepcopy(input_ids) + if self.mask_input or self.mask_response: + sources_tokenized = self.tokenizer( + s, + padding=False, + truncation=True, + return_tensors=None, + max_length=self.max_length, + ) + input_id_len = len(sources_tokenized["input_ids"]) + # mask input + if self.mask_input: + labels[:input_id_len] = [IGNORE_INDEX] * input_id_len + # mask response + if self.mask_response: + labels[input_id_len:input_len] = [IGNORE_INDEX] * (input_len - input_id_len) + + examples["input_ids"].append(results["input_ids"]) + examples["labels"].append(labels) + examples["attention_mask"].append(results["attention_mask"]) + return examples diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py index eb4996cb5..8c67dcb4d 100644 --- a/llm_on_ray/finetune/finetune.py +++ b/llm_on_ray/finetune/finetune.py @@ -18,7 +18,10 @@ import os import argparse +import re import sys +import copy + from typing import Any, Dict, Union, Optional from itertools import chain @@ -37,9 +40,8 @@ from pydantic_yaml import parse_yaml_raw_as from llm_on_ray import common -from llm_on_ray.finetune import template +from llm_on_ray.finetune.data_process import DataProcessor from llm_on_ray.finetune.finetune_config import FinetuneConfig -from importlib import util def adapt_transformers_to_device(config: Dict): @@ -140,7 +142,13 @@ def load_tokenizer(config: Dict): else: tokenizer_name = config["General"]["base_model"] load_config = config["General"].get("config", {}) - tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name, **load_config) + # default padding side is right + padding_side = config["Dataset"].get("padding_side", "right") + # default truncation side is right + truncation_side = config["Dataset"].get("truncation_side", "right") + tokenizer = transformers.AutoTokenizer.from_pretrained( + tokenizer_name, padding_side=padding_side, truncation_side=truncation_side, **load_config + ) return tokenizer @@ -195,50 +203,27 @@ def local_load(name, **load_config): def tokenize_dataset(config: Dict, tokenizer, dataset): - max_length = config["Dataset"].get("max_length", 512) group = config["Dataset"].get("group", True) block_size = config["Dataset"].get("block_size", 512) tokenizer.pad_token = tokenizer.eos_token - if isinstance(dataset, datasets.Dataset): - column_names = dataset.column_names - - if isinstance(dataset, datasets.DatasetDict): - column_names = dataset["train"].column_names - - if column_names and template.TEXT_COLUMN_NAME not in column_names: - - def prompt(rec): - instruction = rec["instruction"] - response = rec["response"] - context = rec.get("context") - if not instruction: - raise ValueError(f"Expected an instruction in: {rec}") - if not response: - raise ValueError(f"Expected a response in: {rec}") - if context: - rec["text"] = template.PROMPT_WITH_INPUT_FORMAT.format( - instruction=instruction, response=response, input=context - ) - else: - rec["text"] = template.PROMPT_NO_INPUT_FORMAT.format( - instruction=instruction, response=response - ) - return rec + processor = DataProcessor(config, tokenizer) - dataset = dataset.map( - prompt, - load_from_cache_file=False, - desc="Prompt", - ) - column_names += [template.TEXT_COLUMN_NAME] + for key in dataset: + prompts = processor.make_prompt(dataset[key]) + dataset[key] = datasets.Dataset.from_dict(prompts) - def tokenize_function(examples): - return tokenizer(examples[template.TEXT_COLUMN_NAME], max_length=max_length) + column_names = list(dataset["train"].features) + tokenize_fn = ( + processor.tokenize_by_neural_chat + if config["Dataset"].get("data_preprocess_type", "neural_chat") == "neural_chat" + else processor.tokenize + ) tokenized_dataset = dataset.map( - tokenize_function, + tokenize_fn, remove_columns=column_names, + batched=True, load_from_cache_file=False, desc="Tokenize dataset", ) @@ -258,7 +243,6 @@ def group_texts(examples): k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } - result["labels"] = result["input_ids"].copy() return result tokenized_dataset = tokenized_dataset.map( diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index e1efd0b48..27bbe3cd5 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -71,6 +71,15 @@ class Dataset(BaseModel): group: bool = True block_size: int = 512 shuffle: bool = False + max_source_length: int = 384 + padding_side: str = "right" + truncation_side: str = "right" + max_seq_length: int = 512 + truncation: bool = True + padding: bool = True + mask_input: bool = True + mask_response: bool = True + data_preprocess_type: str = "neural_chat" class RayResourceConfig(BaseModel): diff --git a/llm_on_ray/finetune/template.py b/llm_on_ray/finetune/template.py deleted file mode 100644 index cf8647d7f..000000000 --- a/llm_on_ray/finetune/template.py +++ /dev/null @@ -1,64 +0,0 @@ -# -# Copyright 2023 The LLM-on-Ray Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -#!/usr/bin/env python - -INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request." -INSTRUCTION_KEY = "### Instruction:" -INPUT_KEY = "Input:" -RESPONSE_KEY = "### Response:" -END_KEY = "### End" -RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" - -PROMPT_NO_INPUT_FORMAT = """{intro} - -{instruction_key} -{instruction} - -{response_key} -{response} - -{end_key}""".format( - intro=INTRO_BLURB, - instruction_key=INSTRUCTION_KEY, - instruction="{instruction}", - response_key=RESPONSE_KEY, - response="{response}", - end_key=END_KEY, -) - -PROMPT_WITH_INPUT_FORMAT = """{intro} - -{instruction_key} -{instruction} - -{input_key} -{input} - -{response_key} -{response} - -{end_key}""".format( - intro=INTRO_BLURB, - instruction_key=INSTRUCTION_KEY, - instruction="{instruction}", - input_key=INPUT_KEY, - input="{input}", - response_key=RESPONSE_KEY, - response="{response}", - end_key=END_KEY, -) -TEXT_COLUMN_NAME = "text"