Skip to content

Commit

Permalink
Update score.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jaymedina committed May 7, 2024
1 parent a955184 commit f5d5a3c
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions score.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@

import argparse
import json
import os

import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score

from glob import glob

def get_args():
"""Set up command-line interface and get arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--predictions_file", type=str, required=True)
parser.add_argument("-g", "--goldstandard_file", type=str, required=True)
parser.add_argument("-g", "--goldstandard_folder", type=str, required=True)
parser.add_argument("-o", "--output", type=str, default="results.json")
return parser.parse_args()

Expand All @@ -31,16 +33,27 @@ def score(gold, gold_col, pred, pred_col):
return {"auc_roc": roc, "auprc": pr}


def extract_gs_file(folder):
"""Extract gold standard file from folder."""
files = glob(os.path.join(folder, "*"))
if len(files) != 1:
raise ValueError(f"Expected exactly one gold standard file in folder. Got {len(files)}. Exiting.")

return files[0]


def main():
"""Main function."""
args = get_args()

with open(args.output, encoding="utf-8") as out:
res = json.load(out)

gold_file = extract_gs_file(args.goldstandard_folder)

if res.get("validation_status") == "VALIDATED":
pred = pd.read_csv(args.predictions_file)
gold = pd.read_csv(args.goldstandard_file)
gold = pd.read_csv(gold_file)
scores = score(gold, "disease", pred, "disease_probability")
status = "SCORED"
else:
Expand Down

0 comments on commit f5d5a3c

Please sign in to comment.