Skip to content

Commit

Permalink
Lazily compute masked arrays
Browse files Browse the repository at this point in the history
Signed-off-by: Brianna Major <[email protected]>
  • Loading branch information
bnmajor committed Jan 18, 2024
1 parent 36c9da8 commit 14cd0b1
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion hexrdgui/calibration/polar_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def write_image(self, filename='polar_image.npz'):
if mask.type == MaskType.threshold or not mask.visible:
continue

data[f'mask_{name}'] = mask.masked_arrays
data[f'mask_{name}'] = mask.get_masked_arrays(self.type)

# Delete the file if it already exists
if filename.exists():
Expand Down
4 changes: 2 additions & 2 deletions hexrdgui/calibration/polarview.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from hexrd import instrument

from hexrdgui.hexrd_config import HexrdConfig
from hexrdgui.masking.constants import MaskType
from hexrdgui.masking.constants import MaskType, ViewType
from hexrdgui.utils import SnipAlgorithmType, run_snip1d, snip_width_pixels

tvec_c = ct.zeros_3
Expand Down Expand Up @@ -417,7 +417,7 @@ def apply_masks(self, img):
for mask in MaskManager().masks.values():
if mask.type == MaskType.threshold or not mask.visible:
continue
mask_arr = mask.masked_arrays
mask_arr = mask.get_masked_arrays(ViewType.polar)
total_mask = np.logical_or(total_mask, ~mask_arr)
if (tm := MaskManager().threshold_mask) and tm.visible:
lt_val, gt_val = tm.data
Expand Down
12 changes: 4 additions & 8 deletions hexrdgui/hexrd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,18 +892,14 @@ def raw_masks_dict(self):

if mask.type == MaskType.threshold:
idx = HexrdConfig().current_imageseries_idx
thresh_mask = mask.masked_arrays[name][idx]
thresh_mask = mask.get_masked_arrays()
thresh_mask = thresh_mask[name][idx]
final_mask = np.logical_and(final_mask, thresh_mask)
else:
if MaskManager().view_mode != constants.ViewType.raw:
# Make sure we have the raw masked arrays
mask.update_masked_arrays(constants.ViewType.raw)
for det, arr in mask.masked_arrays:
masks = mask.get_masked_arrays(constants.ViewType.raw)
for det, arr in masks:
if det == name:
final_mask = np.logical_and(final_mask, arr)
if MaskManager().view_mode != constants.ViewType.raw:
# Reset the masked arrays for the current view
mask.update_masked_arrays(MaskManager().view_mode)
masks_dict[name] = final_mask

return masks_dict
Expand Down
14 changes: 13 additions & 1 deletion hexrdgui/masking/mask_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ def __init__(self, name='', mtype='', visible=True):
self.name = name
self.visible = visible
self.masked_arrays = None
self.last_view_mode = ViewType.raw

def get_masked_arrays(self, image_mode=ViewType.raw):
if self.masked_arrays is None or self.last_view_mode != image_mode:
self.update_masked_arrays(image_mode)

return self.masked_arrays

def invalidate_masked_arrays(self):
self.masked_arrays = None

# Abstract methods
@property
Expand Down Expand Up @@ -68,6 +78,7 @@ def data(self, values):
self.update_masked_arrays()

def update_masked_arrays(self, view=ViewType.raw):
self.last_view_mode = view
if view == ViewType.raw:
self.masked_arrays = create_raw_mask(self._raw)
else:
Expand Down Expand Up @@ -113,6 +124,7 @@ def data(self, values):
self.update_masked_arrays()

def update_masked_arrays(self, view=ViewType.raw):
self.last_view_mode = view
self.masked_arrays = recompute_raw_threshold_mask()

def serialize(self):
Expand Down Expand Up @@ -249,7 +261,7 @@ def load_masks(self, h5py_group):
else:
new_mask = RegionMask.deserialize(data)
self.masks[key] = new_mask
new_mask.update_masked_arrays()
new_mask.update_masked_arrays(self.view_mode)

if not HexrdConfig().loading_state:
# We're importing masks directly,
Expand Down

0 comments on commit 14cd0b1

Please sign in to comment.