diff --git a/tardis/montecarlo/base.py b/tardis/montecarlo/base.py index 91f6819538d..ccc6a9bb0b4 100644 --- a/tardis/montecarlo/base.py +++ b/tardis/montecarlo/base.py @@ -221,6 +221,7 @@ def _initialize_packets(self, T, no_of_packets, iteration, radius): self.input_energy = energies self._output_nu = np.ones(no_of_packets, dtype=np.float64) * -99.0 + self._output_r = np.ones(no_of_packets, dtype=np.float64) * -99.0 self._output_energy = np.ones(no_of_packets, dtype=np.float64) * -99.0 self.last_line_interaction_in_id = -1 * np.ones( @@ -385,6 +386,10 @@ def get_line_interaction_id(self, line_interaction_type): def output_nu(self): return u.Quantity(self._output_nu, u.Hz) + @property + def output_r(self): + return u.Quantity(self._output_r, u.cm) + @property def output_energy(self): return u.Quantity(self._output_energy, u.erg) diff --git a/tardis/montecarlo/montecarlo_numba/base.py b/tardis/montecarlo/montecarlo_numba/base.py index 6dd4af64b94..7aabd006537 100644 --- a/tardis/montecarlo/montecarlo_numba/base.py +++ b/tardis/montecarlo/montecarlo_numba/base.py @@ -47,6 +47,7 @@ def montecarlo_radial1d( runner.input_mu, runner.input_energy, runner._output_nu, + runner._output_r, runner._output_energy, ) @@ -76,6 +77,7 @@ def montecarlo_radial1d( v_packets_energy_hist, last_interaction_type, last_interaction_in_nu, + last_interaction_in_r, last_line_interaction_in_id, last_line_interaction_out_id, virt_packet_nus, @@ -83,6 +85,7 @@ def montecarlo_radial1d( virt_packet_initial_mus, virt_packet_initial_rs, virt_packet_last_interaction_in_nu, + virt_packet_last_interaction_in_r, virt_packet_last_interaction_type, virt_packet_last_line_interaction_in_id, virt_packet_last_line_interaction_out_id, @@ -104,6 +107,7 @@ def montecarlo_radial1d( runner._montecarlo_virtual_luminosity.value[:] = v_packets_energy_hist runner.last_interaction_type = last_interaction_type runner.last_interaction_in_nu = last_interaction_in_nu + runner.last_interaction_in_r = last_interaction_in_r runner.last_line_interaction_in_id = last_line_interaction_in_id runner.last_line_interaction_out_id = last_line_interaction_out_id @@ -121,6 +125,9 @@ def montecarlo_radial1d( runner.virt_packet_last_interaction_in_nu = np.concatenate( virt_packet_last_interaction_in_nu ).ravel() + runner.virt_packet_last_interaction_in_r = np.concatenate( + np.array(virt_packet_last_interaction_in_r) + ).ravel() runner.virt_packet_last_interaction_type = np.concatenate( virt_packet_last_interaction_type ).ravel() @@ -170,12 +177,14 @@ def montecarlo_main_loop( Option to enable virtual packet logging. """ output_nus = np.empty_like(packet_collection.packets_input_nu) + output_rs = np.empty_like(packet_collection.packets_output_r) last_interaction_types = ( np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1 ) output_energies = np.empty_like(packet_collection.packets_output_nu) last_interaction_in_nus = np.empty_like(packet_collection.packets_output_nu) + last_interaction_in_rs = np.empty_like(packet_collection.packets_output_r) last_line_interaction_in_ids = ( np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1 ) @@ -227,6 +236,7 @@ def montecarlo_main_loop( virt_packet_initial_mus = [] virt_packet_initial_rs = [] virt_packet_last_interaction_in_nu = [] + virt_packet_last_interaction_in_r = [] virt_packet_last_interaction_type = [] virt_packet_last_line_interaction_in_id = [] virt_packet_last_line_interaction_out_id = [] @@ -268,7 +278,9 @@ def montecarlo_main_loop( ) output_nus[i] = r_packet.nu + output_rs[i] = r_packet.r last_interaction_in_nus[i] = r_packet.last_interaction_in_nu + last_interaction_in_rs[i] = r_packet.last_interaction_in_r last_line_interaction_in_ids[i] = r_packet.last_line_interaction_in_id last_line_interaction_out_ids[i] = r_packet.last_line_interaction_out_id @@ -329,6 +341,11 @@ def montecarlo_main_loop( ] ) ) + virt_packet_last_interaction_in_r.append( + vpacket_collection.last_interaction_in_r[ + : vpacket_collection.idx + ] + ) virt_packet_last_interaction_type.append( np.ascontiguousarray( vpacket_collection.last_interaction_type[ @@ -357,10 +374,12 @@ def montecarlo_main_loop( packet_collection.packets_output_energy[:] = output_energies[:] packet_collection.packets_output_nu[:] = output_nus[:] + packet_collection.packets_output_r[:] = output_rs[:] return ( v_packets_energy_hist, last_interaction_types, last_interaction_in_nus, + last_interaction_in_rs, last_line_interaction_in_ids, last_line_interaction_out_ids, virt_packet_nus, @@ -368,6 +387,7 @@ def montecarlo_main_loop( virt_packet_initial_mus, virt_packet_initial_rs, virt_packet_last_interaction_in_nu, + virt_packet_last_interaction_in_r, virt_packet_last_interaction_type, virt_packet_last_line_interaction_in_id, virt_packet_last_line_interaction_out_id, diff --git a/tardis/montecarlo/montecarlo_numba/numba_interface.py b/tardis/montecarlo/montecarlo_numba/numba_interface.py index 884991e45a1..2dc72435ed3 100644 --- a/tardis/montecarlo/montecarlo_numba/numba_interface.py +++ b/tardis/montecarlo/montecarlo_numba/numba_interface.py @@ -282,6 +282,7 @@ def numba_plasma_initialize(plasma, line_interaction_type): ("packets_input_mu", float64[:]), ("packets_input_energy", float64[:]), ("packets_output_nu", float64[:]), + ("packets_output_r", float64[:]), ("packets_output_energy", float64[:]), ] @@ -295,6 +296,7 @@ def __init__( packets_input_mu, packets_input_energy, packets_output_nu, + packets_output_r, packets_output_energy, ): self.packets_input_radius = packets_input_radius @@ -302,6 +304,7 @@ def __init__( self.packets_input_mu = packets_input_mu self.packets_input_energy = packets_input_energy self.packets_output_nu = packets_output_nu + self.packets_output_r = packets_output_r self.packets_output_energy = packets_output_energy @@ -318,6 +321,7 @@ def __init__( ("number_of_vpackets", int64), ("length", int64), ("last_interaction_in_nu", float64[:]), + ("last_interaction_in_r", float64[:]), ("last_interaction_type", int64[:]), ("last_interaction_in_id", int64[:]), ("last_interaction_out_id", int64[:]), @@ -346,6 +350,9 @@ def __init__( self.last_interaction_in_nu = np.zeros( temporary_v_packet_bins, dtype=np.float64 ) + self.last_interaction_in_r = np.zeros( + temporary_v_packets_bins, dtype=np.float64 + ) self.last_interaction_type = -1 * np.ones( temporary_v_packet_bins, dtype=np.int64 ) @@ -366,6 +373,7 @@ def set_properties( initial_mu, initial_r, last_interaction_in_nu, + last_interaction_in_r, last_interaction_type, last_interaction_in_id, last_interaction_out_id, @@ -379,6 +387,7 @@ def set_properties( temp_last_interaction_in_nu = np.empty( temp_length, dtype=np.float64 ) + temp_last_interaction_in_r = np.empty(temp_length, dtype=np.float64) temp_last_interaction_type = np.empty(temp_length, dtype=np.int64) temp_last_interaction_in_id = np.empty(temp_length, dtype=np.int64) temp_last_interaction_out_id = np.empty(temp_length, dtype=np.int64) @@ -390,6 +399,9 @@ def set_properties( temp_last_interaction_in_nu[ : self.length ] = self.last_interaction_in_nu + temp_last_interaction_in_r[ + : self.length + ] = self.last_interaction_in_r temp_last_interaction_type[ : self.length ] = self.last_interaction_type @@ -405,7 +417,8 @@ def set_properties( self.initial_mus = temp_initial_mus self.initial_rs = temp_initial_rs self.last_interaction_in_nu = temp_last_interaction_in_nu - self.last_interaction_type = temp_last_interaction_type + self.last_interaction_in_r = temp_last_interaction_in_r + self.last_interaction_type = temp_last_interaction_type self.last_interaction_in_id = temp_last_interaction_in_id self.last_interaction_out_id = temp_last_interaction_out_id self.length = temp_length @@ -415,6 +428,7 @@ def set_properties( self.initial_mus[self.idx] = initial_mu self.initial_rs[self.idx] = initial_r self.last_interaction_in_nu[self.idx] = last_interaction_in_nu + self.last_interaction_in_r[self.idx] = last_interaction_in_r self.last_interaction_type[self.idx] = last_interaction_type self.last_interaction_in_id[self.idx] = last_interaction_in_id self.last_interaction_out_id[self.idx] = last_interaction_out_id diff --git a/tardis/montecarlo/montecarlo_numba/r_packet.py b/tardis/montecarlo/montecarlo_numba/r_packet.py index 30fb6fbe0d5..c93bd0ebce8 100644 --- a/tardis/montecarlo/montecarlo_numba/r_packet.py +++ b/tardis/montecarlo/montecarlo_numba/r_packet.py @@ -40,6 +40,7 @@ class PacketStatus(IntEnum): ("index", int64), ("last_interaction_type", int64), ("last_interaction_in_nu", float64), + ("last_interaction_in_r", float64), ("last_line_interaction_in_id", int64), ("last_line_interaction_out_id", int64), ] @@ -58,6 +59,7 @@ def __init__(self, r, mu, nu, energy, seed, index=0): self.index = index self.last_interaction_type = -1 self.last_interaction_in_nu = 0.0 + self.last_interaction_in_r = 0.0 self.last_line_interaction_in_id = -1 self.last_line_interaction_out_id = -1 diff --git a/tardis/montecarlo/montecarlo_numba/vpacket.py b/tardis/montecarlo/montecarlo_numba/vpacket.py index af4c4c9f763..57234d19181 100644 --- a/tardis/montecarlo/montecarlo_numba/vpacket.py +++ b/tardis/montecarlo/montecarlo_numba/vpacket.py @@ -282,6 +282,7 @@ def trace_vpacket_volley( v_packet_mu, r_packet.r, r_packet.last_interaction_in_nu, + r_packet.last_interaction_in_r, r_packet.last_interaction_type, r_packet.last_line_interaction_in_id, r_packet.last_line_interaction_out_id, diff --git a/tardis/visualization/__init__.py b/tardis/visualization/__init__.py index 4b806fdd147..9edc30d3805 100644 --- a/tardis/visualization/__init__.py +++ b/tardis/visualization/__init__.py @@ -9,3 +9,4 @@ from tardis.visualization.widgets.line_info import LineInfoWidget from tardis.visualization.widgets.custom_abundance import CustomAbundanceWidget from tardis.visualization.tools.sdec_plot import SDECPlotter +from tardis.visualization.tools.interaction_radius_plot import InteractionRadiusPlotter diff --git a/tardis/visualization/tools/interaction_radius_plot.py b/tardis/visualization/tools/interaction_radius_plot.py new file mode 100644 index 00000000000..a98ecb4adea --- /dev/null +++ b/tardis/visualization/tools/interaction_radius_plot.py @@ -0,0 +1,334 @@ +""" +Last interaction radius plot package for TARDIS simulations. + +This plot is a spectral diagnostics plot similar to those originally +proposed in Williamson et al. (2021). +""" + +import tardis.visualization.tools.sdec_plot as sdec + +import numpy as np +import pandas as pd +import astropy.units as u + +from tardis.util.base import ( + atomic_number2element_symbol, + element_symbol2atomic_number, + species_string_to_tuple, + species_tuple_to_string, + roman_to_int, + int_to_roman, +) + +import matplotlib.pyplot as plt +import matplotlib.cm as cm +import matplotlib.colors as clr +import plotly.graph_objects as go + + +class InteractionRadiusPlotter: + """ + Plotting interface for the interaction radius plot. + """ + + def __init__(self, data, time_explosion): + """ + Initialize the plotter with required data from the simulation. + + Parameters + ---------- + data : dict of SDECData + Dictionary to store data required for interaction radius plot, + for both packet modes (real, virtual). + """ + + self.data = data + self.time_explosion = time_explosion + return + + @classmethod + def from_simulation(cls, sim): + """ + Create an instance of the plotter from a TARDIS simulation object. + + Parameters + ---------- + sim : tardis.simulation.Simulation + TARDIS simulation object produced by running a simulation. + + Returns + ------- + Plotter + """ + + return cls(dict(virtual=sdec.SDECData.from_simulation(sim, "virtual"), + real=sdec.SDECData.from_simulation(sim, "real")), + sim.model.time_explosion) + + @classmethod + def from_hdf(cls, hdf_fpath): + """ + Create an instance of the Plotter from a simulation HDF file. + + Parameters + ---------- + hdf_fpath : str + Valid path to the HDF file where simulation is saved. + + Returns + ------- + Plotter + """ + hdfstore = pd.HDFStore(hdf_fpath) + time_explosion = hdfstore['/simulation/plasma/scalars']['time_explosion'] * u.s + return cls(dict(virtual=sdec.SDECData.from_hdf(hdf_fpath, "virtual"), + real=sdec.SDECData.from_hdf(hdf_fpath, "real")), + ) + + def _parse_species_list(self, species_list): + """ + Parse user requested species list and create list of species ids to be used. + + Parameters + ---------- + species_list : list of species to plot + List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours. + Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions + (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) + + """ + if species_list is not None: + # check if there are any digits in the species list. If there are, then exit. + # species_list should only contain species in the Roman numeral + # format, e.g. Si II, and each ion must contain a space + if any(char.isdigit() for char in " ".join(species_list)) == True: + raise ValueError( + "All species must be in Roman numeral form, e.g. Si II" + ) + else: + full_species_list = [] + for species in species_list: + # check if a hyphen is present. If it is, then it indicates a + # range of ions. Add each ion in that range to the list as a new entry + if "-" in species: + # split the string on spaces. First thing in the list is then the element + element = species.split(" ")[0] + # Next thing is the ion range + # convert the requested ions into numerals + first_ion_numeral = roman_to_int( + species.split(" ")[-1].split("-")[0] + ) + second_ion_numeral = roman_to_int( + species.split(" ")[-1].split("-")[-1] + ) + # add each ion between the two requested into the species list + for ion_number in np.arange( + first_ion_numeral, second_ion_numeral + 1 + ): + full_species_list.append( + f"{element} {int_to_roman(ion_number)}" + ) + else: + # Otherwise it's either an element or ion so just add to the list + full_species_list.append(species) + + # full_species_list is now a list containing each individual species requested + # e.g. it parses species_list = [Si I - V] into species_list = [Si I, Si II, Si III, Si IV, Si V] + self._full_species_list = full_species_list + requested_species_ids = [] + keep_colour = [] + + # go through each of the requested species. Check whether it is + # an element or ion (ions have spaces). If it is an element, + # add all possible ions to the ions list. Otherwise just add + # the requested ion + for species in full_species_list: + if " " in species: + requested_species_ids.append( + [ + species_string_to_tuple(species)[0] * 100 + + species_string_to_tuple(species)[1] + ] + ) + else: + atomic_number = element_symbol2atomic_number(species) + requested_species_ids.append( + [ + atomic_number * 100 + ion_number + for ion_number in np.arange(atomic_number) + ] + ) + # add the atomic number to a list so you know that this element should + # have all species in the same colour, i.e. it was requested like + # species_list = [Si] + keep_colour.append(atomic_number) + requested_species_ids = [ + species_id + for temp_list in requested_species_ids + for species_id in temp_list + ] + + self._species_list = requested_species_ids + self._keep_colour = keep_colour + else: + self._species_list = None + return + + def _make_colorbar_labels(self): + """Get the labels for the species in the colorbar.""" + if self._species_list is None: + # If species_list is none then the labels are just elements + species_name = [ + atomic_number2element_symbol(atomic_num) + for atomic_num in self.species + ] + else: + species_name = [] + for species in self.species: + # Go through each species requested + ion_number = species % 100 + atomic_number = (species - ion_number) / 100 + + ion_numeral = int_to_roman(ion_number + 1) + atomic_symbol = atomic_number2element_symbol(atomic_number) + + # if the element was requested, and not a specific ion, then + # add the element symbol to the label list + if (atomic_number in self._keep_colour) & ( + atomic_symbol not in species_name + ): + # compiling the label, and adding it to the list + label = f"{atomic_symbol}" + species_name.append(label) + elif atomic_number not in self._keep_colour: + # otherwise add the ion to the label list + label = f"{atomic_symbol} {ion_numeral}" + species_name.append(label) + + self._species_name = species_name + return + + def _make_colorbar_colors(self): + """Get the colours for the species to be plotted.""" + # the colours depends on the species present in the model and what's requested + # some species need to be shown in the same colour, so the exact colours have to be + # worked out + + color_list = [] + + # Colors for each element + # Create new variables to keep track of the last atomic number that was plotted + # This is used when plotting species in case an element was given in the list + # This is to ensure that all ions of that element are grouped together + # ii is to track the colour index + # e.g. if Si is given in species_list, this is to ensure Si I, Si II, etc. all have the same colour + color_counter = 0 + previous_atomic_number = 0 + for species_counter, identifier in enumerate(self.species): + if self._species_list is not None: + # Get the ion number and atomic number for each species + ion_number = identifier % 100 + atomic_number = (identifier - ion_number) / 100 + if previous_atomic_number == 0: + # If this is the first species being plotted, then take note of the atomic number + # don't update the colour index + color_counter = color_counter + previous_atomic_number = atomic_number + elif previous_atomic_number in self._keep_colour: + # If the atomic number is in the list of elements that should all be plotted in the same colour + # then don't update the colour index if this element has been plotted already + if previous_atomic_number == atomic_number: + color_counter = color_counter + previous_atomic_number = atomic_number + else: + # Otherwise, increase the colour counter by one, because this is a new element + color_counter = color_counter + 1 + previous_atomic_number = atomic_number + else: + # If this is just a normal species that was requested then increment the colour index + color_counter = color_counter + 1 + previous_atomic_number = atomic_number + # Calculate the colour of this species + color = self.cmap(color_counter / len(self._species_name)) + + else: + # If you're not using species list then this is just a fraction based on the total + # number of columns in the dataframe + color = self.cmap(species_counter / len(self.species)) + + color_list.append(color) + + self._color_list = color_list + + return + + def _show_colorbar_mpl(self): + """Show matplotlib colorbar with labels of elements mapped to colors.""" + + color_values = [ + self.cmap(species_counter / len(self._species_name)) + for species_counter in range(len(self._species_name)) + ] + + custcmap = clr.ListedColormap(color_values) + norm = clr.Normalize(vmin=0, vmax=len(self._species_name)) + mappable = cm.ScalarMappable(norm=norm, cmap=custcmap) + mappable.set_array(np.linspace(1, len(self._species_name) + 1, 256)) + cbar = plt.colorbar(mappable, ax=self.ax) + + bounds = np.arange(len(self._species_name)) + 0.5 + cbar.set_ticks(bounds) + + cbar.set_ticklabels(self._species_name) + return + + def generate_plot_mpl(self, + packets_mode="virtual", + ax=None, + figsize=(12, 7), + cmapname="jet", + species_list=None): + """ + Generate the last interaction radius distribution plot + using matplotlib. + """ + + # Parse the requested species list + self._parse_species_list(species_list=species_list) + species_in_model = np.unique( + self.data[packets_mode].packets_df_line_interaction['last_line_interaction_species'].values) + msk = np.isin(self._species_list, species_in_model) + self.species = np.array(self._species_list)[msk] + + if ax is None: + self.ax = plt.figure(figsize=figsize).add_subplot(111) + else: + self.ax = ax + + # Get the labels in the color bar. This determines the number of unique colors + self._make_colorbar_labels() + # Set colormap to be used in elements of emission and absorption plots + self.cmap = cm.get_cmap(cmapname, len(self._species_name)) + # Get the number of unqie colors + self._make_colorbar_colors() + self._show_colorbar_mpl() + + groups = self.data[packets_mode].packets_df_line_interaction.groupby(by='last_line_interaction_species') + + plot_colors = [] + plot_data = [] + + for species_counter, identifier in enumerate(self.species): + g_df = groups.get_group(identifier) + r_last_interaction = g_df['last_interaction_in_r'].values * u.cm + v_last_interaction = (r_last_interaction / self.time_explosion).to('km/s') + plot_data.append(v_last_interaction) + plot_colors.append(self._color_list[species_counter]) + + self.ax.hist(plot_data, bins=50, color=plot_colors) + self.ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) + self.ax.tick_params('both', labelsize=20) + self.ax.set_xlabel('Last Interaction Velocity (km/s)', fontsize=25) + self.ax.set_ylabel('Packet Count', fontsize=25) + + return plt.gca() diff --git a/tardis/visualization/tools/sdec_plot.py b/tardis/visualization/tools/sdec_plot.py index 4dffec02d91..de5a3abe968 100644 --- a/tardis/visualization/tools/sdec_plot.py +++ b/tardis/visualization/tools/sdec_plot.py @@ -42,6 +42,7 @@ def __init__( last_line_interaction_in_id, last_line_interaction_out_id, last_line_interaction_in_nu, + last_interaction_in_r, lines_df, packet_nus, packet_energies, @@ -69,6 +70,8 @@ def __init__( emission (interaction out) last_line_interaction_in_nu : np.array Frequency values of the last absorption of emitted packets + last_line_interaction_in_r : np.array + Radius of the last interaction experienced by emitted packets lines_df : pd.DataFrame Data about the atomic lines present in simulation model's plasma packet_nus : astropy.Quantity @@ -99,6 +102,7 @@ def __init__( "last_line_interaction_out_id": last_line_interaction_out_id, "last_line_interaction_in_id": last_line_interaction_in_id, "last_line_interaction_in_nu": last_line_interaction_in_nu, + "last_interaction_in_r": last_interaction_in_r } ) @@ -175,6 +179,7 @@ def from_simulation(cls, sim, packets_mode): last_line_interaction_in_id=sim.runner.virt_packet_last_line_interaction_in_id, last_line_interaction_out_id=sim.runner.virt_packet_last_line_interaction_out_id, last_line_interaction_in_nu=sim.runner.virt_packet_last_interaction_in_nu, + last_interaction_in_r=sim.runner.virt_packet_last_interaction_in_r, lines_df=lines_df, packet_nus=u.Quantity(sim.runner.virt_packet_nus, "Hz"), packet_energies=u.Quantity( @@ -205,6 +210,9 @@ def from_simulation(cls, sim, packets_mode): last_line_interaction_in_nu=sim.runner.last_interaction_in_nu[ sim.runner.emitted_packet_mask ], + last_interaction_in_r=sim.runner.last_interaction_in_r[ + sim.runner.emitted_packet_mask + ], lines_df=lines_df, packet_nus=sim.runner.output_nu[sim.runner.emitted_packet_mask], packet_energies=sim.runner.output_energy[ @@ -271,6 +279,12 @@ def from_hdf(cls, hdf_fpath, packets_mode): ].to_numpy(), "Hz", ), + last_interaction_in_r=u.Quantity( + hdf[ + "/simulation/runner/virt_packet_last_interaction_in_r" + ].to_numpy(), + "cm", + ), lines_df=lines_df, packet_nus=u.Quantity( hdf["/simulation/runner/virt_packet_nus"].to_numpy(), @@ -333,6 +347,12 @@ def from_hdf(cls, hdf_fpath, packets_mode): ].to_numpy()[emitted_packet_mask], "Hz", ), + last_interaction_in_r=u.Quantity( + hdf[ + "/simulation/runner/last_interaction_in_r" + ].to_numpy()[emitted_packet_mask], + "cm", + ), lines_df=lines_df, packet_nus=u.Quantity( hdf["/simulation/runner/output_nu"].to_numpy()[ @@ -381,6 +401,8 @@ def from_hdf(cls, hdf_fpath, packets_mode): ) + + class SDECPlotter: """ Plotting interface for Spectral element DEComposition (SDEC) Plot.