Skip to content

Commit

Permalink
use scipy.ndimage.binary_erosion
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Jul 17, 2024
1 parent 7f3cb70 commit 8256725
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions SIRF_data_preparation/evaluation_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import binary_erosion

import sirf.STIR as STIR
from petric import QualityMetrics
Expand All @@ -24,21 +25,18 @@ def get_metrics(qm: QualityMetrics, iters: Iterator[int], srcdir='.'):
list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters])


def pass_index(metrics: np.ndarray, thresh: np.ndarray, window: int = 1) -> int:
def pass_index(metrics: np.ndarray, thresh: Iterator, window: int = 1) -> int:
"""
Returns first index of `metrics` with value <= `thresh`.
The value must remain below the threshold for at least `window`.
Raises IndexError if doesn't pass.
The values must remain below the respective thresholds for at least `window` number of entries.
Otherwise raises IndexError.
"""
assert metrics.shape[1] == len(thresh)
thr_arr = np.asanyarray(thresh)
assert metrics.ndim == 2
assert thresh.ndim == 1

m = (metrics <= thresh[None]).all(1)
res = m[:-window]
for i in range(1, window):
res &= m[i:-window+i]
res &= m[window:]
assert thr_arr.ndim == 1
assert metrics.shape[1] == thr_arr.shape[0]
passed = (metrics <= thr_arr[None]).all(axis=1)
res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2))
return np.where(res)[0][0]


Expand Down

0 comments on commit 8256725

Please sign in to comment.