Skip to content

Commit

Permalink
check test path to file
Browse files Browse the repository at this point in the history
  • Loading branch information
dayyass committed Dec 7, 2020
1 parent 835deb3 commit bcd5aea
Showing 1 changed file with 29 additions and 25 deletions.
54 changes: 29 additions & 25 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ def main(path_to_config: str):
verbose=config["prepare_data"]["val_data"]["verbose"],
)

test_token_seq, test_label_seq = prepare_conll_data_format(
path=config["prepare_data"]["test_data"]["path"],
sep=config["prepare_data"]["test_data"]["sep"],
lower=config["prepare_data"]["test_data"]["lower"],
verbose=config["prepare_data"]["test_data"]["verbose"],
)
if "test_data" in config["prepare_data"]:
test_token_seq, test_label_seq = prepare_conll_data_format(
path=config["prepare_data"]["test_data"]["path"],
sep=config["prepare_data"]["test_data"]["sep"],
lower=config["prepare_data"]["test_data"]["lower"],
verbose=config["prepare_data"]["test_data"]["verbose"],
)

# token2idx / label2idx

Expand Down Expand Up @@ -91,13 +92,14 @@ def main(path_to_config: str):
preprocess=config["dataloader"]["preprocess"],
)

testset = NERDataset(
token_seq=test_token_seq,
label_seq=test_label_seq,
token2idx=token2idx,
label2idx=label2idx,
preprocess=config["dataloader"]["preprocess"],
)
if "test_data" in config["prepare_data"]:
testset = NERDataset(
token_seq=test_token_seq,
label_seq=test_label_seq,
token2idx=token2idx,
label2idx=label2idx,
preprocess=config["dataloader"]["preprocess"],
)

# collators

Expand All @@ -113,11 +115,12 @@ def main(path_to_config: str):
percentile=100, # hardcoded
)

test_collator = NERCollator(
token_padding_value=token2idx[config["dataloader"]["token_padding"]],
label_padding_value=label2idx[config["dataloader"]["label_padding"]],
percentile=100, # hardcoded
)
if "test_data" in config["prepare_data"]:
test_collator = NERCollator(
token_padding_value=token2idx[config["dataloader"]["token_padding"]],
label_padding_value=label2idx[config["dataloader"]["label_padding"]],
percentile=100, # hardcoded
)

# dataloaders

Expand All @@ -136,12 +139,13 @@ def main(path_to_config: str):
collate_fn=val_collator,
)

testloader = DataLoader(
dataset=testset,
batch_size=1, # hardcoded
shuffle=False, # hardcoded
collate_fn=test_collator,
)
if "test_data" in config["prepare_data"]:
testloader = DataLoader(
dataset=testset,
batch_size=1, # hardcoded
shuffle=False, # hardcoded
collate_fn=test_collator,
)

# INIT MODEL

Expand Down Expand Up @@ -208,7 +212,7 @@ def main(path_to_config: str):
model=model,
trainloader=trainloader,
valloader=valloader,
testloader=testloader,
testloader=testloader if "test_data" in config["prepare_data"] else None,
criterion=criterion,
optimizer=optimizer,
device=device,
Expand Down

0 comments on commit bcd5aea

Please sign in to comment.