diff --git a/config.yaml b/config.yaml index f770aa6..a8d092f 100644 --- a/config.yaml +++ b/config.yaml @@ -30,10 +30,11 @@ dataloader: batch_size: 64 model: + architecture: BiLSTM embedding: embedding_dim: 128 rnn: - rnn_unit: nn.LSTM + rnn_unit: LSTM # GRU, RNN hidden_size: 256 num_layers: 1 dropout: 0 diff --git a/main.py b/main.py index 289d551..4b6014a 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,6 @@ from torch.utils.data import DataLoader from pytorch_ner.dataset import NERCollator, NERDataset -from pytorch_ner.nn_modules.architecture import BiLSTM from pytorch_ner.nn_modules.embedding import Embedding from pytorch_ner.nn_modules.linear import LinearHead from pytorch_ner.nn_modules.rnn import DynamicRNN @@ -19,6 +18,7 @@ ) from pytorch_ner.save import save_model from pytorch_ner.train import train +from pytorch_ner.utils import str_to_class def main(path_to_config: str): @@ -148,10 +148,11 @@ def main(path_to_config: str): ) rnn_layer = DynamicRNN( - rnn_unit=eval(config["model"]["rnn"]["rnn_unit"]), # TODO: fix eval - input_size=config["model"]["embedding"][ - "embedding_dim" - ], # reference to embedding_dim + rnn_unit=str_to_class( + module_name="torch.nn", + class_name=config["model"]["rnn"]["rnn_unit"], + ), + input_size=config["model"]["embedding"]["embedding_dim"], # ref to emb_dim hidden_size=config["model"]["rnn"]["hidden_size"], num_layers=config["model"]["rnn"]["num_layers"], dropout=config["model"]["rnn"]["dropout"], @@ -169,9 +170,12 @@ def main(path_to_config: str): ), ) - # TODO: add model architecture in config # TODO: add attention if needed - model = BiLSTM( + model_class = str_to_class( + module_name="pytorch_ner.nn_modules.architecture", + class_name=config["model"]["architecture"], + ) + model = model_class( embedding_layer=embedding_layer, rnn_layer=rnn_layer, linear_head=linear_head, diff --git a/pytorch_ner/utils.py b/pytorch_ner/utils.py index 149a068..58d3690 100644 --- a/pytorch_ner/utils.py +++ b/pytorch_ner/utils.py @@ -1,3 +1,4 @@ +import importlib import os import random import shutil @@ -44,3 +45,16 @@ def rmdir(path: str): if os.path.exists(path): shutil.rmtree(path) + + +def str_to_class(module_name, class_name): + """ + Convert string to Python class object. + https://stackoverflow.com/questions/1176136/convert-string-to-python-class-object + """ + + # load the module, will raise ImportError if module cannot be loaded + module = importlib.import_module(module_name) + # get the class, will raise AttributeError if class cannot be found + cls = getattr(module, class_name) + return cls diff --git a/tests/test_onnx.py b/tests/test_onnx.py index f7e1043..b674291 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -1,6 +1,7 @@ +from test_nn_modules.test_architecture import model_bilstm as model + from pytorch_ner.onnx import onnx_export_and_check from pytorch_ner.utils import mkdir -from tests.test_nn_modules.test_architecture import model_bilstm as model path_to_save = "models/model.onnx" mkdir("models")