Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: minmingzhu <[email protected]>
  • Loading branch information
minmingzhu committed Jun 17, 2024
1 parent b32bb20 commit db1b800
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 30 deletions.
92 changes: 64 additions & 28 deletions llm_on_ray/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import os
import argparse
import re
import sys
import copy

Expand Down Expand Up @@ -199,11 +200,10 @@ def local_load(name, **load_config):


def tokenize_dataset(config: Dict, tokenizer, dataset):
max_length = config["Dataset"].get("max_length", 512)
max_seq_length = config["Dataset"].get("max_length", 512)
group = config["Dataset"].get("group", True)
block_size = config["Dataset"].get("block_size", 512)
mask_input = config["Dataset"].get("mask_input", False)
mask_response = config["Dataset"].get("mask_response", False)
max_source_length = config["Dataset"].get("max_source_length", 384)
tokenizer.pad_token = tokenizer.eos_token

def prompt(rec, tokenizer):
Expand Down Expand Up @@ -234,47 +234,83 @@ def prompt(rec, tokenizer):
prompts = prompt(dataset[key], tokenizer)
dataset[key] = datasets.Dataset.from_dict(prompts)

def truncate_sequences(sequences, max_length):
"""
Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L40
"""
words_to_cut = sum(list(map(len, sequences))) - max_length
if words_to_cut <= 0:
return sequences

while words_to_cut > 0 and len(sequences) > 0:
words_to_cut -= len(sequences[0])
sequences = sequences[1:]

return sequences

def tokenize_function(examples):
max_seq_length = max_length
keys = list(examples.data.keys())
if len(keys) != 2:
raise ValueError("Unsupported dataset format")
"""
Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L225
The only differences are:
- using our own prompt style
"""
assistant = "### Response:\n"
end = tokenizer.eos_token
assistant_tokens = tokenizer.tokenize(assistant)
header = (
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
+ end
+ "\n"
)

instructions = [q.strip() for q in examples["prompt_sources"]]
responses = [q.strip() for q in examples["prompt_targets"]]

examples["input_ids"] = []
examples["labels"] = []
examples["attention_mask"] = []
for s, t in zip(examples[keys[0]], examples[keys[1]]):
results = tokenizer(
s + t, padding=False, truncation=True, return_tensors=None, max_length=max_length

for instruction, response in zip(instructions, responses):
convs = re.findall(
r"### Instruction.*?{0}|### Response.*?{0}".format(end), instruction, re.DOTALL
)
input_ids = results["input_ids"] + [tokenizer.eos_token_id]
input_len = len(input_ids)
labels = copy.deepcopy(input_ids)
# mask input
if mask_input:
sources_tokenized = tokenizer(
s, padding=False, truncation=True, return_tensors=None, max_length=max_length
)
input_id_len = len(sources_tokenized["input_ids"])
labels[:input_id_len] = [IGNORE_INDEX] * input_id_len
# mask response
if mask_response:
sources_tokenized = tokenizer(
s, padding=False, truncation=True, return_tensors=None, max_length=max_length
)
input_id_len = len(sources_tokenized["input_ids"])
labels[input_id_len:input_len] = [IGNORE_INDEX] * (input_len - input_id_len)
convs_tokens = [tokenizer.tokenize(conv) + tokenizer.tokenize("\n") for conv in convs]
header_tokens = tokenizer.tokenize(header) + tokenizer.tokenize("\n")

max_input = max_source_length - len(header_tokens) - len(assistant_tokens)

truncated_convs = truncate_sequences(convs_tokens, max_input)

if len(truncated_convs) == 0:
truncated_convs = [convs_tokens[-1][: max_input - 3] + convs_tokens[-1][-3:]]

prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens]
prompt_ids = [
tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens
]
prompt_ids = list(chain(*prompt_ids))

resp_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(response.strip()))
# keep last and eos_id
max_resp = max_seq_length - len(prompt_ids) - 1
if len(resp_ids) > max_resp:
resp_ids = resp_ids[: max_resp - 1] + resp_ids[-1:]

input_ids = prompt_ids + resp_ids + [tokenizer.eos_token_id]
labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids + [tokenizer.eos_token_id]

# padding
input_len = len(input_ids)
pad_len = max_seq_length - input_len
input_ids = input_ids + [tokenizer.eos_token_id] * pad_len
labels = labels + [IGNORE_INDEX] * pad_len
attention_mask = [1] * input_len + [0] * pad_len

assert len(input_ids) == max_seq_length
assert len(prompt_ids) <= max_source_length
assert len(labels) == len(input_ids) == len(attention_mask)

examples["input_ids"].append(input_ids)
examples["input_ids"].append(torch.tensor(input_ids))
examples["labels"].append(labels)
examples["attention_mask"].append(attention_mask)

Expand Down
3 changes: 1 addition & 2 deletions llm_on_ray/finetune/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ class Dataset(BaseModel):
group: bool = True
block_size: int = 512
shuffle: bool = False
mask_input: bool = True
mask_response: bool = False
max_source_length: int = 384


class RayResourceConfig(BaseModel):
Expand Down

0 comments on commit db1b800

Please sign in to comment.