diff --git a/SIRF_data_preparation/evaluation_utilities.py b/SIRF_data_preparation/evaluation_utilities.py index b65d5bd..29cf75b 100644 --- a/SIRF_data_preparation/evaluation_utilities.py +++ b/SIRF_data_preparation/evaluation_utilities.py @@ -4,7 +4,8 @@ from typing import Iterator import matplotlib.pyplot as plt -import numpy +import numpy as np +from scipy.ndimage import binary_erosion import sirf.STIR as STIR from petric import QualityMetrics @@ -15,16 +16,31 @@ def read_objectives(datadir='.'): with (Path(datadir) / 'objectives.csv').open() as csvfile: reader = csv.reader(csvfile) next(reader) # skip first (header) line - return numpy.asarray([tuple(map(float, row)) for row in reader]) + return np.asarray([tuple(map(float, row)) for row in reader]) def get_metrics(qm: QualityMetrics, iters: Iterator[int], srcdir='.'): """Read 'iter_{iter_glob}.hv' images from datadir, compute metrics and return as 2d array""" - return numpy.asarray([ + return np.asarray([ list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters]) -def plot_metrics(iters: Iterator[int], m: numpy.ndarray, labels=None, suffix=""): +def pass_index(metrics: np.ndarray, thresh: Iterator, window: int = 1) -> int: + """ + Returns first index of `metrics` with value <= `thresh`. + The values must remain below the respective thresholds for at least `window` number of entries. + Otherwise raises IndexError. + """ + thr_arr = np.asanyarray(thresh) + assert metrics.ndim == 2 + assert thr_arr.ndim == 1 + assert metrics.shape[1] == thr_arr.shape[0] + passed = (metrics <= thr_arr[None]).all(axis=1) + res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2)) + return np.where(res)[0][0] + + +def plot_metrics(iters: Iterator[int], m: np.ndarray, labels=None, suffix=""): """Make 2 subplots of metrics""" if labels is None: labels = [""] * m.shape[1]