Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LIV Plot Tests #2723

Merged
merged 18 commits into from
Aug 22, 2024
76 changes: 27 additions & 49 deletions tardis/visualization/tools/liv_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,15 @@
----------
packets_mode : str
Packet mode, either 'virtual' or 'real'.

Returns
-------
plot_data : list
List of velocity data for each species.

plot_colors : list
List of colors corresponding to each species.
"""
groups = (
self.data[packets_mode]
.packets_df_line_interaction.loc[self.packet_nu_line_range_mask]
.groupby(by="last_line_interaction_species")
)

plot_colors = []
plot_data = []
self.plot_colors = []
self.plot_data = []
species_not_wvl_range = []
species_counter = 0

Expand All @@ -239,14 +231,15 @@
).to("km/s")
full_v_last.extend(v_last_interaction)
if full_v_last:
plot_data.append(full_v_last)
plot_colors.append(self._color_list[species_counter])
self.plot_data.append(full_v_last)
self.plot_colors.append(self._color_list[species_counter])
species_counter += 1

if species_not_wvl_range:
logger.info(
f"{species_not_wvl_range} were not found in the provided wavelength range."
"%s were not found in the provided wavelength range.",
species_not_wvl_range,
)
return plot_data, plot_colors

def _prepare_plot_data(
self,
Expand Down Expand Up @@ -291,15 +284,6 @@
ValueError
If no species are provided for plotting, or if no valid species are
found in the model.

Returns
-------
plot_data : list
List of velocity data for each species.
plot_colors : list
List of colors corresponding to each species.
new_bin_edges : np.ndarray
Array of bin edges for the velocity data.
"""
if species_list is None:
# Extract all unique elements from the packets data
Expand Down Expand Up @@ -348,7 +332,7 @@
<= packet_nu_range[0]
)

plot_data, plot_colors = self._generate_plot_data(packets_mode)
self._generate_plot_data(packets_mode)
bin_edges = (self.velocity).to("km/s")

if num_bins:
Expand All @@ -358,15 +342,13 @@
logger.warning(
"Number of bins must be less than or equal to number of shells. Plotting with number of bins equals to number of shells."
)
new_bin_edges = bin_edges
self.new_bin_edges = bin_edges

Check warning on line 345 in tardis/visualization/tools/liv_plot.py

View check run for this annotation

Codecov / codecov/patch

tardis/visualization/tools/liv_plot.py#L345

Added line #L345 was not covered by tests
else:
new_bin_edges = np.linspace(
self.new_bin_edges = np.linspace(
bin_edges[0], bin_edges[-1], num_bins + 1
)
else:
new_bin_edges = bin_edges

return plot_data, plot_colors, new_bin_edges
self.new_bin_edges = bin_edges

Check warning on line 351 in tardis/visualization/tools/liv_plot.py

View check run for this annotation

Codecov / codecov/patch

tardis/visualization/tools/liv_plot.py#L351

Added line #L351 was not covered by tests

def _get_step_plot_data(self, data, bin_edges):
"""
Expand All @@ -378,18 +360,10 @@
Data to be binned into a histogram.
bin_edges : array-like
Edges of the bins for the histogram.

Returns
-------
step_x : np.ndarray
x-coordinates for the step plot.
step_y : np.ndarray
y-coordinates for the step plot.
"""
hist, _ = np.histogram(data, bins=bin_edges)
step_x = np.repeat(bin_edges, 2)[1:-1]
step_y = np.repeat(hist, 2)
return step_x, step_y
self.step_x = np.repeat(bin_edges, 2)[1:-1]
self.step_y = np.repeat(hist, 2)

def generate_plot_mpl(
self,
Expand Down Expand Up @@ -448,7 +422,7 @@
)
nelements = None

plot_data, plot_colors, bin_edges = self._prepare_plot_data(
self._prepare_plot_data(
packets_mode,
packet_wvl_range,
species_list,
Expand All @@ -457,18 +431,20 @@
nelements,
)

bin_edges = self.new_bin_edges

if ax is None:
self.ax = plt.figure(figsize=figsize).add_subplot(111)
else:
self.ax = ax

for data, color, name in zip(
plot_data, plot_colors, self._species_name
self.plot_data, self.plot_colors, self._species_name
):
step_x, step_y = self._get_step_plot_data(data, bin_edges)
self._get_step_plot_data(data, bin_edges)
self.ax.plot(
step_x,
step_y,
self.step_x,
self.step_y,
label=name,
color=color,
linewidth=2.5,
Expand Down Expand Up @@ -548,7 +524,7 @@
)
nelements = None

plot_data, plot_colors, bin_edges = self._prepare_plot_data(
self._prepare_plot_data(
packets_mode,
packet_wvl_range,
species_list,
Expand All @@ -557,19 +533,21 @@
nelements,
)

bin_edges = self.new_bin_edges

if fig is None:
self.fig = go.Figure()
else:
self.fig = fig

for data, color, name in zip(
plot_data, plot_colors, self._species_name
self.plot_data, self.plot_colors, self._species_name
):
step_x, step_y = self._get_step_plot_data(data, bin_edges)
self._get_step_plot_data(data, bin_edges)
self.fig.add_trace(
go.Scatter(
x=step_x,
y=step_y,
x=self.step_x,
y=self.step_y,
mode="lines",
line=dict(
color=pu.to_rgb255_string(color),
Expand Down
Loading
Loading