- PyTorch BERT fine-tuning classification model
This directory contains implementation of the BERT fine-tuning model for sequence (sentence) classification tasks. We present two flavors of the implementation:
- Single sentence classification using Stanford Sentiment Treebank (SST-2) dataset.
- Sentence pair classification using Multi-Genre Natural Language Inference (MNLI) dataset.
This is a fine-tuning task, meaning, the BERT model [1] is initialized with the weights generated by pre-training the model in a self-supervised fashion on the masked language modeling (MLM) and the next sentence prediction (NSP) tasks. During fine-tuning, the model architecture is adjusted by replacing MLM and NSP output heads with a classification head. The new output head is constructed by using two dense layers, feeding the [CLS]
representation to a these layers for classification. The first dense layer is loaded from the pre-trained NSP head and the second dense layer is initialized with random weights.
The following block diagram shows a high-level view of the sequence of steps you will perform in this example.
Fig.1 - Flow Chart of steps to fine-tune classification model
configs/
: YAML configuration files.model.py
: Model implementation leveraging BertForSequenceClassification class.run.py
: Training script. Performs training and validation.utils.py
: Miscellaneous helper functions.
In order to download the raw dataset for GLUE benchmark, download the dataset helper files from the GLUE benchmark repository by cloning it using the below command:
git clone https://github.com/nyu-mll/GLUE-baselines
To download the SST-2 and MNLI dataset, run the below command:
python download_glue_data.py --data_dir </path/to/data_dir> --tasks SST,MNLI
More information about the dataset and the script can be found on GLUE Benchmark.
There is no separate data preparation step as we have designed the data input pipeline to use the downloaded .tsv
files in the previous step. You can use the same path used in --data_dir
flag above to populate the data_dir
(for train_input
and eval_input
) field in the yaml
config file.
The input to the model takes in either one sentence (A
) or a pair of sentences (A
, B
). Sentence A
is preceded by the [CLS]
special token and is separated from sentence B
by the [SEP]
special token. The BERT model pre-training uses a special segment embedding layer at the input to signify the two sentences used in the NSP task. Here, that segment embedding is used to identify sentence A
and a sentence B
if it exists in the input. See below for details.
The labels for this model indicate the class of the sentence to be classified for single sentence classification or the class indicating the relationship between the input sentence pair. This is supplied in the form of an integer corresponding the the class label.
If you want to use your own input function with this example code, then this section describes the input data format expected by BertForSequenceClassificationModel
class defined in model.py. When you create your own custom BERT Classifier input function, you must ensure that your input function produces a features dictionary and a label tensor as described in this section.
The input features and labels are passed to the model in one dictionary has the following key/values:
-
input_ids
: Input token IDs, padded with0
s tomax_sequence_length
. The tokens in the dataset are mapped to these IDs using the vocabulary file. These values should be between0
andvocab_size - 1
, inclusive. The first token should be the special[CLS]
token. The end of each sentence should be marked by the special[SEP]
token. So, a sentence pair should be separated by additional[SEP]
token.- Shape:
[batch_size, max_sequence_length]
- Type:
torch.int32
- Shape:
-
attention_mask
: Mask for padded positions. Has values0
on the padded positions and1
elsewhere.- Shape:
[batch_size, max_sequence_length]
- Type:
torch.int32
- Shape:
-
token_type_ids
: Segment IDs. A tensor the same size as theinput_ids
designating to which segment each token belongs. Each element of this tensor should only take the value0
or1
. The[CLS]
token at the start, the question and the subsequent[SEP]
token should all have the segment value0
. The context and the subsequent[SEP]
token should have the segment value1
. All padding tokens after the last[SEP]
should be in segment id0
. Shown asSegments
in the example below.- Shape:
[batch_size, max_sequence_length]
- Type:
torch.int32
- Shape:
-
labels
: Labels are integers per input sequence or sequence pairs corresponding to their classification labels.- Shape:
[batch_size, ]
- Type:
torch.int32
- Shape:
An example of the input string and segment ID structure for single sentence classification:
Tokens: [CLS] The Matrix is a good movie. [SEP] [PAD] [PAD] ...
Segments: 0 0 0 0 0 0 0 0 0 0 ...
An example of the input string and segment ID structure for sentence pair classification:
Tokens: [CLS] The Matrix is a good movie. [SEP] So The Matrix Reloaded must also be good. [SEP] [PAD] [PAD] ...
Segments: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 ...
NOTE: The input tokens are converted to IDs using the vocab file.
IMPORTANT: See the following notes before proceeding further.
Parameter settings in YAML config file: The config YAML files are located in the configs directory. Before starting a pre-training run, make sure that in the YAML config file you are using:
- The
train_input.data_dir
parameter points to the correct dataset, and - The
train_input.max_sequence_length
parameter corresponds to the sequence length of the dataset. - The
train_input.batch_size
parameter will set the batch size for the training.
Same applies for the eval_input
.
YAML config file differences:
Please check YAML config section for details on each config supported out of the box for this model.
Please follow the instructions on our quickstart in the Developer Docs.
Note: To specify a BERT pretrained checkpoint use:
--checkpoint_path
is the path to the saved checkpoint from BERT pre-training,--load_checkpoint_states="model"
setting is needed for loading the pre-trained BERT model for fine-tuning and--disable_strict_checkpoint_loading
is needed to be able to only partially load a model.
If running on a cpu or gpu, activate the environment from Python GPU Environment setup, and simply run:
python run.py CPU --mode train --params /path/to/yaml --model_dir /path/to/model_dir
or
python run.py GPU --mode train --params /path/to/yaml --model_dir /path/to/model_dir
Note: Change the command to
--mode eval
for evaluation.
Below is the list of yaml config files included for this model implementation at configs folder:
bert_base_*.yaml
have the standard bert-base config withhidden_size=768, num_hidden_layers=12, num_heads=12
as a backbone.bert_large_*.yaml
have the standard bert-large config withhidden_size=1024, num_hidden_layers=24, num_heads=16
as a backbone.
Additionally,
bert_*_sst2_*.yaml
corresponds to SST2 dataset config.bert_*_mnli_*.yaml
corresponds to MNLI dataset config.
[1] BERT paper