Skip to content

Commit

Permalink
build up
Browse files Browse the repository at this point in the history
  • Loading branch information
yul091 committed May 5, 2023
1 parent 3714f40 commit 9dd3a8a
Show file tree
Hide file tree
Showing 22 changed files with 5,325 additions and 0 deletions.
298 changes: 298 additions & 0 deletions DG_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
import torch
from datasets import load_dataset, Dataset
from itertools import chain
from typing import List, Optional
from transformers import AutoTokenizer
from transformers.testing_utils import CaptureLogger
from transformers.utils.logging import get_logger


class DGDataset:
def __init__(
self,
dataset: str = "blended_skill_talk",
task: str = "seq2seq",
tokenizer: AutoTokenizer = None,
max_source_length: int = 512,
max_target_length: int = 512,
padding: str = "max_length",
ignore_pad_token_for_loss: bool = True,
preprocessing_num_workers: int = None,
overwrite_cache: bool = True,
):
self.dataset = dataset
self.task = task
self.tokenizer = tokenizer
self.max_source_length = max_source_length
self.max_target_length = max_target_length
self.padding = padding
self.ignore_pad_token_for_loss = ignore_pad_token_for_loss
self.preprocessing_num_workers = preprocessing_num_workers
self.overwrite_cache = overwrite_cache
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
self.tok_logger = get_logger("transformers.tokenization_utils_base")


def prepare_context(self, instance: dict):
if self.dataset == 'blended_skill_talk':
num_entries = len(instance["free_messages"])
total_entries = num_entries
if self.task == 'seq2seq':
persona_pieces = f"<PS>{instance['personas'][1]}"
if instance['context'] == "wizard_of_wikipedia":
additional_context_pieces = f"<CTX>{instance['additional_context']}."
else:
additional_context_pieces = ""
context = persona_pieces + additional_context_pieces
else:
num_entries = min(num_entries, 2)
context = ''
prev_utt_pc = [sent for sent in instance["previous_utterance"] if sent != '']

elif self.dataset == 'conv_ai_2':
total_entries = len(instance['dialog'])
num_entries = total_entries//2
if self.task == 'seq2seq':
user_profile = ' '.join([''.join(x) for x in instance['user_profile']])
persona_pieces = f"<PS>{user_profile}"
context = persona_pieces
else:
num_entries = min(num_entries, 2)
context = ''
prev_utt_pc = []

elif self.dataset == 'empathetic_dialogues':
total_entries = len(instance['dialog'])
num_entries = total_entries//2
if self.task == 'seq2seq':
persona_pieces = f"<PS>{instance['prompt']}"
additional_context_pieces = f"<CTX>{instance['context']}."
context = persona_pieces + additional_context_pieces
else:
num_entries = min(num_entries, 2)
context = ''
prev_utt_pc = []

elif self.dataset == 'AlekseyKorshuk/persona-chat':
total_entries = len(instance['utterances'])
num_entries = total_entries//2
if self.task == 'seq2seq':
user_profile = ' '.join(instance['personality'])
persona_pieces = f"<PS>{user_profile}"
context = persona_pieces
else:
num_entries = min(num_entries, 2)
context = ''
prev_utt_pc = []

else:
raise ValueError("Dataset not supported.")
return num_entries, total_entries, context, prev_utt_pc


def prepare_entry(
self,
instance: dict,
entry_idx: int,
context: str,
prev_utt_pc: List[str],
total_entries: int,
):
if self.dataset == 'blended_skill_talk':
free_message = instance['free_messages'][entry_idx]
guided_message = instance['guided_messages'][entry_idx]
references = [values[entry_idx] for key, values in instance['suggestions'].items()]

elif self.dataset == 'conv_ai_2':
free_message = instance['dialog'][entry_idx*2]['text']
if entry_idx*2+1 >= total_entries:
guided_message = None
else:
guided_message = instance['dialog'][entry_idx*2+1]['text']
references = []

elif self.dataset == 'empathetic_dialogues':
free_message = instance['dialog'][entry_idx*2]['text']
if entry_idx*2+1 >= total_entries:
guided_message = None
else:
guided_message = instance['dialog'][entry_idx*2+1]['text']
references = []

elif self.dataset == 'AlekseyKorshuk/persona-chat':
free_message = instance['utterances'][entry_idx*2]['history'][-1]
if entry_idx*2+1 >= total_entries:
guided_message = None
else:
guided_message = instance['utterances'][entry_idx*2+1]['history'][-1]
references = instance['utterances'][entry_idx*2]['candidates']

