forked from PKU-TANGENT/nlp-tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a8017ff
commit a7029a6
Showing
10 changed files
with
1,865,239 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn import CrossEntropyLoss | ||
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence | ||
|
||
|
||
class BiLSTM(nn.Module): | ||
def __init__(self, config): | ||
super(BiLSTM, self).__init__() | ||
self.config = config | ||
self.device = config.device | ||
|
||
self.emb_dim = config.emb_dim | ||
self.hidden_dim = config.hidden_dim | ||
self.vocab_size = config.vocab_size | ||
self.tag2id = config.tag2id | ||
self.tagset_size = len(self.tag2id) | ||
|
||
# map input tokens to unique embedding(vector) | ||
self.emb = nn.Embedding(self.vocab_size, self.emb_dim) | ||
# LSTM is a variant of Recurrent Neural Network(RNN) | ||
self.lstm = nn.LSTM(self.emb_dim, self.hidden_dim, num_layers=1, batch_first=True, bidirectional=True) | ||
# linear layer, predict the probability of each tag | ||
self.hidden2tag = nn.Linear(2 * self.hidden_dim, self.tagset_size) | ||
|
||
# Loss: compute the distance between our prediction and the gold tag | ||
self.loss = CrossEntropyLoss() | ||
|
||
def forward(self, sent, labels, lengths, mask): | ||
embedded = self.emb(sent) | ||
|
||
# The padded batch should be packed before LSTM | ||
embedded = pack_padded_sequence(embedded, lengths, batch_first=True) | ||
lstm_out, _ = self.lstm(embedded) | ||
# The packed batch should be padded after LSTM | ||
lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True) # lstm_out: [batch_size, max_len, hidden_dim] | ||
logits = self.hidden2tag(lstm_out) # logits: [batch_size, max_len, tagset_size] | ||
|
||
# Predict the tags | ||
pred_tag = torch.argmax(logits, dim=-1) | ||
|
||
# Compute loss. Pad token must be masked before computing the loss. | ||
logits = logits.view(-1, self.tagset_size)[mask.view(-1) == 1.0] | ||
loss = self.loss(logits, labels.view(-1)) | ||
|
||
return loss, pred_tag |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# BiLSTM Baseline | ||
|
||
|
||
|
||
简单的 BiLSTM NER baseline,作为我们 Project 的一个起点 | ||
|
||
代码已经做好注释,如有不明白之处请及时提问,如发现 bug 请及时报告hhh | ||
|
||
|
||
|
||
## 项目结构 | ||
|
||
```shell | ||
├── data # The directory of datasets | ||
│ └── renMinRiBao | ||
│ ├── tags.txt | ||
│ ├── test_data.txt | ||
│ ├── train_data.txt | ||
│ └── val_data.txt | ||
├── model # The directory to save models | ||
├── BiLSTM.py # BiLSTM model | ||
├── preprocess_data.py # Preprocess the datasets | ||
├── train.py # The entry of this project | ||
└── utils.py # Some useful functions | ||
|
||
``` | ||
|
||
|
||
|
||
## 运行环境 | ||
|
||
```shell | ||
Python 3.7 | ||
PyTorch 1.7.1 | ||
NumPy 1.19.2 | ||
``` | ||
|
||
请使用 pip install 命令安装以上运行环境(版本不一定非要一样,正常安装最新版一般不会有问题),建议使用 anaconda 或 miniconda 等虚拟环境(自行百度),做好环境管理 | ||
|
||
执行 `python train.py` 即可按照默认设置训练模型 | ||
|
||
执行 `python train.py --test` 即可在测试集上进行测试 | ||
|
||
|
||
|
||
## 运行结果 | ||
|
||
训练10个 epoch,准确率:88.58,召回率:83.45,F1:85.94(实际测试结果略有出入为正常现象) | ||
|
||
|
||
|
||
## 学习资料 | ||
|
||
机器学习基本概念:周志华《机器学习》前几章 | ||
|
||
Word embedding:https://www.zhihu.com/question/32275069 | ||
|
||
RNN:https://zhuanlan.zhihu.com/p/123211148,https://zhuanlan.zhihu.com/p/28054589,https://zhuanlan.zhihu.com/p/30844905,https://www.bilibili.com/video/BV1JE411g7XF?p=20 | ||
|
||
LSTM:http://colah.github.io/posts/2015-08-Understanding-LSTMs/ | ||
|
||
交叉熵:https://zhuanlan.zhihu.com/p/149186719,https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html | ||
|
||
梯度下降:https://www.bilibili.com/video/BV1JE411g7XF?p=5 | ||
|
||
PyTorch document:https://pytorch.org/docs/stable/index.html | ||
|
||
PyTorch 教程:https://pytorch.org/tutorials/index.html | ||
|
||
课程:李宏毅深度学习,https://www.bilibili.com/video/BV1JE411g7XF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
O | ||
B-LOC | ||
B-PER | ||
I-ORG | ||
B-ORG | ||
I-DATE | ||
I-PER | ||
I-LOC | ||
B-DATE |
Oops, something went wrong.