-
Notifications
You must be signed in to change notification settings - Fork 1
/
pos_tagging_test.py
78 lines (65 loc) · 2.59 KB
/
pos_tagging_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Prints POS-tagger results of loss and accuracy on train, dev and test set of
French-GSD dataset
"""
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def evaluate(net, criterion, loader, role='Test'):
"""Evaluates a trained POS-tagger on a dataset"""
net.eval()
correct = 0
total_preds = 0
mean_loss = 0.
with torch.no_grad():
for data, lengths, target in loader:
data, lengths, target = data.to(device), lengths.to(device), \
target.to(device)
seq_length, batch_size = tuple(data.size())
output = net(data, lengths)
# flatten output and target to compute loss on individual words
target = target.view(-1)
output = output.view(batch_size*seq_length, -1)
loss = criterion(output, target)
mean_loss += loss.item()
# predict the tags
_, pred = F.log_softmax(output, 1).max(1)
mask = target != 0
# counts the number of correct predictions
correct += pred[mask].eq(target[mask]).sum().item()
total_preds += mask.sum().item()
mean_loss /= len(loader)
acc = correct / total_preds
print(f"{role} mean loss: {mean_loss:.4e}, {role} accuracy: {acc:.4f}")
return mean_loss, acc
if __name__ == '__main__':
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Evaluates a POS-tagger on the test set of the French-GSD "
"dataset.")
parser.add_argument("--saved-path",
default='./checkpoints/saved/checkpt.pt',
type=str)
return parser.parse_args()
torch.manual_seed(42)
args = parse_args()
from tagger import Tagger
from pos_tagging_data import *
from utils import *
batch_size = 128
train_loader, val_loader, test_loader, words, tags = \
get_dataloaders_and_vocabs(batch_size)
criterion = nn.CrossEntropyLoss(ignore_index=0)
# load model from saved checkpoint
net = Tagger(len(words), len(tags), embedding_size=30, hidden_size=30)
checkpoint = CheckpointState(net, savepath=args.saved_path)
checkpoint.load()
net = net.to(device)
# evaluate the model on the train, validation and test set,
# printing loss and accuracy
evaluate(net, criterion, train_loader, 'Train')
evaluate(net, criterion, val_loader, 'Val')
evaluate(net, criterion, test_loader, 'Test')