Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model checkpoint #42

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,7 @@ train:

save:
path_to_folder: 'models/test_main/'
model_checkpoint:
save_frequency: 2
save_best_weights: true
export_onnx: true
5 changes: 5 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.utils.data import DataLoader

from pytorch_ner.dataset import NERCollator, NERDataset
from pytorch_ner.model_checkpoint import model_checkpoint
from pytorch_ner.nn_modules.architecture import BiLSTM
from pytorch_ner.nn_modules.embedding import Embedding
from pytorch_ner.nn_modules.linear import LinearHead
Expand Down Expand Up @@ -201,6 +202,10 @@ def main(path_to_config: str):
optimizer=optimizer,
device=device,
n_epoch=config["train"]["n_epoch"],
export_onnx=config["save"]["export_onnx"],
path_to_folder=config["save"]["path_to_folder"],
save_frequency=config["save"]["model_checkpoint"]["save_frequency"],
save_best_weights=config["save"]["model_checkpoint"]["save_best_weights"],
verbose=config["train"]["verbose"],
)

Expand Down
53 changes: 53 additions & 0 deletions pytorch_ner/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import json
import os
from typing import Dict

import numpy as np
import torch
import torch.nn as nn
import yaml

from pytorch_ner.onnx import onnx_export_and_check
from pytorch_ner.utils import mkdir, rmdir


def model_checkpoint(
model: nn.Module,
epoch: int,
save_best_weights: bool,
val_metrics,
val_losses,
path_to_folder: str,
export_onnx: bool,
save_frequency: int,
):

"""
This function creates check point based on either one of the two scenarios:
1. Save best weights regarding the val_loss
2. Save weights frequently with save_frequency int

"""
if save_best_weights:
if np.mean(val_metrics["loss"]) < min(val_losses):
# This iteration has lower val_loss, let's save it
val_losses.append(np.mean(val_metrics["loss"]))
pth_file_name = "best_model.pth"
onnx_file_name = "best_model.onnx"
else:
# No need to save weights
return
else:
if epoch % save_frequency == 0:
# We're at multiple of save_frequency, let's save weights
pth_file_name = "model_epoch_" + str(epoch) + ".pth"
onnx_file_name = "model_epoch_" + str(epoch) + ".onnx"
else:
# No need to save weights
return

torch.save(model.state_dict(), os.path.join(path_to_folder, pth_file_name))
if export_onnx:
onnx_export_and_check(
model=model, path_to_save=os.path.join(path_to_folder, onnx_file_name)
)
7 changes: 4 additions & 3 deletions pytorch_ner/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ def save_model(
config: Dict,
export_onnx: bool = False,
):
# make empty dir
rmdir(path_to_folder)
mkdir(path_to_folder)

if not os.path.exists(path_to_folder):
# make empty dir
mkdir(path_to_folder)

model.cpu()
model.eval()
Expand Down
28 changes: 27 additions & 1 deletion pytorch_ner/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from collections import defaultdict
from typing import Callable, DefaultDict, List, Optional

Expand All @@ -9,7 +10,9 @@
from tqdm import tqdm

from pytorch_ner.metrics import calculate_metrics
from pytorch_ner.utils import to_numpy
from pytorch_ner.model_checkpoint import model_checkpoint
from pytorch_ner.onnx import onnx_export_and_check
from pytorch_ner.utils import mkdir, rmdir, to_numpy


def masking(lengths: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -144,12 +147,23 @@ def train(
optimizer: optim.Optimizer,
device: torch.device,
n_epoch: int,
export_onnx: bool,
path_to_folder: str,
save_frequency: int,
save_best_weights: bool,
testloader: Optional[DataLoader] = None,
verbose: bool = True,
):
"""
Training / validation loop for n_epoch with final testing.
"""
if os.path.exists(path_to_folder):
# delete any previous versions of models
rmdir(path_to_folder)
mkdir(path_to_folder)

# List that tracks val_loss over training to save best weights
val_losses = [np.inf]

for epoch in range(n_epoch):

Expand Down Expand Up @@ -183,6 +197,18 @@ def train(
print(f"val {metric_name}: {np.mean(metric_list)}")
print()

# Model Checkpoint
model_checkpoint(
model=model,
epoch=epoch,
save_best_weights=save_best_weights,
val_metrics=val_metrics,
val_losses=val_losses,
path_to_folder=path_to_folder,
export_onnx=export_onnx,
save_frequency=save_frequency,
)

if testloader is not None:

test_metrics = validate_loop(
Expand Down
4 changes: 4 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
optimizer=optimizer,
device=device,
n_epoch=5,
export_onnx=True,
path_to_folder="models/test_main/",
save_frequency=1,
save_best_weights=True,
verbose=False,
)

Expand Down