Skip to content

Commit

Permalink
Update validate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jaymedina committed May 7, 2024
1 parent af429ae commit a955184
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@

import argparse
import json
import os

import pandas as pd
import numpy as np
import pandas as pd

from glob import glob

EXPECTED_COLS = {
'id': str,
Expand All @@ -23,7 +26,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--predictions_file",
type=str, required=True)
parser.add_argument("-g", "--goldstandard_file",
parser.add_argument("-g", "--goldstandard_folder",
type=str, required=True)
parser.add_argument("-o", "--output",
type=str, default="results.json")
Expand Down Expand Up @@ -82,10 +85,20 @@ def check_prob_values(pred):
return ""


def validate(gold_file, pred_file):
def extract_gs_file(folder):
"""Extract gold standard file from folder."""
files = glob(os.path.join(folder, "*"))
if len(files) != 1:
raise ValueError(f"Expected exactly one gold standard file in folder. Got {len(files)}. Exiting.")

return files[0]


def validate(gold_folder, pred_file):
"""Validate predictions file against goldstandard."""
errors = []

gold_file = extract_gs_file(gold_folder)
gold = pd.read_csv(gold_file, index_col="id")
try:
pred = pd.read_csv(
Expand Down Expand Up @@ -117,7 +130,7 @@ def main():
errors = [f.read()]
else:
errors = validate(
gold_file=args.goldstandard_file,
gold_folder=args.goldstandard_folder,
pred_file=args.predictions_file
)

Expand Down

0 comments on commit a955184

Please sign in to comment.