-
Notifications
You must be signed in to change notification settings - Fork 0
/
classifier.py
108 lines (92 loc) · 3.58 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import annotations
import typing as t
import torch
import pydantic as pyd
from torch import nn
from transformers import PreTrainedTokenizer
from transformer import utils
from transformer.models.base import BaseLM
from transformer.modules.transformers import TransformerEncoder
from transformer.modules.embedding import InputEmbedding
from transformer.params import TransformerParams
__all__ = ["ClassifierLM"]
class ClassifierLM(BaseLM):
@pyd.validate_call(config=dict(arbitrary_types_allowed=True))
def __init__(
self: t.Self,
params: TransformerParams,
tokenizer: PreTrainedTokenizer,
num_classes: pyd.PositiveInt,
) -> None:
super().__init__(params=params)
self.tokenizer = tokenizer
self.model = nn.ModuleDict(
{
"input": nn.Sequential(
InputEmbedding(len(self.tokenizer), params.model_dim),
nn.Dropout(0.1),
),
"encoder": TransformerEncoder(params),
"mean": utils.nn.MaskedMean(),
"softmax": nn.Sequential(
nn.Linear(params.model_dim, num_classes),
nn.Tanh(),
nn.LogSoftmax(dim=-1),
),
}
)
def forward(
self: t.Self, ids: torch.LongTensor, masks: torch.LongTensor
) -> torch.FloatTensor:
# ids/masks shape: [batch_size, context_length]
# create input embeddings for tokens and pass through transformer
emb = self.model["input"](ids)
hidden = self.model["encoder"](emb, masks=masks)
# emb/hidden shape: [batch_size, context_length, model_dim]
# calculate the avg. embedding for each sequence (ignoring padding)
avg = self.model["mean"](hidden, masks=masks)
# avg shape: [batch_size, model_dim]
# calculate softmax over averaged encoder output (passed through a linear layer)
return self.model["softmax"](avg)
# output shape: [batch_size, num_classes]
def configure_optimizers(self: t.Self) -> torch.optim.Optimizer:
return torch.optim.SGD(self.model.parameters(), lr=3e-4)
def step(
self: t.Self, batch: tuple[torch.LongTensor, ...], *, stage: str
) -> torch.FloatTensor:
ids, targets, weights, masks = batch
# make predictions
preds = self(ids, masks)
# calculate (weighted) loss
loss = nn.functional.nll_loss(preds, targets, reduction="none")
loss = (weights * loss).mean()
self.log(f"{stage}_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return loss
def training_step(
self: t.Self, batch: tuple[torch.LongTensor, ...]
) -> torch.FloatTensor:
return self.step(batch, stage="train")
def validation_step(
self: t.Self, batch: tuple[torch.LongTensor, ...]
) -> torch.FloatTensor:
return self.step(batch, stage="val")
def test_step(
self: t.Self, batch: tuple[torch.LongTensor, ...]
) -> torch.FloatTensor:
return self.step(batch, stage="test")
def predict_step(
self: t.Self, batch: tuple[torch.LongTensor, ...]
) -> torch.FloatTensor:
ids, targets, masks = batch
preds = self(ids, masks).argmax(axis=-1)
return list(
zip(
self.tokenizer.batch_decode(
ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
),
preds.tolist(),
targets.tolist(),
)
)