Skip to content

Commit

Permalink
Handle empty transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
AyushiDaksh committed Aug 4, 2023
1 parent 204d5d2 commit 1759798
Showing 1 changed file with 76 additions and 49 deletions.
125 changes: 76 additions & 49 deletions tardis/visualization/widgets/grotrian.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,26 @@ def standardize(
# Compute lower and upper bounds of values
if min_value is None:
if zero_undefined:
min_value = values[values > 0].min()
min_value = (
values[values > 0].min() if len(values[values > 0]) > 0 else 0
)
else:
min_value = values.min()
min_value = values.min() if len(values) > 0 else 0
if max_value is None:
if zero_undefined:
max_value = values[values > 0].max()
max_value = (
values[values > 0].max() if len(values[values > 0]) > 0 else 0
)
else:
max_value = values.max()
max_value = values.max() if len(values) > 0 else 0

# Apply transformation if given
transformed_min_value = transform(min_value)
transformed_max_value = transform(max_value)
transformed_min_value = (
transform(min_value) if (min_value > 0 or not zero_undefined) else 0
)
transformed_max_value = (
transform(max_value) if (max_value > 0 or not zero_undefined) else 0
)
transformed_values = transform(values)

# Compute range
Expand All @@ -86,7 +94,7 @@ def standardize(
) / value_range
if zero_undefined:
transformed_values = transformed_values + zero_undefined_offset
transformed_values.mask(values == 0, 0, inplace=True)
transformed_values = np.where(values == 0, 0, transformed_values)
else:
# If only single value present in table, then place it at 0
transformed_values = 0 * values
Expand Down Expand Up @@ -460,42 +468,45 @@ def _compute_transitions(self):
]

### Compute default wavelengths if not set by user
if self.min_wavelength is None: # Compute default wavelength
self._min_wavelength = np.min(
np.concatenate(
(excite_lines.wavelength, deexcite_lines.wavelength)
if len(excite_lines) + len(deexcite_lines) > 0:
if self.min_wavelength is None: # Compute default wavelength
self._min_wavelength = np.min(
np.concatenate(
(excite_lines.wavelength, deexcite_lines.wavelength)
)
)
)
if self.max_wavelength is None: # Compute default wavelength
self._max_wavelength = np.max(
np.concatenate(
(excite_lines.wavelength, deexcite_lines.wavelength)
if self.max_wavelength is None: # Compute default wavelength
self._max_wavelength = np.max(
np.concatenate(
(excite_lines.wavelength, deexcite_lines.wavelength)
)
)
)

### Remove the rows outside the wavelength range for the plot
excite_lines = excite_lines.loc[
(excite_lines.wavelength >= self.min_wavelength)
& (excite_lines.wavelength <= self.max_wavelength)
]
deexcite_lines = deexcite_lines.loc[
(deexcite_lines.wavelength >= self.min_wavelength)
& (deexcite_lines.wavelength <= self.max_wavelength)
]
### Remove the rows outside the wavelength range for the plot
excite_lines = excite_lines.loc[
(excite_lines.wavelength >= self.min_wavelength)
& (excite_lines.wavelength <= self.max_wavelength)
]
deexcite_lines = deexcite_lines.loc[
(deexcite_lines.wavelength >= self.min_wavelength)
& (deexcite_lines.wavelength <= self.max_wavelength)
]

### Compute the standardized log number of electrons for arrow line width
transition_width_coefficient = standardize(
np.concatenate(
(excite_lines.num_electrons, deexcite_lines.num_electrons)
),
transform=self._transition_width_transform,
)
excite_lines[
"transition_width_coefficient"
] = transition_width_coefficient[: len(excite_lines)]
deexcite_lines[
"transition_width_coefficient"
] = transition_width_coefficient[len(excite_lines) :]
### Compute the standardized log number of electrons for arrow line width
transition_width_coefficient = standardize(
np.concatenate(
(excite_lines.num_electrons, deexcite_lines.num_electrons)
),
transform=self._transition_width_transform,
zero_undefined=True,
zero_undefined_offset=1e-3,
)
excite_lines[
"transition_width_coefficient"
] = transition_width_coefficient[: len(excite_lines)]
deexcite_lines[
"transition_width_coefficient"
] = transition_width_coefficient[len(excite_lines) :]

self.excite_lines = excite_lines
self.deexcite_lines = deexcite_lines
Expand Down Expand Up @@ -605,7 +616,7 @@ def _draw_energy_levels(self):
self.fig.add_annotation(
x=self.x_max + 0.1,
y=level_info.y_coord,
text=f"n={level_number}",
text=f"{level_number}",
showarrow=False,
xref="x2",
yref="y2",
Expand Down Expand Up @@ -684,6 +695,8 @@ def _draw_transitions(self, is_excitation):
lines["color_coefficient"] = standardize(
lines.wavelength,
transform=self._wavelength_color_transform,
zero_undefined=True,
zero_undefined_offset=1e-5,
min_value=self.min_wavelength,
max_value=self.max_wavelength,
)
Expand Down Expand Up @@ -944,10 +957,15 @@ def display(self):
)

### Create transition lines and corresponding width and color scales
self._draw_transitions(is_excitation=True)
self._draw_transitions(is_excitation=False)
self._draw_transition_width_scale()
self._draw_transition_color_scale()
if len(self.excite_lines) > 0:
self._draw_transitions(is_excitation=True)

if len(self.deexcite_lines) > 0:
self._draw_transitions(is_excitation=False)

if len(self.excite_lines) + len(self.deexcite_lines) > 0:
self._draw_transition_width_scale()
self._draw_transition_color_scale()

return self.fig

Expand Down Expand Up @@ -998,7 +1016,10 @@ def __init__(self, plot, num_shells, **kwargs):
self._ion_change_handler,
names="value",
)
self.ion_selector.observe(self._wavelength_resetter, names="value")
self.ion_selector.observe(
self._wavelength_resetter,
names="value",
)

shell_list = ["All"] + [str(i) for i in range(1, num_shells + 1)]
self.shell_selector = ipw.Dropdown(
Expand All @@ -1012,7 +1033,6 @@ def __init__(self, plot, num_shells, **kwargs):
),
names="value",
)
self.shell_selector.observe(self._wavelength_resetter, names="value")

self.max_level_selector = ipw.BoundedIntText(
value=plot.max_levels,
Expand All @@ -1025,9 +1045,6 @@ def __init__(self, plot, num_shells, **kwargs):
lambda change: self._change_handler("max_levels", change["new"]),
names="value",
)
self.max_level_selector.observe(
self._wavelength_resetter, names="value"
)

self.y_scale_selector = ipw.ToggleButtons(
options=["Linear", "Log"],
Expand Down Expand Up @@ -1107,6 +1124,7 @@ def _ion_change_handler(self, change):
children_list = list(self.fig.children)
children_list[index] = self.plot.display()
self.fig.children = tuple(children_list)
# self._wavelength_resetter()

def _wavelength_change_handler(self, change):
"""
Expand All @@ -1131,6 +1149,15 @@ def _wavelength_resetter(self, change):
"""
Resets the range of the wavelength slider whenever the ion, level or shell changes
"""
if (
self.plot.min_wavelength is None
or self.plot.max_wavelength is None
or self.plot.min_wavelength >= self.plot.max_wavelength
):
self.wavelength_range_selector.disabled = True
return

self.wavelength_range_selector.disabled = False
self.wavelength_range_selector.min = self.plot.min_wavelength
self.wavelength_range_selector.max = self.plot.max_wavelength
self.wavelength_range_selector.value = [
Expand Down

0 comments on commit 1759798

Please sign in to comment.