diff --git a/main.py b/main.py index 07c8745..62ec815 100644 --- a/main.py +++ b/main.py @@ -52,12 +52,13 @@ def main(path_to_config: str): verbose=config["prepare_data"]["val_data"]["verbose"], ) - test_token_seq, test_label_seq = prepare_conll_data_format( - path=config["prepare_data"]["test_data"]["path"], - sep=config["prepare_data"]["test_data"]["sep"], - lower=config["prepare_data"]["test_data"]["lower"], - verbose=config["prepare_data"]["test_data"]["verbose"], - ) + if "test_data" in config["prepare_data"]: + test_token_seq, test_label_seq = prepare_conll_data_format( + path=config["prepare_data"]["test_data"]["path"], + sep=config["prepare_data"]["test_data"]["sep"], + lower=config["prepare_data"]["test_data"]["lower"], + verbose=config["prepare_data"]["test_data"]["verbose"], + ) # token2idx / label2idx @@ -91,13 +92,14 @@ def main(path_to_config: str): preprocess=config["dataloader"]["preprocess"], ) - testset = NERDataset( - token_seq=test_token_seq, - label_seq=test_label_seq, - token2idx=token2idx, - label2idx=label2idx, - preprocess=config["dataloader"]["preprocess"], - ) + if "test_data" in config["prepare_data"]: + testset = NERDataset( + token_seq=test_token_seq, + label_seq=test_label_seq, + token2idx=token2idx, + label2idx=label2idx, + preprocess=config["dataloader"]["preprocess"], + ) # collators @@ -113,11 +115,12 @@ def main(path_to_config: str): percentile=100, # hardcoded ) - test_collator = NERCollator( - token_padding_value=token2idx[config["dataloader"]["token_padding"]], - label_padding_value=label2idx[config["dataloader"]["label_padding"]], - percentile=100, # hardcoded - ) + if "test_data" in config["prepare_data"]: + test_collator = NERCollator( + token_padding_value=token2idx[config["dataloader"]["token_padding"]], + label_padding_value=label2idx[config["dataloader"]["label_padding"]], + percentile=100, # hardcoded + ) # dataloaders @@ -136,12 +139,13 @@ def main(path_to_config: str): collate_fn=val_collator, ) - testloader = DataLoader( - dataset=testset, - batch_size=1, # hardcoded - shuffle=False, # hardcoded - collate_fn=test_collator, - ) + if "test_data" in config["prepare_data"]: + testloader = DataLoader( + dataset=testset, + batch_size=1, # hardcoded + shuffle=False, # hardcoded + collate_fn=test_collator, + ) # INIT MODEL @@ -208,7 +212,7 @@ def main(path_to_config: str): model=model, trainloader=trainloader, valloader=valloader, - testloader=testloader, + testloader=testloader if "test_data" in config["prepare_data"] else None, criterion=criterion, optimizer=optimizer, device=device,