Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
Signed-off-by: minmingzhu <[email protected]>
  • Loading branch information
minmingzhu committed Jun 26, 2024
1 parent a633a13 commit d3c99ea
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 149 deletions.
287 changes: 141 additions & 146 deletions llm_on_ray/finetune/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,21 @@

class DataProcessor:
# We used the following prompts for fine-tuning the Alpaca model. You can find reference doc form this URL(https://github.com/tatsu-lab/stanford_alpaca/blob/main/README.md#data-release)
def __init__(self, config, eos_token):
self.config = config
self.end = eos_token
def __init__(self, config, tokenizer):
self.tokenizer = tokenizer
self.end = tokenizer.eos_token
self.intro = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
self.instruction = "### Instruction:\n"
self.input = "### Input:\n"
self.response = "### Response:\n"
self.padding_side = config["Dataset"].get("padding_side", "right")
self.truncation_side = config["Dataset"].get("truncation_side", "right")
self.max_length = self.max_seq_length = config["Dataset"].get("max_length", 512)
self.max_source_length = config["Dataset"].get("max_source_length", 384)
self.truncation = config["Dataset"].get("truncation", True)
self.padding = config["Dataset"].get("padding", True)
self.mask_input = config["Dataset"].get("mask_input", True)
self.mask_response = config["Dataset"].get("mask_response", True)

def make_prompt(self, examples):
prompts = {}
Expand Down Expand Up @@ -75,151 +83,138 @@ def make_prompt(self, examples):
prompts["prompt_targets"].append(prompt_response)
return prompts

def tokenize(self, tokenizer):
padding_side = self.config["Dataset"].get("padding_side", "right")
truncation_side = self.config["Dataset"].get("truncation_side", "right")
max_length = max_seq_length = self.config["Dataset"].get("max_length", 512)
max_source_length = self.config["Dataset"].get("max_source_length", 384)
truncation = self.config["Dataset"].get("truncation", True)
padding = self.config["Dataset"].get("padding", True)
mask_input = self.config["Dataset"].get("mask_input", True)
mask_response = self.config["Dataset"].get("mask_response", True)

def truncate_sequences(sequences, max_length):
"""
Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L40
"""
words_to_cut = sum(list(map(len, sequences))) - max_length
if words_to_cut <= 0:
return sequences

while words_to_cut > 0 and len(sequences) > 0:
words_to_cut -= len(sequences[0])
sequences = sequences[1:]
def __truncate_sequences(self, sequences, max_length):
"""
Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L40
"""
words_to_cut = sum(list(map(len, sequences))) - max_length
if words_to_cut <= 0:
return sequences

def preprocess_function_with_neural_chat(examples):
"""
Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L225
The only differences are:
- using our own prompt style
- add left or right padding and truncation
- add mask_input and mask_response
"""
keys = list(examples.data.keys())
if len(keys) != 2:
raise ValueError("Unsupported dataset format")
assistant_tokens = tokenizer.tokenize(self.response)
header = self.intro + self.end + "\n"

examples["input_ids"] = []
examples["labels"] = []
examples["attention_mask"] = []
for instruction, response in zip(examples[keys[0]], examples[keys[1]]):
convs = re.findall(
r"{0}.*?{2}|{1}.*?{2}".format(self.instruction, self.response, self.end),
instruction,
re.DOTALL,
)
convs_tokens = [
tokenizer.tokenize(conv) + tokenizer.tokenize("\n") for conv in convs
]
header_tokens = tokenizer.tokenize(header) + tokenizer.tokenize("\n")
max_input = max_source_length - len(header_tokens) - len(assistant_tokens)
truncated_convs = truncate_sequences(convs_tokens, max_input)
if len(truncated_convs) == 0:
truncated_convs = [convs_tokens[-1][: max_input - 3] + convs_tokens[-1][-3:]]

prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens]
prompt_ids = [
tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens
]
prompt_ids = list(chain(*prompt_ids))

