diff --git a/tardis/visualization/widgets/grotrian.py b/tardis/visualization/widgets/grotrian.py index c20b4f06016..b5465670b48 100644 --- a/tardis/visualization/widgets/grotrian.py +++ b/tardis/visualization/widgets/grotrian.py @@ -62,18 +62,26 @@ def standardize( # Compute lower and upper bounds of values if min_value is None: if zero_undefined: - min_value = values[values > 0].min() + min_value = ( + values[values > 0].min() if len(values[values > 0]) > 0 else 0 + ) else: - min_value = values.min() + min_value = values.min() if len(values) > 0 else 0 if max_value is None: if zero_undefined: - max_value = values[values > 0].max() + max_value = ( + values[values > 0].max() if len(values[values > 0]) > 0 else 0 + ) else: - max_value = values.max() + max_value = values.max() if len(values) > 0 else 0 # Apply transformation if given - transformed_min_value = transform(min_value) - transformed_max_value = transform(max_value) + transformed_min_value = ( + transform(min_value) if (min_value > 0 or not zero_undefined) else 0 + ) + transformed_max_value = ( + transform(max_value) if (max_value > 0 or not zero_undefined) else 0 + ) transformed_values = transform(values) # Compute range @@ -86,7 +94,7 @@ def standardize( ) / value_range if zero_undefined: transformed_values = transformed_values + zero_undefined_offset - transformed_values.mask(values == 0, 0, inplace=True) + transformed_values = np.where(values == 0, 0, transformed_values) else: # If only single value present in table, then place it at 0 transformed_values = 0 * values @@ -460,42 +468,45 @@ def _compute_transitions(self): ] ### Compute default wavelengths if not set by user - if self.min_wavelength is None: # Compute default wavelength - self._min_wavelength = np.min( - np.concatenate( - (excite_lines.wavelength, deexcite_lines.wavelength) + if len(excite_lines) + len(deexcite_lines) > 0: + if self.min_wavelength is None: # Compute default wavelength + self._min_wavelength = np.min( + np.concatenate( + (excite_lines.wavelength, deexcite_lines.wavelength) + ) ) - ) - if self.max_wavelength is None: # Compute default wavelength - self._max_wavelength = np.max( - np.concatenate( - (excite_lines.wavelength, deexcite_lines.wavelength) + if self.max_wavelength is None: # Compute default wavelength + self._max_wavelength = np.max( + np.concatenate( + (excite_lines.wavelength, deexcite_lines.wavelength) + ) ) - ) - ### Remove the rows outside the wavelength range for the plot - excite_lines = excite_lines.loc[ - (excite_lines.wavelength >= self.min_wavelength) - & (excite_lines.wavelength <= self.max_wavelength) - ] - deexcite_lines = deexcite_lines.loc[ - (deexcite_lines.wavelength >= self.min_wavelength) - & (deexcite_lines.wavelength <= self.max_wavelength) - ] + ### Remove the rows outside the wavelength range for the plot + excite_lines = excite_lines.loc[ + (excite_lines.wavelength >= self.min_wavelength) + & (excite_lines.wavelength <= self.max_wavelength) + ] + deexcite_lines = deexcite_lines.loc[ + (deexcite_lines.wavelength >= self.min_wavelength) + & (deexcite_lines.wavelength <= self.max_wavelength) + ] - ### Compute the standardized log number of electrons for arrow line width - transition_width_coefficient = standardize( - np.concatenate( - (excite_lines.num_electrons, deexcite_lines.num_electrons) - ), - transform=self._transition_width_transform, - ) - excite_lines[ - "transition_width_coefficient" - ] = transition_width_coefficient[: len(excite_lines)] - deexcite_lines[ - "transition_width_coefficient" - ] = transition_width_coefficient[len(excite_lines) :] + ### Compute the standardized log number of electrons for arrow line width + transition_width_coefficient = standardize( + np.concatenate( + (excite_lines.num_electrons, deexcite_lines.num_electrons) + ), + transform=self._transition_width_transform, + zero_undefined=True, + zero_undefined_offset=1e-3, + ) + excite_lines[ + "transition_width_coefficient" + ] = transition_width_coefficient[: len(excite_lines)] + deexcite_lines[ + "transition_width_coefficient" + ] = transition_width_coefficient[len(excite_lines) :] self.excite_lines = excite_lines self.deexcite_lines = deexcite_lines @@ -605,7 +616,7 @@ def _draw_energy_levels(self): self.fig.add_annotation( x=self.x_max + 0.1, y=level_info.y_coord, - text=f"n={level_number}", + text=f"{level_number}", showarrow=False, xref="x2", yref="y2", @@ -684,6 +695,8 @@ def _draw_transitions(self, is_excitation): lines["color_coefficient"] = standardize( lines.wavelength, transform=self._wavelength_color_transform, + zero_undefined=True, + zero_undefined_offset=1e-5, min_value=self.min_wavelength, max_value=self.max_wavelength, ) @@ -944,10 +957,15 @@ def display(self): ) ### Create transition lines and corresponding width and color scales - self._draw_transitions(is_excitation=True) - self._draw_transitions(is_excitation=False) - self._draw_transition_width_scale() - self._draw_transition_color_scale() + if len(self.excite_lines) > 0: + self._draw_transitions(is_excitation=True) + + if len(self.deexcite_lines) > 0: + self._draw_transitions(is_excitation=False) + + if len(self.excite_lines) + len(self.deexcite_lines) > 0: + self._draw_transition_width_scale() + self._draw_transition_color_scale() return self.fig @@ -998,7 +1016,10 @@ def __init__(self, plot, num_shells, **kwargs): self._ion_change_handler, names="value", ) - self.ion_selector.observe(self._wavelength_resetter, names="value") + self.ion_selector.observe( + self._wavelength_resetter, + names="value", + ) shell_list = ["All"] + [str(i) for i in range(1, num_shells + 1)] self.shell_selector = ipw.Dropdown( @@ -1012,7 +1033,6 @@ def __init__(self, plot, num_shells, **kwargs): ), names="value", ) - self.shell_selector.observe(self._wavelength_resetter, names="value") self.max_level_selector = ipw.BoundedIntText( value=plot.max_levels, @@ -1025,9 +1045,6 @@ def __init__(self, plot, num_shells, **kwargs): lambda change: self._change_handler("max_levels", change["new"]), names="value", ) - self.max_level_selector.observe( - self._wavelength_resetter, names="value" - ) self.y_scale_selector = ipw.ToggleButtons( options=["Linear", "Log"], @@ -1107,6 +1124,7 @@ def _ion_change_handler(self, change): children_list = list(self.fig.children) children_list[index] = self.plot.display() self.fig.children = tuple(children_list) + # self._wavelength_resetter() def _wavelength_change_handler(self, change): """ @@ -1131,6 +1149,15 @@ def _wavelength_resetter(self, change): """ Resets the range of the wavelength slider whenever the ion, level or shell changes """ + if ( + self.plot.min_wavelength is None + or self.plot.max_wavelength is None + or self.plot.min_wavelength >= self.plot.max_wavelength + ): + self.wavelength_range_selector.disabled = True + return + + self.wavelength_range_selector.disabled = False self.wavelength_range_selector.min = self.plot.min_wavelength self.wavelength_range_selector.max = self.plot.max_wavelength self.wavelength_range_selector.value = [