Skip to content

Commit

Permalink
Merge pull request #68 from SyneRBI/metrics-pass
Browse files Browse the repository at this point in the history
metrics: pass criterion
  • Loading branch information
casperdcl authored Jul 17, 2024
2 parents d38a919 + 8256725 commit d9fbdc0
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions SIRF_data_preparation/evaluation_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Iterator

import matplotlib.pyplot as plt
import numpy
import numpy as np
from scipy.ndimage import binary_erosion

import sirf.STIR as STIR
from petric import QualityMetrics
Expand All @@ -15,16 +16,31 @@ def read_objectives(datadir='.'):
with (Path(datadir) / 'objectives.csv').open() as csvfile:
reader = csv.reader(csvfile)
next(reader) # skip first (header) line
return numpy.asarray([tuple(map(float, row)) for row in reader])
return np.asarray([tuple(map(float, row)) for row in reader])


def get_metrics(qm: QualityMetrics, iters: Iterator[int], srcdir='.'):
"""Read 'iter_{iter_glob}.hv' images from datadir, compute metrics and return as 2d array"""
return numpy.asarray([
return np.asarray([
list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters])


def plot_metrics(iters: Iterator[int], m: numpy.ndarray, labels=None, suffix=""):
def pass_index(metrics: np.ndarray, thresh: Iterator, window: int = 1) -> int:
"""
Returns first index of `metrics` with value <= `thresh`.
The values must remain below the respective thresholds for at least `window` number of entries.
Otherwise raises IndexError.
"""
thr_arr = np.asanyarray(thresh)
assert metrics.ndim == 2
assert thr_arr.ndim == 1
assert metrics.shape[1] == thr_arr.shape[0]
passed = (metrics <= thr_arr[None]).all(axis=1)
res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2))
return np.where(res)[0][0]


def plot_metrics(iters: Iterator[int], m: np.ndarray, labels=None, suffix=""):
"""Make 2 subplots of metrics"""
if labels is None:
labels = [""] * m.shape[1]
Expand Down

0 comments on commit d9fbdc0

Please sign in to comment.