resp_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(response.strip()))
# keep last and eos_id
max_resp = max_seq_length - len(prompt_ids) - 1

# truncating response
if len(resp_ids) > max_resp:
if truncation_side == "right":
resp_ids = resp_ids[: max_resp - 1] + resp_ids[-1:]
else:
resp_ids = resp_ids[-max_resp:]

# masking
input_ids = prompt_ids + resp_ids + [tokenizer.eos_token_id]
if mask_input:
labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids + [tokenizer.eos_token_id]
elif mask_response:
labels = prompt_ids + [IGNORE_INDEX] * len(resp_ids) + [tokenizer.eos_token_id]
else:
labels = input_ids

# padding
input_len = len(input_ids)
pad_len = max_seq_length - input_len
if padding_side == "right":
input_ids = input_ids + [tokenizer.eos_token_id] * pad_len
labels = labels + [IGNORE_INDEX] * pad_len
attention_mask = [1] * input_len + [0] * pad_len
while words_to_cut > 0 and len(sequences) > 0:
words_to_cut -= len(sequences[0])
sequences = sequences[1:]
return sequences

def tokenize_by_neural_chat(self, examples):
"""
Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L225
The only differences are:
- using our own prompt style
- add left or right padding and truncation
- add mask_input and mask_response
"""
keys = list(examples.data.keys())
if len(keys) != 2:
raise ValueError("Unsupported dataset format")
assistant_tokens = self.tokenizer.tokenize(self.response)
header = self.intro + self.end + "\n"

examples["input_ids"] = []
examples["labels"] = []
examples["attention_mask"] = []
for instruction, response in zip(examples[keys[0]], examples[keys[1]]):
convs = re.findall(
r"{0}.*?{2}|{1}.*?{2}".format(self.instruction, self.response, self.end),
instruction,
re.DOTALL,
)
convs_tokens = [
self.tokenizer.tokenize(conv) + self.tokenizer.tokenize("\n") for conv in convs
]
header_tokens = self.tokenizer.tokenize(header) + self.tokenizer.tokenize("\n")
max_input = self.max_source_length - len(header_tokens) - len(assistant_tokens)
truncated_convs = self.__truncate_sequences(convs_tokens, max_input)
if len(truncated_convs) == 0:
truncated_convs = [convs_tokens[-1][: max_input - 3] + convs_tokens[-1][-3:]]

prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens]
prompt_ids = [
self.tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens
]
prompt_ids = list(chain(*prompt_ids))

resp_ids = self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(response.strip())
)
# keep last and eos_id
max_resp = self.max_seq_length - len(prompt_ids) - 1

# truncating response
if len(resp_ids) > max_resp:
if self.truncation_side == "right":
resp_ids = resp_ids[: max_resp - 1] + resp_ids[-1:]
else:
input_ids = [tokenizer.eos_token_id] * pad_len + input_ids
labels = [IGNORE_INDEX] * pad_len + labels
attention_mask = [0] * pad_len + [1] * input_len

assert len(input_ids) == max_seq_length
assert len(prompt_ids) <= max_source_length
assert len(labels) == len(input_ids) == len(attention_mask)

examples["input_ids"].append(torch.tensor(input_ids))
examples["labels"].append(labels)
examples["attention_mask"].append(attention_mask)

return examples

def preprocess_function_encode_inputs(examples):
keys = list(examples.data.keys())
if len(keys) != 2:
raise ValueError("Unsupported dataset format")

