-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
67 lines (53 loc) · 2.67 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import pytorch_lightning as pl
import datasets
from datasets import load_dataset
from transformers import AutoTokenizer
class Dataset(pl.LightningDataModule):
def __init__(self, model = 'google/bert_uncased_L-2_H-128_A-2', batch_size = 32,
max_length = 128):
super().__init__()
self.batch_size = batch_size
self.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(model)
# prepare dataset
cola_dataset = load_dataset('glue', 'cola')
self.train_dataset = cola_dataset['train']
self.validation_dataset = cola_dataset['validation']
self.test_dataset = cola_dataset['test']
# tokenize the text
def tokenize(self, sample):
return self.tokenizer(
sample['sentence'],
truncation = True,
padding = 'max_length',
max_length = self.max_length)
# process text and make read for loading...
def setup(self, stage = None):
if stage == "fit" or stage is None:
self.train_dataset = self.train_dataset.map(self.tokenize, batched = True)
self.train_dataset.set_format(type = "torch",
columns = ["input_ids", "attention_mask", "label"])
self.validation_dataset = self.validation_dataset.map(self.tokenize, batched = True)
self.validation_dataset.set_format(type = "torch",
columns = ["input_ids", "attention_mask", "label"],
output_all_columns = True)
self.test_dataset = self.test_dataset.map(self.tokenize, batched = True)
self.test_dataset.set_format(type = "torch",
columns = ["input_ids", "attention_mask", "label"])
# define training data loader
def train_dataloader(self):
return torch.utils.data.DataLoader(self.train_dataset, batch_size = self.batch_size,
shuffle = True, num_workers = 8)
# define validation data loader
def val_dataloader(self):
return torch.utils.data.DataLoader(self.validation_dataset, batch_size = self.batch_size,
shuffle = False, num_workers = 8)
# to test run the script. (optional for testing if everything works fine..)
if __name__ == "__main__":
dataset = Dataset()
dataset.setup()
data = dataset.val_dataloader()
batch = next(iter(data))
print(batch['input_ids'].shape, batch['label'].shape, batch['attention_mask'].shape)
print(batch['input_ids'][0])