-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathmain.py
47 lines (35 loc) · 1.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import datetime
import torch
from pathlib import Path
from bert.dataset import IMDBBertDataset
from bert.model import BERT
from bert.trainer import BertTrainer
BASE_DIR = Path(__file__).resolve().parent
EMB_SIZE = 64
HIDDEN_SIZE = 36
EPOCHS = 4
BATCH_SIZE = 12
NUM_HEADS = 4
CHECKPOINT_DIR = BASE_DIR.joinpath('data/bert_checkpoints')
timestamp = datetime.datetime.utcnow().timestamp()
LOG_DIR = BASE_DIR.joinpath(f'data/logs/bert_experiment_{timestamp}')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
if __name__ == '__main__':
print("Prepare dataset")
ds = IMDBBertDataset(BASE_DIR.joinpath('data/imdb.csv'), ds_from=0, ds_to=1000)
bert = BERT(len(ds.vocab), EMB_SIZE, HIDDEN_SIZE, NUM_HEADS).to(device)
trainer = BertTrainer(
model=bert,
dataset=ds,
log_dir=LOG_DIR,
checkpoint_dir=CHECKPOINT_DIR,
print_progress_every=20,
print_accuracy_every=200,
batch_size=BATCH_SIZE,
learning_rate=0.00007,
epochs=15
)
trainer.print_summary()
trainer()