else:
raise ValueError("Dataset not supported.")

if not prev_utt_pc:
original_context = context
else:
sp_token = '<SEP>' if self.task == 'seq2seq' else ' '
original_context = context + sp_token + sp_token.join(prev_utt_pc)

references.append(guided_message)
return free_message, guided_message, original_context, references


def tokenize_and_align_labels(self, instance: dict):
num_entries, total_entries, context, prev_utt_pc = self.prepare_context(instance)
inputs, labels = [], []
for entry_idx in range(num_entries):
free_message, guided_message, original_context, references = self.prepare_entry(
instance,
entry_idx,
context,
prev_utt_pc,
total_entries,
)
if guided_message is None:
continue
# Input & Output
if self.task == 'seq2seq':
text = original_context + self.tokenizer.eos_token + free_message
else:
text = original_context + free_message + guided_message

inputs.append(text)
labels.append(guided_message)
prev_utt_pc += [
free_message,
guided_message,
]

if not inputs:
return {"input_ids": [], "labels": [], "attention_mask": []}

if self.task == 'seq2seq':
inputs = self.tokenizer(inputs, max_length=self.max_source_length, padding=self.padding, truncation=True)
# Setup the tokenizer for targets
with self.tokenizer.as_target_tokenizer():
labels = self.tokenizer(labels, max_length=self.max_target_length, padding=self.padding, truncation=True)

# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100
# when we want to ignore padding in the loss.
if self.padding == "max_length" and self.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
inputs["labels"] = labels["input_ids"]
return inputs
else:
with CaptureLogger(self.tok_logger) as cl:
inputs = self.tokenizer(
inputs,
return_tensors="pt",
max_length=self.max_source_length,
padding=self.padding,
truncation=True,
)
labels = self.tokenizer(
labels,
return_tensors="pt",
max_length=self.max_target_length,
padding=self.padding,
truncation=True,
)

new_inputs = inputs.copy()
for k, v1 in inputs.items():
v2 = labels[k]
new_inputs[k] = torch.cat((v1, v2), dim=1)

new_labels = torch.cat((-100*torch.ones_like(inputs["input_ids"]), labels["input_ids"]), dim=1)
new_inputs["labels"] = new_labels

# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
self.tok_logger.warning(
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
" before being passed to the model."
)
return new_inputs


def group_texts(self, examples):
# ['input_ids', 'attention_mask', 'labels']
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
return concatenated_examples


def group_ED(self, dataset: Dataset):
results = {
'conv_id': [],
'prompt': [],
'dialog': [],
'context': [],
}
for i, instance in enumerate(dataset):
if instance['utterance_idx'] == 1:
results['conv_id'].append(instance['conv_id'])
results['dialog'].append([])
results['prompt'].append(instance['prompt'])
results['context'].append(instance['context'])

response = {'text': instance['utterance'], 'speaker_idx': instance['speaker_idx']}
results['dialog'][-1].append(response)
return Dataset.from_dict(results)


def preprocess(self, dataset: Dataset):
if self.dataset == "empathetic_dialogues":
dataset = self.group_ED(dataset)

dataset = dataset.map(
self.tokenize_and_align_labels,
batched=False,
num_proc=self.preprocessing_num_workers,
remove_columns=dataset.column_names,
load_from_cache_file=not self.overwrite_cache,
)
dataset = dataset.map(
self.group_texts,
batched=True,
num_proc=self.preprocessing_num_workers,
load_from_cache_file=not self.overwrite_cache,
)
return dataset




if __name__ == "__main__":
from transformers import AutoTokenizer
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
tokenizer.pad_token = tokenizer.eos_token
data_names = [
"conv_ai_2",
"empathetic_dialogues",
"AlekseyKorshuk/persona-chat",
"blended_skill_talk",
]
task = "seq2seq"
max_length = 256

for data_name in data_names:
train_dataset = load_dataset(data_name)["train"]
dg = DGDataset(
dataset=data_name,
task=task,
tokenizer=tokenizer,
max_source_length=max_length,
max_target_length=max_length,
)
print('{}: {}'.format(data_name, train_dataset))
train_dataset = dg.preprocess(train_dataset)
print("processed dataset: ", train_dataset)
print("processed dataset[0]: ", train_dataset[0])




Loading

0 comments on commit 9dd3a8a

Please sign in to comment.