Skip to content

Commit

Permalink
Make seed health checks optional in synthesizers.
Browse files Browse the repository at this point in the history
  • Loading branch information
tymorrow committed Feb 14, 2024
1 parent ab0ec57 commit 5f33251
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
10 changes: 7 additions & 3 deletions riid/data/synthetic/passby.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def _generate_single_passby(self, fwhm: float, snr: float, dwell_time: float,

return fg_ss, gross_ss

def generate(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet, verbose: bool = True) \
def generate(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet,
skip_health_check: bool = False, verbose: bool = True) \
-> List[Tuple[SampleSet, SampleSet, SampleSet]]:
"""Generate a list of `SampleSet`s where each contains a pass-by as a sequence of spectra.
Expand All @@ -234,6 +235,7 @@ def generate(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet, verbose: bool
source component(s) of spectra
bg_seeds_ss: spectra normalized by total counts to be used as the
background components of gross spectra
skip_health_check: whether to skip seed health checks
verbose: whether to display output from synthesis
Returns:
Expand All @@ -245,8 +247,10 @@ def generate(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet, verbose: bool
"""
if not fg_seeds_ss or not bg_seeds_ss:
raise ValueError("At least one foreground and background seed must be provided.")
fg_seeds_ss.check_seed_health()
bg_seeds_ss.check_seed_health()

if not skip_health_check:
fg_seeds_ss.check_seed_health()
bg_seeds_ss.check_seed_health()

self._reset_progress()
if verbose:
Expand Down
23 changes: 16 additions & 7 deletions riid/data/synthetic/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ def __init__(self, seeds_ss: SampleSet, mixture_size: int = 2, dirichlet_alpha:

self._check_seeds()

def _check_seeds(self):
self.seeds_ss.check_seed_health()
def _check_seeds(self, skip_health_check: bool = False):
if not skip_health_check:
self.seeds_ss.check_seed_health()
n_sources_per_row = np.count_nonzero(
self.seeds_ss.get_source_contributions(),
axis=1
Expand All @@ -187,7 +188,8 @@ def _check_seeds(self):
"All seeds must have the same energy calibration."
))

def __call__(self, n_samples: int, max_batch_size: int = 100) -> Iterator[SampleSet]:
def __call__(self, n_samples: int, max_batch_size: int = 100,
skip_health_check: bool = False) -> Iterator[SampleSet]:
"""Yields batches of seeds one at a time until a specified number of samples has
been reached.
Expand All @@ -204,12 +206,13 @@ def __call__(self, n_samples: int, max_batch_size: int = 100) -> Iterator[Sample
Args:
n_samples: total number of mixture seeds to produce across all batches
max_batch_size: maxmimum size of a batch per yield
max_batch_size: maximum size of a batch per yield
skip_health_check: whether to skip the seed health check
Returns:
Generator of `SampleSet`s
"""
self._check_seeds()
self._check_seeds(skip_health_check)

isotope_to_seeds = self.seeds_ss.sources_columns_to_dict(target_level="Isotope")
isotopes = list(isotope_to_seeds.keys())
Expand Down Expand Up @@ -298,11 +301,17 @@ def __call__(self, n_samples: int, max_batch_size: int = 100) -> Iterator[Sample

yield batch_ss

def generate(self, n_samples: int, max_batch_size: int = 100) -> SampleSet:
def generate(self, n_samples: int, max_batch_size: int = 100,
skip_health_check: bool = False) -> SampleSet:
"""Computes random mixtures of seeds at the isotope level.
"""
batches = []
for batch_ss in self(n_samples, max_batch_size=max_batch_size):
batch_iterable = self(
n_samples,
max_batch_size=max_batch_size,
skip_health_check=skip_health_check
)
for batch_ss in batch_iterable:
batches.append(batch_ss)
mixtures_ss = SampleSet()
mixtures_ss.spectra_type = self.seeds_ss.spectra_type
Expand Down
7 changes: 5 additions & 2 deletions riid/data/synthetic/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def _get_synthetic_samples(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet,
return fg_ss, gross_ss

def generate(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet,
skip_health_check: bool = False,
verbose: bool = True) -> Tuple[SampleSet, SampleSet]:
"""Generate a `SampleSet` of gamma spectra from the provided config.
Expand All @@ -240,6 +241,7 @@ def generate(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet,
`bg_seeds_ss`, which represent mixtures of K-U-T, get added on top.
Note: this spectrum is not considered part of the `bg_cps` parameter,
but is instead added on top of it.
skip_health_check: whether to skip seed health checks
verbose: whether to show detailed output
Returns:
Expand All @@ -251,8 +253,9 @@ def generate(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet,
"""
if not fg_seeds_ss or not bg_seeds_ss:
raise ValueError("At least one foreground and background seed must be provided.")
fg_seeds_ss.check_seed_health()
bg_seeds_ss.check_seed_health()
if not skip_health_check:
fg_seeds_ss.check_seed_health()
bg_seeds_ss.check_seed_health()

self._reset_progress()
if verbose:
Expand Down

0 comments on commit 5f33251

Please sign in to comment.