diff --git a/validate.py b/validate.py index 96e8cfe..d5d32d4 100644 --- a/validate.py +++ b/validate.py @@ -8,9 +8,12 @@ import argparse import json +import os -import pandas as pd import numpy as np +import pandas as pd + +from glob import glob EXPECTED_COLS = { 'id': str, @@ -23,7 +26,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("-p", "--predictions_file", type=str, required=True) - parser.add_argument("-g", "--goldstandard_file", + parser.add_argument("-g", "--goldstandard_folder", type=str, required=True) parser.add_argument("-o", "--output", type=str, default="results.json") @@ -82,10 +85,20 @@ def check_prob_values(pred): return "" -def validate(gold_file, pred_file): +def extract_gs_file(folder): + """Extract gold standard file from folder.""" + files = glob(os.path.join(folder, "*")) + if len(files) != 1: + raise ValueError(f"Expected exactly one gold standard file in folder. Got {len(files)}. Exiting.") + + return files[0] + + +def validate(gold_folder, pred_file): """Validate predictions file against goldstandard.""" errors = [] + gold_file = extract_gs_file(gold_folder) gold = pd.read_csv(gold_file, index_col="id") try: pred = pd.read_csv( @@ -117,7 +130,7 @@ def main(): errors = [f.read()] else: errors = validate( - gold_file=args.goldstandard_file, + gold_folder=args.goldstandard_folder, pred_file=args.predictions_file )