diff --git a/score.py b/score.py index 30f91df..a0157bb 100644 --- a/score.py +++ b/score.py @@ -5,15 +5,18 @@ - ROC curve - PR curve """ - +from glob import glob import argparse import json import os import pandas as pd +import numpy as np from sklearn.metrics import roc_auc_score, average_precision_score -from glob import glob +GOLDSTANDARD_COLS = {"id": str, "disease": int} +PREDICTION_COLS = {"id": str, "disease_probability": np.float64} + def get_args(): """Set up command-line interface and get arguments.""" @@ -37,7 +40,10 @@ def extract_gs_file(folder): """Extract gold standard file from folder.""" files = glob(os.path.join(folder, "*")) if len(files) != 1: - raise ValueError(f"Expected exactly one gold standard file in folder. Got {len(files)}. Exiting.") + raise ValueError( + "Expected exactly one gold standard file in folder. " + f"Got {len(files)}. Exiting." + ) return files[0] @@ -52,8 +58,16 @@ def main(): gold_file = extract_gs_file(args.goldstandard_folder) if res.get("validation_status") == "VALIDATED": - pred = pd.read_csv(args.predictions_file) - gold = pd.read_csv(gold_file) + pred = pd.read_csv( + args.predictions_file, + usecols=GOLDSTANDARD_COLS, + dtype=GOLDSTANDARD_COLS + ) + gold = pd.read_csv( + gold_file, + usecols=PREDICTION_COLS, + dtype=PREDICTION_COLS + ) scores = score(gold, "disease", pred, "disease_probability") status = "SCORED" else: diff --git a/validate.py b/validate.py index d5d32d4..c9373a2 100644 --- a/validate.py +++ b/validate.py @@ -5,7 +5,7 @@ - `id` is a string - `disease_probability` is a float between 0 and 1 """ - +from glob import glob import argparse import json import os @@ -13,23 +13,16 @@ import numpy as np import pandas as pd -from glob import glob - -EXPECTED_COLS = { - 'id': str, - 'disease_probability': np.float64 -} +GOLDSTANDARD_COLS = {"id": str, "disease": int} +EXPECTED_COLS = {"id": str, "disease_probability": np.float64} def get_args(): """Set up command-line interface and get arguments.""" parser = argparse.ArgumentParser() - parser.add_argument("-p", "--predictions_file", - type=str, required=True) - parser.add_argument("-g", "--goldstandard_folder", - type=str, required=True) - parser.add_argument("-o", "--output", - type=str, default="results.json") + parser.add_argument("-p", "--predictions_file", type=str, required=True) + parser.add_argument("-g", "--goldstandard_folder", type=str, required=True) + parser.add_argument("-o", "--output", type=str, default="results.json") return parser.parse_args() @@ -38,7 +31,7 @@ def check_dups(pred): duplicates = pred.duplicated(subset=["id"]) if duplicates.any(): return ( - f"Found {duplicates.sum()} duplicate participant ID(s): " + f"Found {duplicates.sum()} duplicate ID(s): " f"{pred[duplicates].id.to_list()}" ) return "" @@ -50,7 +43,7 @@ def check_missing_ids(gold, pred): missing_ids = gold.index.difference(pred.index) if missing_ids.any(): return ( - f"Found {missing_ids.shape[0]} missing participant ID(s): " + f"Found {missing_ids.shape[0]} missing ID(s): " f"{missing_ids.to_list()}" ) return "" @@ -62,7 +55,7 @@ def check_unknown_ids(gold, pred): unknown_ids = pred.index.difference(gold.index) if unknown_ids.any(): return ( - f"Found {unknown_ids.shape[0]} unknown participant ID(s): " + f"Found {unknown_ids.shape[0]} unknown ID(s): " f"{unknown_ids.to_list()}" ) return "" @@ -72,34 +65,34 @@ def check_nan_values(pred): """Check for NAN predictions.""" missing_probs = pred["disease_probability"].isna().sum() if missing_probs: - return ( - f"'disease_probability' column contains {missing_probs} NaN value(s)." - ) + return f"'disease_probability' column contains {missing_probs} NaN value(s)." return "" def check_prob_values(pred): """Check that probabilities are between [0, 1].""" - if (pred["disease_probability"] < 0).any() or (pred["disease_probability"] > 1).any(): - return "'disease_probability' values should be between [0, 1] inclusive." + if (pred["disease_probability"] < 0).any() or \ + (pred["disease_probability"] > 1).any(): + return "'disease_probability' values should be between [0, 1]." return "" def extract_gs_file(folder): - """Extract gold standard file from folder.""" + """Extract goldstandard file from folder.""" files = glob(os.path.join(folder, "*")) if len(files) != 1: - raise ValueError(f"Expected exactly one gold standard file in folder. Got {len(files)}. Exiting.") - + raise ValueError( + "Expected exactly one goldstandard file in folder. " + f"Got {len(files)}. Exiting." + ) return files[0] def validate(gold_folder, pred_file): """Validate predictions file against goldstandard.""" errors = [] - gold_file = extract_gs_file(gold_folder) - gold = pd.read_csv(gold_file, index_col="id") + gold = pd.read_csv(gold_file, dtype=GOLDSTANDARD_COLS, index_col="id") try: pred = pd.read_csv( pred_file, @@ -140,10 +133,9 @@ def main(): # truncate validation errors if >500 (character limit for sending email) if len(invalid_reasons) > 500: invalid_reasons = invalid_reasons[:496] + "..." - res = json.dumps({ - "validation_status": status, - "validation_errors": invalid_reasons - }) + res = json.dumps( + {"validation_status": status, "validation_errors": invalid_reasons} + ) with open(args.output, "w") as out: out.write(res)