examples["input_ids"] = []
examples["labels"] = []
examples["attention_mask"] = []
for s, t in zip(examples[keys[0]], examples[keys[1]]):
results = tokenizer(
s + t,
padding=padding,
truncation=truncation,
resp_ids = resp_ids[-max_resp:]

# masking
input_ids = prompt_ids + resp_ids + [self.tokenizer.eos_token_id]
if self.mask_input:
labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids + [self.tokenizer.eos_token_id]
elif self.mask_response:
labels = prompt_ids + [IGNORE_INDEX] * len(resp_ids) + [self.tokenizer.eos_token_id]
else:
labels = input_ids

# padding
input_len = len(input_ids)
pad_len = self.max_seq_length - input_len
if self.padding_side == "right":
input_ids = input_ids + [self.tokenizer.eos_token_id] * pad_len
labels = labels + [IGNORE_INDEX] * pad_len
attention_mask = [1] * input_len + [0] * pad_len
else:
input_ids = [self.tokenizer.eos_token_id] * pad_len + input_ids
labels = [IGNORE_INDEX] * pad_len + labels
attention_mask = [0] * pad_len + [1] * input_len

assert len(input_ids) == self.max_seq_length
assert len(prompt_ids) <= self.max_source_length
assert len(labels) == len(input_ids) == len(attention_mask)

examples["input_ids"].append(torch.tensor(input_ids))
examples["labels"].append(labels)
examples["attention_mask"].append(attention_mask)

return examples

def tokenize(self, examples):
keys = list(examples.data.keys())
if len(keys) != 2:
raise ValueError("Unsupported dataset format")

examples["input_ids"] = []
examples["labels"] = []
examples["attention_mask"] = []
for s, t in zip(examples[keys[0]], examples[keys[1]]):
results = self.tokenizer(
s + t,
padding=self.padding,
truncation=self.truncation,
return_tensors=None,
max_length=self.max_length,
)

input_ids = results["input_ids"]
input_len = len(input_ids)
labels = copy.deepcopy(input_ids)
if self.mask_input or self.mask_response:
sources_tokenized = self.tokenizer(
s,
padding=False,
truncation=True,
return_tensors=None,
max_length=max_length,
max_length=self.max_length,
)

input_ids = results["input_ids"]
input_len = len(input_ids)
labels = copy.deepcopy(input_ids)
if mask_input or mask_response:
sources_tokenized = tokenizer(
s,
padding=False,
truncation=True,
return_tensors=None,
max_length=max_length,
)
input_id_len = len(sources_tokenized["input_ids"])
# mask input
if mask_input:
labels[:input_id_len] = [IGNORE_INDEX] * input_id_len
# mask response
if mask_response:
labels[input_id_len:input_len] = [IGNORE_INDEX] * (input_len - input_id_len)

examples["input_ids"].append(results["input_ids"])
examples["labels"].append(labels)
examples["attention_mask"].append(results["attention_mask"])
return examples

if self.config["Dataset"].get("data_preprocess_type", "neural_chat") == "neural_chat":
return preprocess_function_with_neural_chat

return preprocess_function_encode_inputs
input_id_len = len(sources_tokenized["input_ids"])
# mask input
if self.mask_input:
labels[:input_id_len] = [IGNORE_INDEX] * input_id_len
# mask response
if self.mask_response:
labels[input_id_len:input_len] = [IGNORE_INDEX] * (input_len - input_id_len)

examples["input_ids"].append(results["input_ids"])
examples["labels"].append(labels)
examples["attention_mask"].append(results["attention_mask"])
return examples
11 changes: 8 additions & 3 deletions llm_on_ray/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,21 @@ def tokenize_dataset(config: Dict, tokenizer, dataset):
block_size = config["Dataset"].get("block_size", 512)
tokenizer.pad_token = tokenizer.eos_token

processor = DataProcessor(config, tokenizer.eos_token)
processor = DataProcessor(config, tokenizer)

for key in dataset:
prompts = processor.make_prompt(dataset[key])
dataset[key] = datasets.Dataset.from_dict(prompts)

column_names = list(dataset["train"].features)
processor_fn = processor.tokenize(tokenizer)
tokenize_fn = (
processor.tokenize_by_neural_chat
if config["Dataset"].get("data_preprocess_type", "neural_chat") == "neural_chat"
else processor.tokenize
)

tokenized_dataset = dataset.map(
processor_fn,
tokenize_fn,
remove_columns=column_names,
batched=True,
load_from_cache_file=False,
Expand Down

0 comments on commit d3c99ea

Please sign in to comment.