forked from vlgiitr/dmn-plus
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_test.py
126 lines (99 loc) · 6.34 KB
/
train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
''' This file contains the code for training and testing the model. Adam optimizer is used for training with a
learning rate of 0.001 and a batch size of 128. Training is done for 256 epochs with early stopping
if validation loss doesn't decrease within last 20 epochs. Weights are initialized using Xavier Initialization
except for word embeddings. Dropout and L2 are used as regularization methos on sentence encodings and answer module.'''
import os
import torch
import numpy as np
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as f
import torch.autograd as Variable
import torch.utils.data as DataLoader
from modelDMN import DMN
from dmn_loader import BabiDataSet, pad_collate
if __name__ == '__main__':
for itr in range(10):
for task in range(1,21):
dataset= BabiDataset(task)
vocab_size= len(dataset.QA.VOCAB)
hidden_size= 100
model= DMN(hidden_size, vocab_size, num_pass= 3, qa= dataset.QA) ##vocab_size denotes the size of word embedding used
early_stop_count= 0
early_stop_flag= False
best_acc= 0
optim= torch.optim.Adam(model.parameters())
for epoch in range(256):
dataset.set_mode('train')
train_load= DataLoader(dataset, batch_size=100, shuffle= True, collate_fn= pad_collate) ### Loading the babi dataset
model.train() ### training the network
if not early_stop_flag:
total_acc=0
count= 0
for batch_id, data in enumerate(train_load):
optim.zero_grad()
context, questions, answers = data
batch_size= context.size()[0]
context= Variable(context.long()) ## context.size() = (batch_size, num_sentences, embedding_length) embedding_length = hidden_size
questions= Variable(questions.long()) ## questions.size() = (batch_size, num_tokens)
answers= Variable(answers)
total_loss, acc = model.loss(context,questions,answers) ## Loss is calculated and gradients are backpropagated through the layers.
total_loss.backward()
total_acc += acc*batch_size
count += batch_size
if batch_id %20 == 0:
print('training error')
print ('task '+str(task_id)+',epoch '+str(epoch)+',loss ' +str(loss.data[0])+',total accuracy : '+str(total_acc/cnt))
optim.step()
'''Validation part'''
dataset.set_mode('valid')
valid_load = DataLoader(dataset, batch_size=100, shuffle=False, collate_fn=pad_collate) ## Loading the validation data
model.eval()
total_acc = 0
count = 0
for batch_id, data in enumerate(train_load):
context, questions, answers = data
batch_size = context.size()[0]
context = Variable(context.long())
questions = Variable(questions.long())
answers = Variable(answers)
_, acc = model.loss(context,questions,answers)
total_loss.backward()
total_acc += acc*batch_size
count += batch_size
total_acc = total_acc / count
if total_acc > best_acc:
best_acc = total_acc
best_state = model.state_dict()
early_stop_count = 0
else:
early_stop_count += 1
if early_stop_count > 20: # If the accuracy doesn't improve even after 20 epochs thenuse early stopping.
early_Stop_flag = True
print ('itr '+str(itr)+',task_id '+str(task_id)+',epoch '+str(epoch)+',total_acc '+str(total_acc))
with open('log.txt', 'a') as fp:
fp.write('itr '+str(itr)+', task_id '+str(task_id)+',epoch '+str(epoch)+',total_acc '+str(total_acc)+'+\n ')
if total_acc == 1.0:
break
else:
print('iteration'+str(itr)+'task' +str(task_id)+' Early Stopping at Epoch' +str(epoch)+'validation accuracy :' +str(best_acc))
dataset.set_mode('test')
test_load= DataLoader(dataset, batch_size=100, shuffle= False, collate_fn= pad_collate)
test_acc = 0
count = 0
for batch_id, data in enumerate(test_load):
context, questions, answers = data
batch_size = context.size()[0]
context = Variable(context.long())
questions = Variable(questions.long())
answers = Variable(answers)
model.load_state_dict(best_state) # Loading the best model
_, acc = model.loss(context, questions, answers)
test_acc += acc* batch_size
count += batch_size
print ('itr '+ str(itr)+'task =' +str(task_id)+ 'Epoch ' +str(epoch)+' test accuracy : '+str(test_acc / count))
os.makedirs('models',exist_ok=True)
with open('models/task'+str(task_id)+'_epoch'+str(epoch)+'_run'+str(run)+'_acc'+str(test_acc/cnt)+'.pth', 'wb') as fp:
torch.save(model.state_dict(), fp)
with open('log.txt', 'a') as fp:
fp.write('[itr '+str(itr)+', Task '+str(task_id)+', Epoch '+str(epoch)+'] [Test] Accuracy : '+str(total_acc)+' + \n')