Skip to content

Commit

Permalink
Lazily compute masked arrays
Browse files Browse the repository at this point in the history
Signed-off-by: Brianna Major <[email protected]>
  • Loading branch information
bnmajor committed Jan 21, 2024
1 parent 36c9da8 commit 8ec3332
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 23 deletions.
2 changes: 1 addition & 1 deletion hexrdgui/calibration/polar_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def write_image(self, filename='polar_image.npz'):
if mask.type == MaskType.threshold or not mask.visible:
continue

data[f'mask_{name}'] = mask.masked_arrays
data[f'mask_{name}'] = mask.get_masked_arrays(self.type)

# Delete the file if it already exists
if filename.exists():
Expand Down
3 changes: 2 additions & 1 deletion hexrdgui/calibration/polarview.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from hexrd import instrument

from hexrdgui.constants import ViewType
from hexrdgui.hexrd_config import HexrdConfig
from hexrdgui.masking.constants import MaskType
from hexrdgui.utils import SnipAlgorithmType, run_snip1d, snip_width_pixels
Expand Down Expand Up @@ -417,7 +418,7 @@ def apply_masks(self, img):
for mask in MaskManager().masks.values():
if mask.type == MaskType.threshold or not mask.visible:
continue
mask_arr = mask.masked_arrays
mask_arr = mask.get_masked_arrays(ViewType.polar)
total_mask = np.logical_or(total_mask, ~mask_arr)
if (tm := MaskManager().threshold_mask) and tm.visible:
lt_val, gt_val = tm.data
Expand Down
12 changes: 4 additions & 8 deletions hexrdgui/hexrd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,18 +892,14 @@ def raw_masks_dict(self):

if mask.type == MaskType.threshold:
idx = HexrdConfig().current_imageseries_idx
thresh_mask = mask.masked_arrays[name][idx]
thresh_mask = mask.get_masked_arrays()
thresh_mask = thresh_mask[name][idx]
final_mask = np.logical_and(final_mask, thresh_mask)
else:
if MaskManager().view_mode != constants.ViewType.raw:
# Make sure we have the raw masked arrays
mask.update_masked_arrays(constants.ViewType.raw)
for det, arr in mask.masked_arrays:
masks = mask.get_masked_arrays(constants.ViewType.raw)
for det, arr in masks:
if det == name:
final_mask = np.logical_and(final_mask, arr)
if MaskManager().view_mode != constants.ViewType.raw:
# Reset the masked arrays for the current view
mask.update_masked_arrays(MaskManager().view_mode)
masks_dict[name] = final_mask

return masks_dict
Expand Down
12 changes: 4 additions & 8 deletions hexrdgui/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,15 +753,13 @@ def run_apply_hand_drawn_mask(self, dets, line_data):
for line in line_data:
name = unique_name(MaskManager().mask_names, 'polar_mask_0')
raw_line = convert_polar_to_raw([line])
mask = MaskManager().add_mask(name, raw_line, MaskType.polygon)
mask.update_masked_arrays(self.image_mode)
MaskManager().add_mask(name, raw_line, MaskType.polygon)
MaskManager().polar_masks_changed.emit()
elif self.image_mode == ViewType.raw:
for det, line in zip(dets, line_data):
name = unique_name(MaskManager().mask_names, 'raw_mask_0')
mask = MaskManager().add_mask(
MaskManager().add_mask(
name, [(det, line.copy())], MaskType.polygon)
mask.update_masked_arrays(self.image_mode)
MaskManager().raw_masks_changed.emit()
self.new_mask_added.emit(self.image_mode)

Expand Down Expand Up @@ -791,8 +789,7 @@ def on_action_edit_apply_laue_mask_to_polar_triggered(self):

name = unique_name(MaskManager().mask_names, 'laue_mask')
raw_data = convert_polar_to_raw(data)
mask = MaskManager().add_mask(name, raw_data, MaskType.laue)
mask.update_masked_arrays(self.image_mode)
MaskManager().add_mask(name, raw_data, MaskType.laue)
self.new_mask_added.emit(self.image_mode)
MaskManager().polar_masks_changed.emit()

Expand Down Expand Up @@ -847,8 +844,7 @@ def action_edit_apply_powder_mask_to_polar(self):

name = unique_name(MaskManager().mask_names, 'powder_mask')
raw_data = convert_polar_to_raw(data)
mask = MaskManager().add_mask(name, raw_data, MaskType.powder)
mask.update_masked_arrays(self.image_mode)
MaskManager().add_mask(name, raw_data, MaskType.powder)
self.new_mask_added.emit(self.image_mode)
MaskManager().polar_masks_changed.emit()

Expand Down
2 changes: 1 addition & 1 deletion hexrdgui/masking/create_polar_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ def rebuild_polar_masks():
for mask in MaskManager().masks.values():
if mask.type == MaskType.threshold:
continue
mask.update_masked_arrays(ViewType.polar)
mask.invalidate_masked_arrays()
2 changes: 1 addition & 1 deletion hexrdgui/masking/create_raw_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ def create_raw_mask(line_data):
def rebuild_raw_masks():
from hexrdgui.masking.mask_manager import MaskManager
for mask in MaskManager().masks.values():
mask.update_masked_arrays(ViewType.raw)
mask.invalidate_masked_arrays()
18 changes: 17 additions & 1 deletion hexrdgui/masking/mask_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ def __init__(self, name='', mtype='', visible=True):
self.name = name
self.visible = visible
self.masked_arrays = None
self.masked_arrays_view_mode = ViewType.raw

def get_masked_arrays(self):
if self.masked_arrays is None:
self.update_masked_arrays()

return self.masked_arrays

def invalidate_masked_arrays(self):
self.masked_arrays = None

# Abstract methods
@property
Expand Down Expand Up @@ -68,11 +78,18 @@ def data(self, values):
self.update_masked_arrays()

def update_masked_arrays(self, view=ViewType.raw):
self.masked_arrays_view_mode = view
if view == ViewType.raw:
self.masked_arrays = create_raw_mask(self._raw)
else:
self.masked_arrays = create_polar_mask_from_raw(self._raw)

def get_masked_arrays(self, image_mode=ViewType.raw):
if self.masked_arrays is None or self.masked_arrays_view_mode != image_mode:
self.update_masked_arrays(image_mode)

return self.masked_arrays

def serialize(self):
data = {
'name': self.name,
Expand Down Expand Up @@ -249,7 +266,6 @@ def load_masks(self, h5py_group):
else:
new_mask = RegionMask.deserialize(data)
self.masks[key] = new_mask
new_mask.update_masked_arrays()

if not HexrdConfig().loading_state:
# We're importing masks directly,
Expand Down
1 change: 0 additions & 1 deletion hexrdgui/masking/mask_regions_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def create_masks(self):
elif self.image_mode == 'polar':
coords = convert_polar_to_raw(data)
mask = MaskManager().add_mask(name, coords, MaskType.region)
mask.update_masked_arrays(self.image_mode)

masks_changed_signal = {
'raw': MaskManager().raw_masks_changed,
Expand Down
1 change: 0 additions & 1 deletion hexrdgui/masking/threshold_mask_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,4 @@ def accept(self):
'threshold', self.values, MaskType.threshold)
else:
MaskManager().threshold_mask.data = self.values
MaskManager().threshold_mask.update_masked_arrays()
self.mask_applied.emit()

0 comments on commit 8ec3332

Please sign in to comment.