From f537b24668e44a242a1b44b76c85fb7ae93c306a Mon Sep 17 00:00:00 2001 From: Asish Kumar Date: Thu, 18 Jul 2024 16:44:38 +0530 Subject: [PATCH] fast benchmark Signed-off-by: Asish Kumar --- .github/workflows/benchmarks.yml | 8 +- .github/workflows/tests.yml | 88 ++- benchmarks/benchmark_base.py | 37 +- .../transport_geometry_calculate_distances.py | 3 +- .../transport_montecarlo_interaction.py | 14 +- ...port_montecarlo_numba_formal_integral_p.py | 12 +- .../transport_montecarlo_numba_interface.py | 5 +- benchmarks/transport_montecarlo_opacities.py | 88 ++- benchmarks/transport_montecarlo_packet.py | 48 +- benchmarks/transport_montecarlo_vpacket.py | 29 +- docs/io/output/vpacket_logging.rst | 3 + docs/io/visualization/how_to_liv_plot.ipynb | 529 +++++++++++++++++ docs/io/visualization/index.rst | 1 + pyproject.toml | 2 + tardis/opacities/opacities.py | 2 +- tardis/opacities/opacity_state.py | 6 +- tardis/simulation/base.py | 3 +- tardis/transport/frame_transformations.py | 2 +- .../transport/geometry/calculate_distances.py | 2 +- tardis/transport/montecarlo/base.py | 19 +- .../montecarlo/configuration/__init__.py | 0 .../base.py} | 10 +- .../constants.py} | 0 .../configuration/montecarlo_globals.py | 2 + .../estimators/radfield_estimator_calcs.py | 2 +- .../transport/montecarlo/formal_integral.py | 7 +- .../montecarlo/formal_integral_cuda.py | 3 +- tardis/transport/montecarlo/interaction.py | 14 +- .../montecarlo/montecarlo_main_loop.py | 5 +- .../montecarlo/montecarlo_transport_state.py | 17 + .../transport/montecarlo/numba_interface.py | 9 +- .../montecarlo/packet_collections.py | 23 + tardis/transport/montecarlo/r_packet.py | 3 +- .../montecarlo/r_packet_transport.py | 7 +- .../montecarlo/single_packet_loop.py | 23 +- tardis/transport/montecarlo/tests/conftest.py | 1 - .../transport/montecarlo/tests/test_base.py | 2 + .../montecarlo/tests/test_interaction.py | 1 - .../montecarlo/tests/test_numba_interface.py | 13 +- .../transport/montecarlo/tests/test_packet.py | 7 +- .../montecarlo/tests/test_packet_source.py | 3 - .../montecarlo/tests/test_vpacket.py | 4 - tardis/transport/montecarlo/vpacket.py | 11 +- tardis/visualization/__init__.py | 1 + tardis/visualization/plot_util.py | 19 + tardis/visualization/tools/liv_plot.py | 532 ++++++++++++++++++ tardis/visualization/tools/sdec_plot.py | 70 +-- .../tools/tests/test_rpacket_plot.py | 1 + 48 files changed, 1354 insertions(+), 337 deletions(-) create mode 100644 docs/io/visualization/how_to_liv_plot.ipynb create mode 100644 tardis/transport/montecarlo/configuration/__init__.py rename tardis/transport/montecarlo/{montecarlo_configuration.py => configuration/base.py} (91%) rename tardis/transport/montecarlo/{numba_config.py => configuration/constants.py} (100%) create mode 100644 tardis/transport/montecarlo/configuration/montecarlo_globals.py create mode 100644 tardis/visualization/tools/liv_plot.py diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index fcefb32cf68..66e962f3525 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -68,11 +68,11 @@ jobs: - name: Accept all asv questions run: asv machine --yes - - name: Run benchmarks for last 5 commits if not PR + - name: Run benchmarks for last 4 commits if not PR if: github.event_name != 'pull_request_target' run: | - git log -n 5 --pretty=format:"%H" >> tag_commits.txt - asv run HASHFILE:tag_commits.txt | tee asv-output.log + git log -n 4 --pretty=format:"%H" >> tag_commits.txt + asv run -a repeat=1 -a rounds=1 HASHFILE:tag_commits.txt | tee asv-output.log if grep -q failed asv-output.log; then echo "Some benchmarks have failed!" exit 1 @@ -126,7 +126,7 @@ jobs: echo $(git rev-parse HEAD) > commit_hashes.txt echo $(git merge-base HEAD upstream/master) >> commit_hashes.txt - asv run HASHFILE:commit_hashes.txt | tee asv-output-PR.log + asv run -a repeat=1 -a rounds=1 HASHFILE:commit_hashes.txt | tee asv-output-PR.log if grep -q failed asv-output-PR.log; then echo "Some benchmarks have failed!" exit 1 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index dab7e450cdf..58351efa50d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,6 @@ on: push: branches: - '*' - pull_request: branches: - '*' @@ -22,24 +21,52 @@ on: env: CACHE_NUMBER: 0 # increase to reset cache manually PYTEST_FLAGS: --tardis-refdata=${{ github.workspace }}/tardis-refdata --tardis-regression-data=${{ github.workspace }}/tardis-regression-data - --cov=tardis --cov-report=xml --cov-report=html CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - defaults: run: shell: bash -l {0} + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + jobs: - build: + codecov: if: github.repository_owner == 'tardis-sn' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup LFS + uses: ./.github/actions/setup_lfs + - name: Setup environment + uses: ./.github/actions/setup_env + with: + os-label: linux-64 + - name: Generate coverage report + run: | + pytest --cov=tardis --cov-report=xml --cov-report=html + - uses: codecov/codecov-action@v4 + if: always() + with: + fail_ci_if_error: true + token: ${{ env.CODECOV_TOKEN }} + verbose: true + + tests: + name: ${{ matrix.continuum }} continuum ${{ matrix.rpacket_tracking }} rpacket_tracking ${{ matrix.os }} + if: github.repository_owner == 'tardis-sn' + runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: - pip: [true, false] label: [osx-arm64, linux-64] + continuum: ['not', ''] + rpacket_tracking: ['not', ''] + exclude: + - continuum: '' + rpacket_tracking: '' include: - label: osx-arm64 os: macos-latest @@ -48,9 +75,6 @@ jobs: - label: linux-64 os: ubuntu-latest prefix: /usr/share/miniconda3/envs/tardis - - name: ${{ matrix.label }}-pip-${{ matrix.pip }} - runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 @@ -59,46 +83,16 @@ jobs: - name: Setup environment uses: ./.github/actions/setup_env - with: + with: os-label: ${{ matrix.label }} - - - name: Install package editable - run: | - pip install -e . - echo "TARDIS_PIP_PATH=tardis" >> $GITHUB_ENV - if: matrix.pip == false - - - name: Install package git - run: pip install git+https://github.com/tardis-sn/tardis.git@${{ github.ref }} - if: matrix.pip == true - - - name: Set pip path - if: matrix.pip == true - run: | - location_line=$(pip show tardis | grep -i -x "Location:.*") - directory_path=$(echo $location_line | awk -F " " '{print $2}') - echo "TARDIS_PIP_PATH=$directory_path" >> $GITHUB_ENV - - name: Set install path - if: matrix.pip == false + - name: Install package editable run: | - directory_path="." - echo "TARDIS_PIP_PATH=$directory_path" >> $GITHUB_ENV + pip install -e . --user - name: Run tests - run: pytest tardis ${{ env.PYTEST_FLAGS }} -m "not continuum" - working-directory: ${{ env.TARDIS_PIP_PATH }} - if: always() - - - name: Run continuum tests - run: pytest tardis ${{ env.PYTEST_FLAGS }} -m continuum - working-directory: ${{ env.TARDIS_PIP_PATH }} - if: always() - - - name: Upload to Codecov - run: bash <(curl -s https://codecov.io/bash) + run: pytest tardis ${{ env.PYTEST_FLAGS }} -m "${{ matrix.continuum }} continuum and ${{ matrix.rpacket_tracking }} rpacket_tracking" - name: Refdata Generation tests - run: pytest tardis ${{ env.PYTEST_FLAGS }} --generate-reference - working-directory: ${{ env.TARDIS_PIP_PATH }} - if: contains(github.event.pull_request.labels.*.name, 'run-generation-tests') || github.ref == 'refs/heads/master' + run: pytest tardis ${{ env.PYTEST_FLAGS }} --generate-reference -m "${{ matrix.continuum }} continuum and ${{ matrix.rpacket_tracking }} rpacket_tracking" + if: contains(github.event.pull_request.labels.*.name, 'run-generation-tests') || github.ref == 'refs/heads/master' \ No newline at end of file diff --git a/benchmarks/benchmark_base.py b/benchmarks/benchmark_base.py index 4136fd8b6eb..f57e56ea531 100644 --- a/benchmarks/benchmark_base.py +++ b/benchmarks/benchmark_base.py @@ -18,7 +18,11 @@ from tardis.simulation import Simulation from tardis.tests.fixtures.atom_data import DEFAULT_ATOM_DATA_UUID from tardis.tests.fixtures.regression_data import RegressionData -from tardis.transport.montecarlo import RPacket, montecarlo_configuration +from tardis.transport.montecarlo import RPacket +from tardis.transport.montecarlo.configuration import montecarlo_globals +from tardis.transport.montecarlo.configuration.base import ( + MonteCarloConfiguration, +) from tardis.transport.montecarlo.estimators import radfield_mc_estimators from tardis.transport.montecarlo.numba_interface import opacity_state_initialize from tardis.transport.montecarlo.packet_collections import ( @@ -66,7 +70,8 @@ def tardis_ref_path(self): # /app/tardis-refdata ref_data_path = Path( Path(__file__).parent.parent, - "tardis-refdata", + "benchmarks", + "data" ).resolve() return ref_data_path @@ -83,7 +88,7 @@ def atomic_dataset(self) -> AtomData: @property def atomic_data_fname(self): atomic_data_fname = ( - f"{self.tardis_ref_path}/atom_data/kurucz_cd23_chianti_H_He.h5" + f"{self.tardis_ref_path}/kurucz_cd23_chianti_H_He.h5" ) if not Path(atomic_data_fname).exists(): @@ -235,9 +240,7 @@ def packet(self): @property def verysimple_packet_collection(self): - return ( - self.nb_simulation_verysimple.transport.transport_state.packet_collection - ) + return self.nb_simulation_verysimple.transport.transport_state.packet_collection @property def nb_simulation_verysimple(self): @@ -259,7 +262,6 @@ def verysimple_opacity_state(self): self.nb_simulation_verysimple.plasma, line_interaction_type="macroatom", disable_line_scattering=self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING, - continuum_processes_enabled=self.nb_simulation_verysimple.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, ) @property @@ -268,27 +270,19 @@ def verysimple_enable_full_relativity(self): @property def verysimple_disable_line_scattering(self): - return ( - self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING - ) + return self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING @property def verysimple_continuum_processes_enabled(self): - return ( - self.nb_simulation_verysimple.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED - ) + return montecarlo_globals.CONTINUUM_PROCESSES_ENABLED @property def verysimple_tau_russian(self): - return ( - self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN - ) + return self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN @property def verysimple_survival_probability(self): - return ( - self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY - ) + return self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY @property def static_packet(self): @@ -359,10 +353,10 @@ def verysimple_radfield_mc_estimators(self): @property def montecarlo_configuration(self): - return montecarlo_configuration.MonteCarloConfiguration() + return MonteCarloConfiguration() @property - def rpacket_tracker(self): + def rpacket_tracker(self): return RPacketTracker(0) @property @@ -396,7 +390,6 @@ def geometry(self): v_outer=np.array([-1, -1], dtype=np.float64), ) - @property def estimators(self): return radfield_mc_estimators.RadiationFieldMCEstimators( diff --git a/benchmarks/transport_geometry_calculate_distances.py b/benchmarks/transport_geometry_calculate_distances.py index 4f5829f40b8..dfda0b3ba66 100644 --- a/benchmarks/transport_geometry_calculate_distances.py +++ b/benchmarks/transport_geometry_calculate_distances.py @@ -1,10 +1,11 @@ -from asv_runner.benchmarks.mark import parameterize +from asv_runner.benchmarks.mark import parameterize, skip_benchmark import tardis.transport.frame_transformations as frame_transformations import tardis.transport.geometry.calculate_distances as calculate_distances from benchmarks.benchmark_base import BenchmarkBase +@skip_benchmark class BenchmarkTransportGeometryCalculateDistances(BenchmarkBase): """ Class to benchmark the calculate distances function. diff --git a/benchmarks/transport_montecarlo_interaction.py b/benchmarks/transport_montecarlo_interaction.py index 12fb8789830..1f8ac706a0a 100644 --- a/benchmarks/transport_montecarlo_interaction.py +++ b/benchmarks/transport_montecarlo_interaction.py @@ -10,7 +10,6 @@ from asv_runner.benchmarks.mark import parameterize - class BenchmarkMontecarloMontecarloNumbaInteraction(BenchmarkBase): """ Class to benchmark the numba interaction function. @@ -52,7 +51,6 @@ def time_line_scatter(self, line_interaction_type): line_interaction_type, self.verysimple_opacity_state, self.verysimple_enable_full_relativity, - self.verysimple_continuum_processes_enabled, ) @parameterize( @@ -63,17 +61,7 @@ def time_line_scatter(self, line_interaction_type): "emission_line_id": 1000, "energy": 0.9114437898710559, }, - { - "mu": -0.6975116557422458, - "emission_line_id": 2000, - "energy": 0.8803098648913266, - }, - { - "mu": -0.7115661419975774, - "emission_line_id": 0, - "energy": 0.8800385929341252, - }, - ] + ] } ) def time_line_emission(self, test_packet): diff --git a/benchmarks/transport_montecarlo_numba_formal_integral_p.py b/benchmarks/transport_montecarlo_numba_formal_integral_p.py index 35b05902a13..015d29a15f8 100644 --- a/benchmarks/transport_montecarlo_numba_formal_integral_p.py +++ b/benchmarks/transport_montecarlo_numba_formal_integral_p.py @@ -24,14 +24,6 @@ class BenchmarkMontecarloMontecarloNumbaNumbaFormalIntegral(BenchmarkBase): { "nu": 1e14, "temperature": 1e4, - }, - { - "nu": 0, - "temperature": 1, - }, - { - "nu": 1, - "temperature": 1, } ] } @@ -41,7 +33,7 @@ def time_intensity_black_body(self, parameters): temperature = parameters["temperature"] formal_integral.intensity_black_body(nu, temperature) - @parameterize({"N": (1e2, 1e3, 1e4, 1e5)}) + @parameterize({"N": (1e2, 1e3)}) def time_trapezoid_integration(self, n): h = 1.0 data = np.random.random(int(n)) @@ -98,7 +90,7 @@ def time_calculate_z(self, p, test_data): for r in r_outer: formal_integral.calculate_z(r, p, inv_t) - @parameterize({"N": [100, 1000, 10000]}) + @parameterize({"N": [100, 1000]}) def time_calculate_p_values(self, N): r = 1.0 formal_integral.calculate_p_values(r, N) diff --git a/benchmarks/transport_montecarlo_numba_interface.py b/benchmarks/transport_montecarlo_numba_interface.py index ac8f79791f6..c3009c1ef12 100644 --- a/benchmarks/transport_montecarlo_numba_interface.py +++ b/benchmarks/transport_montecarlo_numba_interface.py @@ -14,7 +14,7 @@ class BenchmarkMontecarloMontecarloNumbaNumbaInterface(BenchmarkBase): Class to benchmark the numba interface function. """ - @parameterize({"Input params": ["scatter", "macroatom", "downbranch"]}) + @parameterize({"Input params": ["scatter", "macroatom"]}) def time_opacity_state_initialize(self, input_params): line_interaction_type = input_params plasma = self.nb_simulation_verysimple.plasma @@ -22,5 +22,4 @@ def time_opacity_state_initialize(self, input_params): plasma, line_interaction_type, self.verysimple_disable_line_scattering, - self.verysimple_continuum_processes_enabled, - ) \ No newline at end of file + ) diff --git a/benchmarks/transport_montecarlo_opacities.py b/benchmarks/transport_montecarlo_opacities.py index c5fe0196f47..7ce2a4d72ed 100644 --- a/benchmarks/transport_montecarlo_opacities.py +++ b/benchmarks/transport_montecarlo_opacities.py @@ -17,12 +17,10 @@ class BenchmarkMontecarloMontecarloNumbaOpacities(BenchmarkBase): { "Electron number density": [ 1.0e11, - 1e15, 1e5, ], "Energy": [ 511.0, - 255.5, 511.0e7, ], } @@ -34,58 +32,56 @@ def time_compton_opacity_calculation(self, electron_number_density, energy): @parameterize( { - "Ejecta density": [ - 1.0, - 1e-2, - 1e-2, - 1e5, - ], - "Energy": [ - 511.0, - 255.5, - 255.5, - 511.0e7, - ], - "Iron group fraction": [ - 0.0, - 0.5, - 0.25, - 1.0, - ], + "Parameters": [ + { + "Ejecta_density": 1.0, + "Energy": 255.5, + "Iron_group_fraction": 0.5 + }, + { + "Ejecta_density": 0.01, + "Energy": 255.5, + "Iron_group_fraction": 1.0 + }, + { + "Ejecta_density": 0.01, + "Energy": 255.5, + "Iron_group_fraction": 0.5 + } + ] } ) - def time_photoabsorption_opacity_calculation( - self, ejecta_density, energy, iron_group_fraction - ): + def time_photoabsorption_opacity_calculation(self, parameters): calculate_opacity.photoabsorption_opacity_calculation( - energy, ejecta_density, iron_group_fraction + parameters["Energy"], + parameters["Ejecta_density"], + parameters["Iron_group_fraction"] ) @parameterize( { - "Ejecta density": [ - 1.0, - 1e-2, - 1e-2, - 1e5, - ], - "Energy": [ - 511.0, - 1500, - 1200, - 511.0e7, - ], - "Iron group fraction": [ - 0.0, - 0.5, - 0.25, - 1.0, - ], + "Parameters": [ + { + "Ejecta_density": 1.0, + "Energy": 255.5, + "Iron_group_fraction": 0.5 + }, + { + "Ejecta_density": 0.01, + "Energy": 255.5, + "Iron_group_fraction": 1.0 + }, + { + "Ejecta_density": 0.01, + "Energy": 255.5, + "Iron_group_fraction": 0.0 + } + ] } ) - def time_pair_creation_opacity_calculation( - self, ejecta_density, energy, iron_group_fraction - ): + def time_pair_creation_opacity_calculation(self, parameters): calculate_opacity.pair_creation_opacity_calculation( - energy, ejecta_density, iron_group_fraction + parameters["Energy"], + parameters["Ejecta_density"], + parameters["Iron_group_fraction"] ) diff --git a/benchmarks/transport_montecarlo_packet.py b/benchmarks/transport_montecarlo_packet.py index fd9a262a93a..5a6511ab97e 100644 --- a/benchmarks/transport_montecarlo_packet.py +++ b/benchmarks/transport_montecarlo_packet.py @@ -46,18 +46,10 @@ def time_calculate_distance_electron(self, parameters): "electron_density": 1e-5, "distance": 1.0, }, - { - "electron_density": 1e10, - "distance": 1e10, - }, { "electron_density": -1, "distance": 0, }, - { - "electron_density": -1e10, - "distance": -1e10, - }, ] } ) @@ -78,12 +70,6 @@ def time_get_random_mu(self): "time_explosion": 5.2e7, "enable_full_relativity": True, }, - { - "cur_line_id": 0, - "distance_trace": 0, - "time_explosion": 5.2e7, - "enable_full_relativity": True, - }, { "cur_line_id": 1, "distance_trace": 1e5, @@ -115,16 +101,6 @@ def time_update_line_estimators(self, parameters): "delta_shell": 11, "no_of_shells": 132, }, - { - "current_shell_id": 132, - "delta_shell": 1, - "no_of_shells": 133, - }, - { - "current_shell_id": 132, - "delta_shell": 2, - "no_of_shells": 133, - }, ] } ) @@ -145,17 +121,7 @@ def time_move_packet_across_shell_boundary_emitted(self, parameters): "current_shell_id": 132, "delta_shell": 132, "no_of_shells": 132, - }, - { - "current_shell_id": -133, - "delta_shell": -133, - "no_of_shells": -1e9, - }, - { - "current_shell_id": 132, - "delta_shell": 133, - "no_of_shells": 133, - }, + } ] } ) @@ -176,17 +142,7 @@ def time_move_packet_across_shell_boundary_reabsorbed(self, parameters): "current_shell_id": 132, "delta_shell": -1, "no_of_shells": 199, - }, - { - "current_shell_id": 132, - "delta_shell": 0, - "no_of_shells": 132, - }, - { - "current_shell_id": 132, - "delta_shell": 20, - "no_of_shells": 154, - }, + } ] } ) diff --git a/benchmarks/transport_montecarlo_vpacket.py b/benchmarks/transport_montecarlo_vpacket.py index 22198150f43..79dae89da6e 100644 --- a/benchmarks/transport_montecarlo_vpacket.py +++ b/benchmarks/transport_montecarlo_vpacket.py @@ -29,7 +29,7 @@ def v_packet(self): next_line_id=0, index=0, ) - + @property def r_packet(self): return RPacket( @@ -62,9 +62,6 @@ def time_trace_vpacket_within_shell(self): verysimple_time_explosion = self.verysimple_time_explosion verysimple_opacity_state = self.verysimple_opacity_state enable_full_relativity = self.verysimple_enable_full_relativity - continuum_processes_enabled = ( - self.verysimple_continuum_processes_enabled - ) # Give the vpacket a reasonable line ID self.v_packet_initialize_line_id( @@ -80,7 +77,6 @@ def time_trace_vpacket_within_shell(self): verysimple_time_explosion, verysimple_opacity_state, enable_full_relativity, - continuum_processes_enabled, ) def time_trace_vpacket(self): @@ -91,9 +87,6 @@ def time_trace_vpacket(self): verysimple_time_explosion = self.verysimple_time_explosion verysimple_opacity_state = self.verysimple_opacity_state enable_full_relativity = self.verysimple_enable_full_relativity - continuum_processes_enabled = ( - self.verysimple_continuum_processes_enabled - ) tau_russian = self.verysimple_tau_russian survival_probability = self.verysimple_survival_probability @@ -116,7 +109,6 @@ def time_trace_vpacket(self): tau_russian, survival_probability, enable_full_relativity, - continuum_processes_enabled, ) @property @@ -139,9 +131,6 @@ def time_trace_bad_vpacket(self): enable_full_relativity = self.verysimple_enable_full_relativity verysimple_time_explosion = self.verysimple_time_explosion verysimple_opacity_state = self.verysimple_opacity_state - continuum_processes_enabled = ( - self.verysimple_continuum_processes_enabled - ) tau_russian = self.verysimple_tau_russian survival_probability = self.verysimple_survival_probability @@ -153,20 +142,12 @@ def time_trace_bad_vpacket(self): tau_russian, survival_probability, enable_full_relativity, - continuum_processes_enabled, ) @parameterize( { - "Paramters": [ - { - "tau_russian": 10.0, - "survival_possibility": 0.0 - }, - { - "tau_russian": 15.0, - "survival_possibility": 0.1 - }, + "Parameters": [ + {"tau_russian": 10.0, "survival_possibility": 0.0} ] } ) @@ -179,7 +160,5 @@ def time_trace_vpacket_volley(self, parameters): self.verysimple_opacity_state, False, parameters["tau_russian"], - parameters["survival_possibility"], - False + parameters["survival_possibility"] ) - diff --git a/docs/io/output/vpacket_logging.rst b/docs/io/output/vpacket_logging.rst index 308ce6af409..12c9f0cb18d 100644 --- a/docs/io/output/vpacket_logging.rst +++ b/docs/io/output/vpacket_logging.rst @@ -42,6 +42,9 @@ After running the simulation, the following information can be retrieved: * - ``transport.virt_packet_last_interaction_in_nu`` - Numpy array - Frequencies of the r-packets which spawned the virtual packet + * - ``transport.virt_packet_last_interaction_in_r`` + - Numpy array + - Radii of the r-packets which spawned the virtual packet * - ``transport.virt_packet_last_line_interaction_in_id`` - Numpy array - | If the last interaction was a line interaction, the diff --git a/docs/io/visualization/how_to_liv_plot.ipynb b/docs/io/visualization/how_to_liv_plot.ipynb new file mode 100644 index 00000000000..795b9090277 --- /dev/null +++ b/docs/io/visualization/how_to_liv_plot.ipynb @@ -0,0 +1,529 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to Generate a Last Interaction Velocity (LIV) Plot\n", + "The Last Interaction Velocity Plot tracks and display the velocities at which different elements (or species) last interacted with packets in the simulation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, create and run a simulation for which you want to generate this plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tardis import run_tardis\n", + "from tardis.io.atom_data.util import download_atom_data\n", + "\n", + "# We download the atomic data needed to run the simulation\n", + "download_atom_data('kurucz_cd23_chianti_H_He')\n", + "\n", + "sim = run_tardis(\"tardis_example.yml\", virtual_packet_logging=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "Note\n", + "\n", + "The virtual packet logging capability must be active in order to produce the Last Interaction Velocity Plot for virtual packets population. Thus, make sure to set `virtual_packet_logging: True` in your configuration file if you want to generate the Last Interaction Velocity Plot with virtual packets. It should be added under the `virtual` property of the `spectrum` property, as described in the [configuration schema](https://tardis-sn.github.io/tardis/io/configuration/components/spectrum.html).\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, import the plotting interface for Last Interaction Velocity Plot, i.e. the `LIVPlotter` class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tardis.visualization.tools.liv_plot import LIVPlotter" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And create a plotter object to process the data of simulation object `sim` for generating the Last Interaction Velocity plot." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter = LIVPlotter.from_simulation(sim)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Static Plot (in matplotlib)\n", + "You can now call the `generate_plot_mpl()` method on your plotter object to create a highly informative and visually appealing Last Interaction Velocity plot using matplotlib." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Virtual packets mode\n", + "By default, a Last Interaction Velocity plot is produced for the virtual packet population of the simulation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Real packets mode\n", + "You can produce a Last Interaction Velocity plot for the real packet population of the simulation by setting `packets_mode=\"real\"` which is `\"virtual\"` by default." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(packets_mode=\"real\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting only the top contributing elements" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `nelements` option allows you to plot the top contributing elements to the spectrum. Only the top elements are shown in the plot. Please note this works only for elements and not for ions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(nelements=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Choosing what elements/ions to plot\n", + "\n", + "You can pass a `species_list` for the species you want plotted in the Last Interaction Velocity Plot. Valid options include elements (e.g., Si), ions (specified in Roman numeral format, e.g., Si II), a range of ions (e.g., Si I-III), or any combination of these." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(species_list = [\"Si I-III\", \"O\", \"Ca\", \"S\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When using both the `nelements` and the `species_list` options, `species_list` takes precedence. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(species_list = [\"Si I-III\", \"Ca\", \"S\"], nelements=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting a specific number of bins\n", + "You can regroup the bins with broader or narrower widths within the same velocity range using `num_bins`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(species_list = [\"Si I-III\", \"O\", \"Ca\", \"S\"], num_bins=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting on the Log Scale\n", + "You can plot on the log scale on x-axis using `xlog_scale=True` and on y-axis using `ylog_scale=True` by default both are set to `False` which plots on a linear scale." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(species_list = [\"Si I-III\", \"O\", \"Ca\", \"S\"], xlog_scale=True, ylog_scale=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting a specific velocity range\n", + "You can restrict the range of bins to plot in the Last Interaction Velocity Plot by specifying a valid `velocity_range`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(species_list = [\"Si I-III\", \"O\", \"Ca\", \"S\"], velocity_range=(12500, 15050))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional plotting options" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# To list all available options (or parameters) with their description\n", + "help(plotter.generate_plot_mpl)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `generate_plot_mpl` method also has options specific to the matplotlib API, thereby providing you with more control over how your last interaction velocity looks. Possible cases where you may use them are:\n", + "\n", + "- `ax`: To plot on an Axis of a plot you're already working with, e.g. for subplots.\n", + "\n", + "- `figsize`: To resize the plot as per your requirements.\n", + "\n", + "- `cmapname`: To use a colormap of your preference, instead of \"jet\"." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Interactive Plot (in plotly)\n", + "If you're using the Last Interaction Velocity plot for exploration, consider creating an interactive version with `generate_plot_ply()`. This allows you to zoom, pan, inspect data values by hovering, resize the scale, and more conveniently.\n", + "\n", + "\n", + "\n", + "**This method takes the same arguments as `generate_plot_mpl` except for a few specific to the Plotly library.** You can produce all the plots shown above in Plotly by passing the same arguments." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Virtual packets mode\n", + "By default, a Last Interaction Velocity plot is produced for the virtual packet population of the simulation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Real packets mode\n", + "You can produce a Last Interaction Velocity plot for the real packet population of the simulation by setting `packets_mode=\"real\"` which is `\"virtual\"` by default." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(packets_mode=\"real\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting only the top contributing elements" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `nelements` option allows you to plot the top contributing elements to the spectrum. Only the top elements are shown in the plot. Please note this works only for elements and not for ions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(nelements=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Choosing what elements/ions to plot\n", + "\n", + "You can pass a `species_list` for the species you want plotted in the Last Interaction Velocity Plot. Valid options include elements (e.g., Si), ions (specified in Roman numeral format, e.g., Si II), a range of ions (e.g., Si I-III), or any combination of these." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(species_list = [\"Si I-III\", \"Ca\", \"S\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When using both the `nelements` and the `species_list` options, `species_list` takes precedence. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(species_list = [\"Si I-III\", \"Ca\", \"S\"], nelements=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting a specific number of bins\n", + "You can regroup the bins with broader and narrower widths within the same velocity range using `num_bins`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(species_list = [\"Si I-III\", \"Ca\", \"S\"], num_bins=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting on the Log Scale\n", + "You can plot on the log scale on x-axis using `xlog_scale=True` and on y-axis using `ylog_scale=True` by default both are set to `False`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(species_list = [\"Si I-III\", \"Ca\", \"S\"], xlog_scale=True, ylog_scale=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting a specific velocity range\n", + "You can restrict the range of bins to plot in the Last Interaction Velocity Plot by specifying a valid `velocity_range`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(species_list = [\"Si I-III\", \"Ca\", \"S\"], velocity_range=(12500, 15050))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional plotting options" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# To list all available options (or parameters) with their description\n", + "help(plotter.generate_plot_ply)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The `generate_plot_ply` method also has options specific to the plotly API, thereby providing you with more control over how your last interaction velocity plot looks. Possible cases where you may use them are:\n", + "\n", + " - `fig`: To plot the last interaction velocity plot on a figure you are already using e.g. for subplots.\n", + "\n", + " - `graph_height`: To specify the height of the graph as needed.\n", + " \n", + " - `cmapname`: To use a colormap of your preference instead of \"jet\"." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using simulation saved as HDF\n", + "Other than producing the Last Interaction Velocity Plot for simulation objects in runtime, you can also produce it for saved TARDIS simulations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# hdf_plotter = LIVPlotter.from_hdf(\"demo.h5\") ## Files is too large - just as an example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This `hdf_plotter` object is similar to the `plotter` object we used above, **so you can use each plotting method demonstrated above with this too.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Static plot with virtual packets mode\n", + "# hdf_plotter.generate_plot_mpl()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Static plot with real packets mode\n", + "#hdf_plotter.generate_plot_mpl(packets_mode=\"real\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive plot with virtual packets mode and specific list of species\n", + "# hdf_plotter.generate_plot_ply(species_list=[\"Si I-III\", \"Ca\", \"O\", \"S\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive plot with virtual packets mode and regrouped bins\n", + "# hdf_plotter.generate_plot_ply(num_bins=10)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tardis", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/io/visualization/index.rst b/docs/io/visualization/index.rst index b0ff58a6f3b..c98a32caae0 100644 --- a/docs/io/visualization/index.rst +++ b/docs/io/visualization/index.rst @@ -13,6 +13,7 @@ diagnostic visualizations. :maxdepth: 2 how_to_sdec_plot + how_to_liv_plot tutorial_convergence_plot tutorial_montecarlo_packet_visualization diff --git a/pyproject.toml b/pyproject.toml index d2d9c0a5e6f..92d3326248f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -189,6 +189,8 @@ text_file_format = "rst" markers = [ # continuum tests "continuum", + # rpacket tracking tests + "rpacket_tracking" ] [tool.tardis] diff --git a/tardis/opacities/opacities.py b/tardis/opacities/opacities.py index c917f59294f..dd7aeae83d4 100644 --- a/tardis/opacities/opacities.py +++ b/tardis/opacities/opacities.py @@ -6,7 +6,7 @@ from tardis.transport.montecarlo import ( njit_dict_no_parallel, ) -from tardis.transport.montecarlo.numba_config import ( +from tardis.transport.montecarlo.configuration.constants import ( SIGMA_THOMSON, ) diff --git a/tardis/opacities/opacity_state.py b/tardis/opacities/opacity_state.py index 6bfa5e308c0..1f8905cd12f 100644 --- a/tardis/opacities/opacity_state.py +++ b/tardis/opacities/opacity_state.py @@ -3,6 +3,7 @@ from numba.experimental import jitclass from tardis.opacities.tau_sobolev import calculate_sobolev_line_opacity +from tardis.transport.montecarlo.configuration import montecarlo_globals opacity_state_spec = [ ("electron_density", float64[:]), @@ -110,7 +111,6 @@ def opacity_state_initialize( plasma, line_interaction_type, disable_line_scattering, - continuum_processes_enabled, ): """ Initialize the OpacityState object and copy over the data over from TARDIS Plasma @@ -156,7 +156,7 @@ def opacity_state_initialize( ) # TODO: Fix setting of block references for non-continuum mode - if continuum_processes_enabled: + if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED: macro_block_references = plasma.macro_block_references else: macro_block_references = plasma.atomic_data.macro_atom_references[ @@ -169,7 +169,7 @@ def opacity_state_initialize( "destination_level_idx" ].values transition_line_id = plasma.macro_atom_data["lines_idx"].values - if continuum_processes_enabled: + if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED: bf_threshold_list_nu = plasma.nu_i.loc[ plasma.level2continuum_idx.index ].values diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index 7efa377a314..e8e23271ee9 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -20,6 +20,7 @@ from tardis.plasma.standard_plasmas import assemble_plasma from tardis.simulation.convergence import ConvergenceSolver from tardis.transport.montecarlo.base import MonteCarloTransportSolver +from tardis.transport.montecarlo.configuration import montecarlo_globals from tardis.util.base import is_notebook from tardis.visualization import ConvergencePlots @@ -199,7 +200,7 @@ def __init__( self._callbacks = OrderedDict() self._cb_next_id = 0 - self.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED = ( + montecarlo_globals.CONTINUUM_PROCESSES_ENABLED = ( not self.plasma.continuum_interaction_species.empty ) diff --git a/tardis/transport/frame_transformations.py b/tardis/transport/frame_transformations.py index a9768fc6f1a..9813664b98a 100644 --- a/tardis/transport/frame_transformations.py +++ b/tardis/transport/frame_transformations.py @@ -6,7 +6,7 @@ njit_dict_no_parallel, ) -from tardis.transport.montecarlo.numba_config import C_SPEED_OF_LIGHT +from tardis.transport.montecarlo.configuration.constants import C_SPEED_OF_LIGHT @njit(**njit_dict_no_parallel) diff --git a/tardis/transport/geometry/calculate_distances.py b/tardis/transport/geometry/calculate_distances.py index 3649f6c811f..02b1c13cb2b 100644 --- a/tardis/transport/geometry/calculate_distances.py +++ b/tardis/transport/geometry/calculate_distances.py @@ -6,7 +6,7 @@ njit_dict_no_parallel, ) -from tardis.transport.montecarlo.numba_config import ( +from tardis.transport.montecarlo.configuration.constants import ( C_SPEED_OF_LIGHT, MISS_DISTANCE, SIGMA_THOMSON, diff --git a/tardis/transport/montecarlo/base.py b/tardis/transport/montecarlo/base.py index ce62695eb4d..fc580ac2de6 100644 --- a/tardis/transport/montecarlo/base.py +++ b/tardis/transport/montecarlo/base.py @@ -3,21 +3,22 @@ from astropy import units as u from numba import cuda, set_num_threads +import tardis.transport.montecarlo.configuration.constants as constants from tardis import constants as const from tardis.io.logger import montecarlo_tracking as mc_tracker from tardis.io.util import HDFWriterMixin from tardis.transport.montecarlo import ( montecarlo_main_loop, - numba_config, +) +from tardis.transport.montecarlo.configuration import montecarlo_globals +from tardis.transport.montecarlo.configuration.base import ( + MonteCarloConfiguration, + configuration_initialize, ) from tardis.transport.montecarlo.estimators.radfield_mc_estimators import ( initialize_estimator_statistics, ) from tardis.transport.montecarlo.formal_integral import FormalIntegrator -from tardis.transport.montecarlo.montecarlo_configuration import ( - MonteCarloConfiguration, - configuration_initialize, -) from tardis.transport.montecarlo.montecarlo_transport_state import ( MonteCarloTransportState, ) @@ -116,7 +117,6 @@ def initialize_transport_state( plasma, self.line_interaction_type, self.montecarlo_configuration.DISABLE_LINE_SCATTERING, - self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, ) transport_state = MonteCarloTransportState( packet_collection, @@ -192,6 +192,7 @@ def run( ] = v_packets_energy_hist transport_state.last_interaction_type = last_interaction_tracker.types transport_state.last_interaction_in_nu = last_interaction_tracker.in_nus + transport_state.last_interaction_in_r = last_interaction_tracker.in_rs transport_state.last_line_interaction_in_id = ( last_interaction_tracker.in_ids ) @@ -210,7 +211,7 @@ def run( update_iterations_pbar(1) refresh_packet_pbar() # Condition for Checking if RPacket Tracking is enabled - if self.montecarlo_configuration.ENABLE_RPACKET_TRACKING: + if self.enable_rpacket_tracking: transport_state.rpacket_tracker = rpacket_trackers if self.transport_state.rpacket_tracker is not None: @@ -245,10 +246,10 @@ def from_config( "Likely bug in formal integral - " "will not give same results." ) - numba_config.SIGMA_THOMSON = 1e-200 + constants.SIGMA_THOMSON = 1e-200 else: logger.debug("Electron scattering switched on") - numba_config.SIGMA_THOMSON = const.sigma_T.to("cm^2").value + constants.SIGMA_THOMSON = const.sigma_T.to("cm^2").value spectrum_frequency = quantity_linspace( config.spectrum.stop.to("Hz", u.spectral()), diff --git a/tardis/transport/montecarlo/configuration/__init__.py b/tardis/transport/montecarlo/configuration/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tardis/transport/montecarlo/montecarlo_configuration.py b/tardis/transport/montecarlo/configuration/base.py similarity index 91% rename from tardis/transport/montecarlo/montecarlo_configuration.py rename to tardis/transport/montecarlo/configuration/base.py index 27ad015fd8b..5c8e482e08c 100644 --- a/tardis/transport/montecarlo/montecarlo_configuration.py +++ b/tardis/transport/montecarlo/configuration/base.py @@ -3,6 +3,7 @@ from numba.experimental import jitclass import numpy as np +from tardis.transport.montecarlo.configuration import montecarlo_globals from tardis.transport.montecarlo.numba_interface import ( LineInteractionType, ) @@ -20,8 +21,6 @@ ("VPACKET_TAU_RUSSIAN", float64), ("INITIAL_TRACKING_ARRAY_LENGTH", int64), ("LEGACY_MODE_ENABLED", boolean), - ("ENABLE_RPACKET_TRACKING", boolean), - ("CONTINUUM_PROCESSES_ENABLED", boolean), ("VPACKET_SPAWN_START_FREQUENCY", float64), ("VPACKET_SPAWN_END_FREQUENCY", float64), ("ENABLE_VPACKET_TRACKING", boolean), @@ -45,9 +44,6 @@ def __init__(self): self.INITIAL_TRACKING_ARRAY_LENGTH = 0 self.LEGACY_MODE_ENABLED = False - self.ENABLE_RPACKET_TRACKING = False - self.CONTINUUM_PROCESSES_ENABLED = False - self.VPACKET_SPAWN_START_FREQUENCY = 0 self.VPACKET_SPAWN_END_FREQUENCY = 1e200 self.ENABLE_VPACKET_TRACKING = False @@ -81,4 +77,6 @@ def configuration_initialize(config, transport, number_of_vpackets): ).value ) config.ENABLE_VPACKET_TRACKING = transport.enable_vpacket_tracking - config.ENABLE_RPACKET_TRACKING = transport.enable_rpacket_tracking + montecarlo_globals.ENABLE_RPACKET_TRACKING = ( + transport.enable_rpacket_tracking + ) diff --git a/tardis/transport/montecarlo/numba_config.py b/tardis/transport/montecarlo/configuration/constants.py similarity index 100% rename from tardis/transport/montecarlo/numba_config.py rename to tardis/transport/montecarlo/configuration/constants.py diff --git a/tardis/transport/montecarlo/configuration/montecarlo_globals.py b/tardis/transport/montecarlo/configuration/montecarlo_globals.py new file mode 100644 index 00000000000..4faad365010 --- /dev/null +++ b/tardis/transport/montecarlo/configuration/montecarlo_globals.py @@ -0,0 +1,2 @@ +ENABLE_RPACKET_TRACKING = False +CONTINUUM_PROCESSES_ENABLED = False diff --git a/tardis/transport/montecarlo/estimators/radfield_estimator_calcs.py b/tardis/transport/montecarlo/estimators/radfield_estimator_calcs.py index dd574abcedd..6be62d7af0c 100644 --- a/tardis/transport/montecarlo/estimators/radfield_estimator_calcs.py +++ b/tardis/transport/montecarlo/estimators/radfield_estimator_calcs.py @@ -5,7 +5,7 @@ from tardis.transport.montecarlo import ( njit_dict_no_parallel, ) -from tardis.transport.montecarlo.numba_config import KB, H +from tardis.transport.montecarlo.configuration.constants import KB, H from tardis.transport.frame_transformations import ( calc_packet_energy, calc_packet_energy_full_relativity, diff --git a/tardis/transport/montecarlo/formal_integral.py b/tardis/transport/montecarlo/formal_integral.py index c699490fcc5..1774791dbae 100644 --- a/tardis/transport/montecarlo/formal_integral.py +++ b/tardis/transport/montecarlo/formal_integral.py @@ -14,7 +14,8 @@ OpacityState, opacity_state_initialize, ) -from tardis.transport.montecarlo.numba_config import SIGMA_THOMSON +from tardis.transport.montecarlo.configuration.constants import SIGMA_THOMSON +from tardis.transport.montecarlo.configuration import montecarlo_globals from tardis.transport.montecarlo import njit_dict, njit_dict_no_parallel from tardis.transport.montecarlo.numba_interface import ( opacity_state_initialize, @@ -289,7 +290,6 @@ def __init__(self, simulation_state, plasma, transport, points=1000): plasma, transport.line_interaction_type, self.montecarlo_configuration.DISABLE_LINE_SCATTERING, - self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, ) self.atomic_data = plasma.atomic_data self.original_plasma = plasma @@ -312,7 +312,6 @@ def generate_numba_objects(self): self.original_plasma, self.transport.line_interaction_type, self.montecarlo_configuration.DISABLE_LINE_SCATTERING, - self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, ) if self.transport.use_gpu: self.integrator = CudaFormalIntegrator( @@ -364,7 +363,7 @@ def raise_or_return(message): 'and line_interaction_type == "macroatom"' ) - if self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED: + if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED: return raise_or_return( "The FormalIntegrator currently does not work for " "continuum interactions." diff --git a/tardis/transport/montecarlo/formal_integral_cuda.py b/tardis/transport/montecarlo/formal_integral_cuda.py index 75455f19c5f..b1d7f38b93c 100644 --- a/tardis/transport/montecarlo/formal_integral_cuda.py +++ b/tardis/transport/montecarlo/formal_integral_cuda.py @@ -4,7 +4,7 @@ from numba import float64, int64, cuda import math -from tardis.transport.montecarlo.numba_config import SIGMA_THOMSON +from tardis.transport.montecarlo.configuration.constants import SIGMA_THOMSON C_INV = 3.33564e-11 M_PI = np.arccos(-1) @@ -149,7 +149,6 @@ def cuda_formal_integral( ) for _ in range(max(nu_end_idx - pline, 0)): - # calculate e-scattering optical depth to next resonance point zend = time_explosion / C_INV * (1.0 - line_list_nu[pline] / nu) if first == 1: diff --git a/tardis/transport/montecarlo/interaction.py b/tardis/transport/montecarlo/interaction.py index 044fbcd11a8..0801fec4e34 100644 --- a/tardis/transport/montecarlo/interaction.py +++ b/tardis/transport/montecarlo/interaction.py @@ -1,6 +1,7 @@ import numpy as np from numba import njit +import tardis.transport.montecarlo.configuration.montecarlo_globals as montecarlo_globals from tardis import constants as const from tardis.transport.montecarlo import njit_dict_no_parallel from tardis.transport.montecarlo.macro_atom import ( @@ -147,7 +148,6 @@ def continuum_event( chi_ff, chi_bf_contributions, current_continua, - continuum_processes_enabled, enable_full_relativity, ): """ @@ -189,7 +189,6 @@ def continuum_event( r_packet, time_explosion, opacity_state, - continuum_processes_enabled, enable_full_relativity, ) @@ -200,7 +199,6 @@ def macro_atom_event( r_packet, time_explosion, opacity_state, - continuum_processes_enabled, enable_full_relativity, ): """ @@ -218,7 +216,7 @@ def macro_atom_event( ) if ( - continuum_processes_enabled + montecarlo_globals.CONTINUUM_PROCESSES_ENABLED and transition_type == MacroAtomTransitionType.FF_EMISSION ): free_free_emission( @@ -226,7 +224,7 @@ def macro_atom_event( ) elif ( - continuum_processes_enabled + montecarlo_globals.CONTINUUM_PROCESSES_ENABLED and transition_type == MacroAtomTransitionType.BF_EMISSION ): bound_free_emission( @@ -237,7 +235,7 @@ def macro_atom_event( enable_full_relativity, ) elif ( - continuum_processes_enabled + montecarlo_globals.CONTINUUM_PROCESSES_ENABLED and transition_type == MacroAtomTransitionType.BF_COOLING ): bf_cooling( @@ -245,7 +243,7 @@ def macro_atom_event( ) elif ( - continuum_processes_enabled + montecarlo_globals.CONTINUUM_PROCESSES_ENABLED and transition_type == MacroAtomTransitionType.ADIABATIC_COOLING ): adiabatic_cooling(r_packet) @@ -427,7 +425,6 @@ def line_scatter( time_explosion, line_interaction_type, opacity_state, - continuum_processes_enabled, enable_full_relativity, ): """ @@ -471,7 +468,6 @@ def line_scatter( r_packet, time_explosion, opacity_state, - continuum_processes_enabled, enable_full_relativity, ) diff --git a/tardis/transport/montecarlo/montecarlo_main_loop.py b/tardis/transport/montecarlo/montecarlo_main_loop.py index d20c21eacef..e3105e95ab1 100644 --- a/tardis/transport/montecarlo/montecarlo_main_loop.py +++ b/tardis/transport/montecarlo/montecarlo_main_loop.py @@ -4,12 +4,13 @@ from numba.typed import List from tardis.transport.montecarlo import njit_dict -from tardis.transport.montecarlo.packet_trackers import RPacketTracker +from tardis.transport.montecarlo.configuration import montecarlo_globals from tardis.transport.montecarlo.packet_collections import ( VPacketCollection, consolidate_vpacket_tracker, initialize_last_interaction_tracker, ) +from tardis.transport.montecarlo.packet_trackers import RPacketTracker from tardis.transport.montecarlo.r_packet import ( PacketStatus, RPacket, @@ -185,7 +186,7 @@ def montecarlo_main_loop( 1, ) - if montecarlo_configuration.ENABLE_RPACKET_TRACKING: + if montecarlo_globals.ENABLE_RPACKET_TRACKING: for rpacket_tracker in rpacket_trackers: rpacket_tracker.finalize_array() diff --git a/tardis/transport/montecarlo/montecarlo_transport_state.py b/tardis/transport/montecarlo/montecarlo_transport_state.py index 98fc5bf976c..a2fd0455190 100644 --- a/tardis/transport/montecarlo/montecarlo_transport_state.py +++ b/tardis/transport/montecarlo/montecarlo_transport_state.py @@ -28,6 +28,7 @@ class MonteCarloTransportState(HDFWriterMixin): "emitted_packet_mask", "last_interaction_type", "last_interaction_in_nu", + "last_interaction_in_r", "last_line_interaction_out_id", "last_line_interaction_in_id", "last_line_interaction_shell_id", @@ -39,6 +40,7 @@ class MonteCarloTransportState(HDFWriterMixin): "virt_packet_initial_rs", "virt_packet_initial_mus", "virt_packet_last_interaction_in_nu", + "virt_packet_last_interaction_in_r", "virt_packet_last_interaction_type", "virt_packet_last_line_interaction_in_id", "virt_packet_last_line_interaction_out_id", @@ -49,6 +51,7 @@ class MonteCarloTransportState(HDFWriterMixin): last_interaction_type = None last_interaction_in_nu = None + last_interaction_in_r = None last_line_interaction_out_id = None last_line_interaction_in_id = None last_line_interaction_shell_id = None @@ -399,6 +402,20 @@ def virt_packet_last_interaction_in_nu(self): ) return None + @property + def virt_packet_last_interaction_in_r(self): + try: + return u.Quantity(self.vpacket_tracker.last_interaction_in_r, u.cm) + except AttributeError: + warnings.warn( + "MontecarloTransport.virt_packet_last_interaction_in_r:" + "Set 'virtual_packet_logging: True' in the configuration file" + "to access this property" + "It should be added under 'virtual' property of 'spectrum' property", + UserWarning, + ) + return None + @property def virt_packet_last_interaction_type(self): try: diff --git a/tardis/transport/montecarlo/numba_interface.py b/tardis/transport/montecarlo/numba_interface.py index 58938bf9044..f2253c7ff5d 100644 --- a/tardis/transport/montecarlo/numba_interface.py +++ b/tardis/transport/montecarlo/numba_interface.py @@ -1,11 +1,11 @@ from enum import IntEnum +import numpy as np from numba import float64, int64 from numba.experimental import jitclass -import numpy as np from tardis import constants as const - +from tardis.transport.montecarlo.configuration import montecarlo_globals C_SPEED_OF_LIGHT = const.c.to("cm/s").value @@ -117,7 +117,6 @@ def opacity_state_initialize( plasma, line_interaction_type, disable_line_scattering, - continuum_processes_enabled, ): """ Initialize the OpacityState object and copy over the data over from TARDIS Plasma @@ -157,7 +156,7 @@ def opacity_state_initialize( ) # TODO: Fix setting of block references for non-continuum mode - if continuum_processes_enabled: + if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED: macro_block_references = plasma.macro_block_references else: macro_block_references = plasma.atomic_data.macro_atom_references[ @@ -170,7 +169,7 @@ def opacity_state_initialize( "destination_level_idx" ].values transition_line_id = plasma.macro_atom_data["lines_idx"].values - if continuum_processes_enabled: + if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED: bf_threshold_list_nu = plasma.nu_i.loc[ plasma.level2continuum_idx.index ].values diff --git a/tardis/transport/montecarlo/packet_collections.py b/tardis/transport/montecarlo/packet_collections.py index 9745726d7bb..dde92f6f4f5 100644 --- a/tardis/transport/montecarlo/packet_collections.py +++ b/tardis/transport/montecarlo/packet_collections.py @@ -54,10 +54,12 @@ def initialize_last_interaction_tracker(no_of_packets): ) last_interaction_types = -1 * np.ones(no_of_packets, dtype=np.int64) last_interaction_in_nus = np.zeros(no_of_packets, dtype=np.float64) + last_interaction_in_rs = np.zeros(no_of_packets, dtype=np.float64) return LastInteractionTracker( last_interaction_types, last_interaction_in_nus, + last_interaction_in_rs, last_line_interaction_in_ids, last_line_interaction_out_ids, last_line_interaction_shell_ids, @@ -67,6 +69,7 @@ def initialize_last_interaction_tracker(no_of_packets): last_interaction_tracker_spec = [ ("types", int64[:]), ("in_nus", float64[:]), + ("in_rs", float64[:]), ("in_ids", int64[:]), ("out_ids", int64[:]), ("shell_ids", int64[:]), @@ -79,12 +82,14 @@ def __init__( self, types, in_nus, + in_rs, in_ids, out_ids, shell_ids, ): self.types = types self.in_nus = in_nus + self.in_rs = in_rs self.in_ids = in_ids self.out_ids = out_ids self.shell_ids = shell_ids @@ -92,6 +97,7 @@ def __init__( def update_last_interaction(self, r_packet, i): self.types[i] = r_packet.last_interaction_type self.in_nus[i] = r_packet.last_interaction_in_nu + self.in_rs[i] = r_packet.last_interaction_in_r self.in_ids[i] = r_packet.last_line_interaction_in_id self.out_ids[i] = r_packet.last_line_interaction_out_id self.shell_ids[i] = r_packet.last_line_interaction_shell_id @@ -110,6 +116,7 @@ def update_last_interaction(self, r_packet, i): ("number_of_vpackets", int64), ("length", int64), ("last_interaction_in_nu", float64[:]), + ("last_interaction_in_r", float64[:]), ("last_interaction_type", int64[:]), ("last_interaction_in_id", int64[:]), ("last_interaction_out_id", int64[:]), @@ -139,6 +146,9 @@ def __init__( self.last_interaction_in_nu = np.zeros( temporary_v_packet_bins, dtype=np.float64 ) + self.last_interaction_in_r = np.zeros( + temporary_v_packet_bins, dtype=np.float64 + ) self.last_interaction_type = -1 * np.ones( temporary_v_packet_bins, dtype=np.int64 ) @@ -162,6 +172,7 @@ def add_packet( initial_mu, initial_r, last_interaction_in_nu, + last_interaction_in_r, last_interaction_type, last_interaction_in_id, last_interaction_out_id, @@ -182,6 +193,8 @@ def add_packet( Initial r of the packet. last_interaction_in_nu : float Frequency of the last interaction of the packet. + last_interaction_in_r : float + Radius of the last interaction of the packet. last_interaction_type : int Type of the last interaction of the packet. last_interaction_in_id : int @@ -205,6 +218,7 @@ def add_packet( temp_last_interaction_in_nu = np.empty( temp_length, dtype=np.float64 ) + temp_last_interaction_in_r = np.empty(temp_length, dtype=np.float64) temp_last_interaction_type = np.empty(temp_length, dtype=np.int64) temp_last_interaction_in_id = np.empty(temp_length, dtype=np.int64) temp_last_interaction_out_id = np.empty(temp_length, dtype=np.int64) @@ -219,6 +233,9 @@ def add_packet( temp_last_interaction_in_nu[ : self.length ] = self.last_interaction_in_nu + temp_last_interaction_in_r[ + : self.length + ] = self.last_interaction_in_r temp_last_interaction_type[ : self.length ] = self.last_interaction_type @@ -237,6 +254,7 @@ def add_packet( self.initial_mus = temp_initial_mus self.initial_rs = temp_initial_rs self.last_interaction_in_nu = temp_last_interaction_in_nu + self.last_interaction_in_r = temp_last_interaction_in_r self.last_interaction_type = temp_last_interaction_type self.last_interaction_in_id = temp_last_interaction_in_id self.last_interaction_out_id = temp_last_interaction_out_id @@ -248,6 +266,7 @@ def add_packet( self.initial_mus[self.idx] = initial_mu self.initial_rs[self.idx] = initial_r self.last_interaction_in_nu[self.idx] = last_interaction_in_nu + self.last_interaction_in_r[self.idx] = last_interaction_in_r self.last_interaction_type[self.idx] = last_interaction_type self.last_interaction_in_id[self.idx] = last_interaction_in_id self.last_interaction_out_id[self.idx] = last_interaction_out_id @@ -268,6 +287,7 @@ def finalize_arrays(self): self.initial_mus = self.initial_mus[: self.idx] self.initial_rs = self.initial_rs[: self.idx] self.last_interaction_in_nu = self.last_interaction_in_nu[: self.idx] + self.last_interaction_in_r = self.last_interaction_in_r[: self.idx] self.last_interaction_type = self.last_interaction_type[: self.idx] self.last_interaction_in_id = self.last_interaction_in_id[: self.idx] self.last_interaction_out_id = self.last_interaction_out_id[: self.idx] @@ -328,6 +348,9 @@ def consolidate_vpacket_tracker( vpacket_tracker.last_interaction_in_nu[ current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx ] = vpacket_collection.last_interaction_in_nu + vpacket_tracker.last_interaction_in_r[ + current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx + ] = vpacket_collection.last_interaction_in_r vpacket_tracker.last_interaction_type[ current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx diff --git a/tardis/transport/montecarlo/r_packet.py b/tardis/transport/montecarlo/r_packet.py index e781ce3572f..038468afcac 100644 --- a/tardis/transport/montecarlo/r_packet.py +++ b/tardis/transport/montecarlo/r_packet.py @@ -11,7 +11,6 @@ from tardis.transport.frame_transformations import ( get_doppler_factor, ) -from tardis.transport.montecarlo import numba_config as nc from tardis.transport.montecarlo import njit_dict_no_parallel @@ -41,6 +40,7 @@ class PacketStatus(IntEnum): ("index", int64), ("last_interaction_type", int64), ("last_interaction_in_nu", float64), + ("last_interaction_in_r", float64), ("last_line_interaction_in_id", int64), ("last_line_interaction_out_id", int64), ("last_line_interaction_shell_id", int64), @@ -60,6 +60,7 @@ def __init__(self, r, mu, nu, energy, seed, index=0): self.index = index self.last_interaction_type = -1 self.last_interaction_in_nu = 0.0 + self.last_interaction_in_r = 0.0 self.last_line_interaction_in_id = -1 self.last_line_interaction_out_id = -1 self.last_line_interaction_shell_id = -1 diff --git a/tardis/transport/montecarlo/r_packet_transport.py b/tardis/transport/montecarlo/r_packet_transport.py index 24238968b87..a7a9aa01cbd 100644 --- a/tardis/transport/montecarlo/r_packet_transport.py +++ b/tardis/transport/montecarlo/r_packet_transport.py @@ -1,6 +1,7 @@ import numpy as np from numba import njit +import tardis.transport.montecarlo.configuration.montecarlo_globals as montecarlo_globals from tardis.transport.frame_transformations import ( get_doppler_factor, ) @@ -28,7 +29,6 @@ def trace_packet( estimators, chi_continuum, escat_prob, - continuum_processes_enabled, enable_full_relativity, disable_line_scattering, ): @@ -113,7 +113,7 @@ def trace_packet( r_packet.next_line_id = cur_line_id break elif distance == distance_continuum: - if not continuum_processes_enabled: + if not montecarlo_globals.CONTINUUM_PROCESSES_ENABLED: interaction_type = InteractionType.ESCATTERING else: zrand = np.random.random() @@ -140,6 +140,7 @@ def trace_packet( if tau_trace_combined > tau_event and not disable_line_scattering: interaction_type = InteractionType.LINE # Line r_packet.last_interaction_in_nu = r_packet.nu + r_packet.last_interaction_in_r = r_packet.r r_packet.last_line_interaction_in_id = cur_line_id r_packet.last_line_interaction_shell_id = r_packet.current_shell_id r_packet.next_line_id = cur_line_id @@ -163,7 +164,7 @@ def trace_packet( cur_line_id += 1 if distance_continuum < distance_boundary: distance = distance_continuum - if not continuum_processes_enabled: + if not montecarlo_globals.CONTINUUM_PROCESSES_ENABLED: interaction_type = InteractionType.ESCATTERING else: zrand = np.random.random() diff --git a/tardis/transport/montecarlo/single_packet_loop.py b/tardis/transport/montecarlo/single_packet_loop.py index e695ae5dbbc..f32b2715a18 100644 --- a/tardis/transport/montecarlo/single_packet_loop.py +++ b/tardis/transport/montecarlo/single_packet_loop.py @@ -9,6 +9,7 @@ get_doppler_factor, get_inverse_doppler_factor, ) +from tardis.transport.montecarlo.configuration import montecarlo_globals from tardis.transport.montecarlo.estimators.radfield_estimator_calcs import ( update_bound_free_estimators, ) @@ -21,16 +22,12 @@ InteractionType, PacketStatus, ) -from tardis.transport.montecarlo.vpacket import trace_vpacket_volley -from tardis.transport.frame_transformations import ( - get_doppler_factor, - get_inverse_doppler_factor, -) from tardis.transport.montecarlo.r_packet_transport import ( move_packet_across_shell_boundary, move_r_packet, trace_packet, ) +from tardis.transport.montecarlo.vpacket import trace_vpacket_volley C_SPEED_OF_LIGHT = const.c.to("cm/s").value @@ -84,10 +81,9 @@ def single_packet_loop( montecarlo_configuration.ENABLE_FULL_RELATIVITY, montecarlo_configuration.VPACKET_TAU_RUSSIAN, montecarlo_configuration.SURVIVAL_PROBABILITY, - montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, ) - if montecarlo_configuration.ENABLE_RPACKET_TRACKING: + if montecarlo_globals.ENABLE_RPACKET_TRACKING: rpacket_tracker.track(r_packet) # this part of the code is temporary and will be better incorporated @@ -105,7 +101,7 @@ def single_packet_loop( chi_e = chi_electron_calculator( opacity_state, comov_nu, r_packet.current_shell_id ) - if montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED: + if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED: ( chi_bf_tot, chi_bf_contributions, @@ -128,7 +124,6 @@ def single_packet_loop( estimators, chi_continuum, escat_prob, - montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, montecarlo_configuration.ENABLE_FULL_RELATIVITY, montecarlo_configuration.DISABLE_LINE_SCATTERING, ) @@ -156,7 +151,6 @@ def single_packet_loop( estimators, chi_continuum, escat_prob, - montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, montecarlo_configuration.ENABLE_FULL_RELATIVITY, montecarlo_configuration.DISABLE_LINE_SCATTERING, ) @@ -189,7 +183,6 @@ def single_packet_loop( time_explosion, line_interaction_type, opacity_state, - montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, montecarlo_configuration.ENABLE_FULL_RELATIVITY, ) trace_vpacket_volley( @@ -201,7 +194,6 @@ def single_packet_loop( montecarlo_configuration.ENABLE_FULL_RELATIVITY, montecarlo_configuration.VPACKET_TAU_RUSSIAN, montecarlo_configuration.SURVIVAL_PROBABILITY, - montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, ) elif interaction_type == InteractionType.ESCATTERING: @@ -229,10 +221,9 @@ def single_packet_loop( montecarlo_configuration.ENABLE_FULL_RELATIVITY, montecarlo_configuration.VPACKET_TAU_RUSSIAN, montecarlo_configuration.SURVIVAL_PROBABILITY, - montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, ) elif ( - montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED + montecarlo_globals.CONTINUUM_PROCESSES_ENABLED and interaction_type == InteractionType.CONTINUUM_PROCESS ): r_packet.last_interaction_type = InteractionType.CONTINUUM_PROCESS @@ -251,7 +242,6 @@ def single_packet_loop( chi_ff, chi_bf_contributions, current_continua, - montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, montecarlo_configuration.ENABLE_FULL_RELATIVITY, ) @@ -264,11 +254,10 @@ def single_packet_loop( montecarlo_configuration.ENABLE_FULL_RELATIVITY, montecarlo_configuration.VPACKET_TAU_RUSSIAN, montecarlo_configuration.SURVIVAL_PROBABILITY, - montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED, ) else: pass - if montecarlo_configuration.ENABLE_RPACKET_TRACKING: + if montecarlo_globals.ENABLE_RPACKET_TRACKING: rpacket_tracker.track(r_packet) diff --git a/tardis/transport/montecarlo/tests/conftest.py b/tardis/transport/montecarlo/tests/conftest.py index b67fc70b1f4..8d63ca134e2 100644 --- a/tardis/transport/montecarlo/tests/conftest.py +++ b/tardis/transport/montecarlo/tests/conftest.py @@ -34,7 +34,6 @@ def verysimple_opacity_state(nb_simulation_verysimple): nb_simulation_verysimple.plasma, line_interaction_type="macroatom", disable_line_scattering=False, - continuum_processes_enabled=False, ) diff --git a/tardis/transport/montecarlo/tests/test_base.py b/tardis/transport/montecarlo/tests/test_base.py index 9c7ec3bffbe..3e0e0e450c8 100644 --- a/tardis/transport/montecarlo/tests/test_base.py +++ b/tardis/transport/montecarlo/tests/test_base.py @@ -53,6 +53,7 @@ def test_hdf_transport( "emitted_packet_mask", "last_interaction_type", "last_interaction_in_nu", + "last_interaction_in_r", "last_line_interaction_out_id", "last_line_interaction_in_id", "last_line_interaction_shell_id", @@ -61,6 +62,7 @@ def test_hdf_transport( "virt_packet_initial_rs", "virt_packet_initial_mus", "virt_packet_last_interaction_in_nu", + "virt_packet_last_interaction_in_r", "virt_packet_last_interaction_type", "virt_packet_last_line_interaction_in_id", "virt_packet_last_line_interaction_out_id", diff --git a/tardis/transport/montecarlo/tests/test_interaction.py b/tardis/transport/montecarlo/tests/test_interaction.py index 2c6b31ee955..bd6d04422cf 100644 --- a/tardis/transport/montecarlo/tests/test_interaction.py +++ b/tardis/transport/montecarlo/tests/test_interaction.py @@ -48,7 +48,6 @@ def test_line_scatter( time_explosion, line_interaction_type, verysimple_opacity_state, - continuum_processes_enabled=False, enable_full_relativity=False, ) diff --git a/tardis/transport/montecarlo/tests/test_numba_interface.py b/tardis/transport/montecarlo/tests/test_numba_interface.py index 0b41c863eef..25d907049f6 100644 --- a/tardis/transport/montecarlo/tests/test_numba_interface.py +++ b/tardis/transport/montecarlo/tests/test_numba_interface.py @@ -12,7 +12,6 @@ def test_opacity_state_initialize(nb_simulation_verysimple, input_params): plasma, line_interaction_type, disable_line_scattering=False, - continuum_processes_enabled=False, ) npt.assert_allclose( @@ -72,6 +71,9 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection): last_interaction_in_nus = np.array( [3.0e15, 0.0, 1e15, 1e5], dtype=np.float64 ) + last_interaction_in_rs = np.array( + [3e42, 4.5e45, 0, 9.0e40], dtype=np.float64 + ) last_interaction_types = np.array([1, 1, 3, 2], dtype=np.int64) last_interaction_in_ids = np.array([100, 0, 1, 1000], dtype=np.int64) last_interaction_out_ids = np.array([1201, 123, 545, 1232], dtype=np.int64) @@ -83,6 +85,7 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection): initial_mu, initial_r, last_interaction_in_nu, + last_interaction_in_r, last_interaction_type, last_interaction_in_id, last_interaction_out_id, @@ -93,6 +96,7 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection): initial_mus, initial_rs, last_interaction_in_nus, + last_interaction_in_rs, last_interaction_types, last_interaction_in_ids, last_interaction_out_ids, @@ -104,6 +108,7 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection): initial_mu, initial_r, last_interaction_in_nu, + last_interaction_in_r, last_interaction_type, last_interaction_in_id, last_interaction_out_id, @@ -140,6 +145,12 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection): ], last_interaction_in_nus, ) + npt.assert_array_equal( + verysimple_3vpacket_collection.last_interaction_in_r[ + : verysimple_3vpacket_collection.idx + ], + last_interaction_in_rs, + ) npt.assert_array_equal( verysimple_3vpacket_collection.last_interaction_type[ : verysimple_3vpacket_collection.idx diff --git a/tardis/transport/montecarlo/tests/test_packet.py b/tardis/transport/montecarlo/tests/test_packet.py index df9ca7d6ba6..2ab7a636818 100644 --- a/tardis/transport/montecarlo/tests/test_packet.py +++ b/tardis/transport/montecarlo/tests/test_packet.py @@ -5,7 +5,7 @@ import tardis.transport.frame_transformations as frame_transformations import tardis.transport.geometry.calculate_distances as calculate_distances import tardis.transport.montecarlo.estimators.radfield_mc_estimators -import tardis.transport.montecarlo.montecarlo_configuration as numba_config +import tardis.transport.montecarlo.configuration.montecarlo_globals as montecarlo_globals import tardis.transport.montecarlo.numba_interface as numba_interface import tardis.transport.montecarlo.r_packet as r_packet import tardis.transport.montecarlo.utils as utils @@ -243,7 +243,6 @@ def test_trace_packet( verysimple_time_explosion, verysimple_opacity_state, verysimple_estimators, - continuum_processes_enabled=False, enable_full_relativity=False, disable_line_scattering=False, ) @@ -292,7 +291,7 @@ def test_move_r_packet( packet.energy = packet_params["energy"] packet.r = packet_params["r"] - numba_config.ENABLE_FULL_RELATIVITY = ENABLE_FULL_RELATIVITY + montecarlo_globals.ENABLE_FULL_RELATIVITY = ENABLE_FULL_RELATIVITY r_packet_transport.move_r_packet.recompile() # This must be done as move_r_packet was jitted with ENABLE_FULL_RELATIVITY doppler_factor = frame_transformations.get_doppler_factor( packet.r, packet.mu, time_explosion, ENABLE_FULL_RELATIVITY @@ -316,7 +315,7 @@ def test_move_r_packet( expected_j *= doppler_factor expected_nubar *= doppler_factor - numba_config.ENABLE_FULL_RELATIVITY = False + montecarlo_globals.ENABLE_FULL_RELATIVITY = False assert_allclose( estimators.j_estimator[packet.current_shell_id], expected_j, rtol=5e-7 ) diff --git a/tardis/transport/montecarlo/tests/test_packet_source.py b/tardis/transport/montecarlo/tests/test_packet_source.py index 41b3af595fb..a34dd2c0b71 100644 --- a/tardis/transport/montecarlo/tests/test_packet_source.py +++ b/tardis/transport/montecarlo/tests/test_packet_source.py @@ -10,9 +10,6 @@ BlackBodySimpleSource, BlackBodySimpleSourceRelativistic, ) -from tardis.transport.montecarlo import ( - montecarlo_configuration as montecarlo_configuration, -) from tardis.tests.fixtures.regression_data import RegressionData diff --git a/tardis/transport/montecarlo/tests/test_vpacket.py b/tardis/transport/montecarlo/tests/test_vpacket.py index 43906cf642e..e081abe7670 100644 --- a/tardis/transport/montecarlo/tests/test_vpacket.py +++ b/tardis/transport/montecarlo/tests/test_vpacket.py @@ -60,7 +60,6 @@ def test_trace_vpacket_within_shell( verysimple_time_explosion, verysimple_opacity_state, enable_full_relativity=False, - continuum_processes_enabled=False, ) npt.assert_almost_equal(tau_trace_combined, 8164850.891288479) @@ -91,7 +90,6 @@ def test_trace_vpacket( 10.0, 0.0, enable_full_relativity=False, - continuum_processes_enabled=False, ) npt.assert_almost_equal(tau_trace_combined, 8164850.891288479) @@ -130,7 +128,6 @@ def test_trace_vpacket_volley( enable_full_relativity=False, tau_russian=10.0, survival_probability=0.0, - continuum_processes_enabled=False, ) @@ -161,5 +158,4 @@ def test_trace_bad_vpacket( 10.0, 0.0, enable_full_relativity=False, - continuum_processes_enabled=False, ) diff --git a/tardis/transport/montecarlo/vpacket.py b/tardis/transport/montecarlo/vpacket.py index c84aa74fd00..657e1c016d0 100644 --- a/tardis/transport/montecarlo/vpacket.py +++ b/tardis/transport/montecarlo/vpacket.py @@ -4,6 +4,7 @@ from numba import float64, int64, njit from numba.experimental import jitclass +import tardis.transport.montecarlo.configuration.montecarlo_globals as montecarlo_globals from tardis.opacities.opacities import ( chi_continuum_calculator, ) @@ -17,7 +18,7 @@ calculate_distance_line, ) from tardis.transport.montecarlo import njit_dict_no_parallel -from tardis.transport.montecarlo.numba_config import ( +from tardis.transport.montecarlo.configuration.constants import ( C_SPEED_OF_LIGHT, SIGMA_THOMSON, ) @@ -69,7 +70,6 @@ def trace_vpacket_within_shell( time_explosion, opacity_state, enable_full_relativity, - continuum_processes_enabled, ): """ Trace VPacket within one shell (relatively simple operation) @@ -100,7 +100,7 @@ def trace_vpacket_within_shell( comov_nu = v_packet.nu * doppler_factor - if continuum_processes_enabled: + if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED: ( chi_bf_tot, chi_bf_contributions, @@ -167,7 +167,6 @@ def trace_vpacket( tau_russian, survival_probability, enable_full_relativity, - continuum_processes_enabled, ): """ Trace single vpacket. @@ -194,7 +193,6 @@ def trace_vpacket( time_explosion, opacity_state, enable_full_relativity, - continuum_processes_enabled, ) tau_trace_combined += tau_trace_combined_shell @@ -239,7 +237,6 @@ def trace_vpacket_volley( enable_full_relativity, tau_russian, survival_probability, - continuum_processes_enabled, ): """ Shoot a volley of vpackets (the vpacket collection specifies how many) @@ -352,7 +349,6 @@ def trace_vpacket_volley( tau_russian, survival_probability, enable_full_relativity, - continuum_processes_enabled, ) v_packet.energy *= math.exp(-tau_vpacket) @@ -363,6 +359,7 @@ def trace_vpacket_volley( v_packet_mu, r_packet.r, r_packet.last_interaction_in_nu, + r_packet.last_interaction_in_r, r_packet.last_interaction_type, r_packet.last_line_interaction_in_id, r_packet.last_line_interaction_out_id, diff --git a/tardis/visualization/__init__.py b/tardis/visualization/__init__.py index 73ccae5ce78..cadfff0cdc3 100644 --- a/tardis/visualization/__init__.py +++ b/tardis/visualization/__init__.py @@ -11,3 +11,4 @@ from tardis.visualization.widgets.custom_abundance import CustomAbundanceWidget from tardis.visualization.tools.sdec_plot import SDECPlotter from tardis.visualization.tools.rpacket_plot import RPacketPlotter +from tardis.visualization.tools.liv_plot import LIVPlotter diff --git a/tardis/visualization/plot_util.py b/tardis/visualization/plot_util.py index 7d3b81186f1..54b740411dc 100644 --- a/tardis/visualization/plot_util.py +++ b/tardis/visualization/plot_util.py @@ -60,3 +60,22 @@ def get_mid_point_idx(arr): """ mid_value = (arr[0] + arr[-1]) / 2 return np.abs(arr - mid_value).argmin() + + +def to_rgb255_string(color_tuple): + """ + Convert a matplotlib RGBA tuple to a generic RGB 255 string. + + Parameters + ---------- + color_tuple : tuple + Matplotlib RGBA tuple of float values in closed interval [0, 1] + + Returns + ------- + str + RGB string of format rgb(r,g,b) where r,g,b are integers between + 0 and 255 (both inclusive) + """ + color_tuple_255 = tuple([int(x * 255) for x in color_tuple[:3]]) + return f"rgb{color_tuple_255}" diff --git a/tardis/visualization/tools/liv_plot.py b/tardis/visualization/tools/liv_plot.py new file mode 100644 index 00000000000..eef3a4e0cf4 --- /dev/null +++ b/tardis/visualization/tools/liv_plot.py @@ -0,0 +1,532 @@ +import logging +import matplotlib.pyplot as plt +import matplotlib.cm as cm +import plotly.graph_objects as go +import numpy as np +import pandas as pd +import astropy.units as u + +from tardis.util.base import ( + atomic_number2element_symbol, + int_to_roman, +) +import tardis.visualization.tools.sdec_plot as sdec +from tardis.visualization import plot_util as pu + +logger = logging.getLogger(__name__) + + +class LIVPlotter: + """ + Plotting interface for the last interaction velocity plot. + """ + + def __init__(self, data, time_explosion, velocity): + """ + Initialize the plotter with required data from the simulation. + + Parameters + ---------- + data : dict of SDECData + Dictionary to store data required for last interaction velocity plot, + for both packet modes (real, virtual). + + time_explosion : astropy.units.Quantity + Time of the explosion. + + velocity : astropy.units.Quantity + Velocity array from the simulation. + """ + + self.data = data + self.time_explosion = time_explosion + self.velocity = velocity + self.sdec_plotter = sdec.SDECPlotter(data) + + @classmethod + def from_simulation(cls, sim): + """ + Create an instance of the plotter from a TARDIS simulation object. + + Parameters + ---------- + sim : tardis.simulation.Simulation + TARDIS simulation object produced by running a simulation. + + Returns + ------- + LIVPlotter + """ + + return cls( + dict( + virtual=sdec.SDECData.from_simulation(sim, "virtual"), + real=sdec.SDECData.from_simulation(sim, "real"), + ), + sim.plasma.time_explosion, + sim.simulation_state.velocity, + ) + + @classmethod + def from_hdf(cls, hdf_fpath): + """ + Create an instance of the Plotter from a simulation HDF file. + + Parameters + ---------- + hdf_fpath : str + Valid path to the HDF file where simulation is saved. + + Returns + ------- + LIVPlotter + """ + with pd.HDFStore(hdf_fpath, "r") as hdf: + time_explosion = ( + hdf["/simulation/plasma/scalars"]["time_explosion"] * u.s + ) + v_inner = hdf["/simulation/simulation_state/v_inner"] * (u.cm / u.s) + v_outer = hdf["/simulation/simulation_state/v_outer"] * (u.cm / u.s) + velocity = pd.concat( + [v_inner, pd.Series([v_outer.iloc[-1]])], ignore_index=True + ).tolist() * (u.cm / u.s) + return cls( + dict( + virtual=sdec.SDECData.from_hdf(hdf_fpath, "virtual"), + real=sdec.SDECData.from_hdf(hdf_fpath, "real"), + ), + time_explosion, + velocity, + ) + + def _parse_species_list(self, species_list, packets_mode, nelements=None): + """ + Parse user requested species list and create list of species ids to be used. + + Parameters + ---------- + species_list : list of species to plot + List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours. + Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions + (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) + packets_mode : str, optional + Packet mode, either 'virtual' or 'real'. Default is 'virtual'. + nelements : int, optional + Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. + + Raises + ------ + ValueError + If species list contains invalid entries. + + """ + self.sdec_plotter._parse_species_list(species_list) + self._species_list = self.sdec_plotter._species_list + self._species_mapped = self.sdec_plotter._species_mapped + self._keep_colour = self.sdec_plotter._keep_colour + + if nelements: + interaction_counts = ( + self.data[packets_mode] + .packets_df_line_interaction["last_line_interaction_species"] + .value_counts() + ) + interaction_counts.index = interaction_counts.index // 100 + element_counts = interaction_counts.groupby( + interaction_counts.index + ).sum() + top_elements = element_counts.nlargest(nelements).index + top_species_list = [ + atomic_number2element_symbol(element) + for element in top_elements + ] + self._parse_species_list(top_species_list, packets_mode) + + def _make_colorbar_labels(self): + """ + Generate labels for the colorbar based on species. + + If a species list is provided, uses that to generate labels. + Otherwise, generates labels from the species in the model. + """ + if self._species_list is None: + species_name = [ + atomic_number2element_symbol(atomic_num) + for atomic_num in self.species + ] + else: + species_name = [] + for species_key, species_ids in self._species_mapped.items(): + if any(species in self.species for species in species_ids): + if species_key % 100 == 0: + label = atomic_number2element_symbol(species_key // 100) + else: + atomic_number = species_key // 100 + ion_number = species_key % 100 + ion_numeral = int_to_roman(ion_number + 1) + label = f"{atomic_number2element_symbol(atomic_number)} {ion_numeral}" + species_name.append(label) + + self._species_name = species_name + + def _make_colorbar_colors(self): + """ + Generate colors for the species to be plotted. + + This method creates a list of colors corresponding to the species names. + The colors are generated based on the species present in the model and + the requested species list. + """ + color_list = [] + species_keys = list(self._species_mapped.keys()) + num_species = len(species_keys) + + for species_counter, species_key in enumerate(species_keys): + if any( + species in self.species + for species in self._species_mapped[species_key] + ): + color = self.cmap(species_counter / num_species) + color_list.append(color) + + self._color_list = color_list + + def _generate_plot_data(self, packets_mode): + """ + Generate plot data and colors for species in the model. + + Parameters + ---------- + packets_mode : str + Packet mode, either 'virtual' or 'real'. + + Returns + ------- + plot_data : list + List of velocity data for each species. + + plot_colors : list + List of colors corresponding to each species. + """ + groups = self.data[packets_mode].packets_df_line_interaction.groupby( + by="last_line_interaction_species" + ) + + plot_colors = [] + plot_data = [] + species_counter = 0 + + for specie_list in self._species_mapped.values(): + full_v_last = [] + for specie in specie_list: + if specie in self.species: + g_df = groups.get_group(specie) + r_last_interaction = ( + g_df["last_interaction_in_r"].values * u.cm + ) + v_last_interaction = ( + r_last_interaction / self.time_explosion + ).to("km/s") + full_v_last.extend(v_last_interaction) + if full_v_last: + plot_data.append(full_v_last) + plot_colors.append(self._color_list[species_counter]) + species_counter += 1 + + return plot_data, plot_colors + + def _prepare_plot_data( + self, packets_mode, species_list, cmapname, num_bins, nelements + ): + """ + Prepare data and settings required for generating a plot. + + This method handles the common logic for preparing data and settings + needed to generate both matplotlib and plotly plots. It parses the species + list, generates color labels and colormap, and bins the velocity data. + + Parameters + ---------- + packets_mode : str + Packet mode, either 'virtual' or 'real'. + species_list : list of str + List of species to plot. Species can be specified as an ion + (e.g., Si II), an element (e.g., Si), a range of ions (e.g., Si I-V), + or any combination of these. + cmapname : str + Name of the colormap to use. A specific colormap can be chosen, such + as "jet", "viridis", "plasma", etc. + num_bins : int, optional + Number of bins for regrouping within the same range. If None, + no regrouping is done. + + Raises + ------ + ValueError + If no species are provided for plotting, or if no valid species are + found in the model. + + Returns + ------- + plot_data : list + List of velocity data for each species. + plot_colors : list + List of colors corresponding to each species. + new_bin_edges : np.ndarray + Array of bin edges for the velocity data. + """ + if species_list is None: + # Extract all unique elements from the packets data + species_in_model = np.unique( + self.data[packets_mode] + .packets_df_line_interaction["last_line_interaction_species"] + .values + ) + species_list = [ + f"{atomic_number2element_symbol(specie // 100)}" + for specie in species_in_model + ] + self._parse_species_list(species_list, packets_mode, nelements) + species_in_model = np.unique( + self.data[packets_mode] + .packets_df_line_interaction["last_line_interaction_species"] + .values + ) + if self._species_list is None or not self._species_list: + raise ValueError("No species provided for plotting.") + msk = np.isin(self._species_list, species_in_model) + self.species = np.array(self._species_list)[msk] + + if len(self.species) == 0: + raise ValueError("No valid species found for plotting.") + + self._make_colorbar_labels() + self.cmap = cm.get_cmap(cmapname, len(self._species_name)) + self._make_colorbar_colors() + plot_data, plot_colors = self._generate_plot_data(packets_mode) + bin_edges = (self.velocity).to("km/s") + + if num_bins: + if num_bins < 1: + raise ValueError("Number of bins must be positive") + elif num_bins > len(bin_edges) - 1: + logger.warn( + "Number of bins must be less than or equal to number of shells. Plotting with number of bins equals to number of shells." + ) + new_bin_edges = bin_edges + else: + new_bin_edges = np.linspace( + bin_edges[0], bin_edges[-1], num_bins + 1 + ) + else: + new_bin_edges = bin_edges + + return plot_data, plot_colors, new_bin_edges + + def _get_step_plot_data(self, data, bin_edges): + """ + Generate step plot data from histogram data. + + Parameters + ---------- + data : array-like + Data to be binned into a histogram. + bin_edges : array-like + Edges of the bins for the histogram. + + Returns + ------- + step_x : np.ndarray + x-coordinates for the step plot. + step_y : np.ndarray + y-coordinates for the step plot. + """ + hist, _ = np.histogram(data, bins=bin_edges) + step_x = np.repeat(bin_edges, 2)[1:-1] + step_y = np.repeat(hist, 2) + return step_x, step_y + + def generate_plot_mpl( + self, + species_list=None, + nelements=None, + packets_mode="virtual", + ax=None, + figsize=(11, 5), + cmapname="jet", + xlog_scale=False, + ylog_scale=False, + num_bins=None, + velocity_range=None, + ): + """ + Generate the last interaction velocity distribution plot using matplotlib. + + Parameters + ---------- + species_list : list of str, optional + List of species to plot. Default is None which plots all species in the model. + nelements : int, optional + Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. + packets_mode : str, optional + Packet mode, either 'virtual' or 'real'. Default is 'virtual'. + ax : matplotlib.axes.Axes, optional + Axes object to plot on. If None, creates a new figure. + figsize : tuple, optional + Size of the figure. Default is (11, 5). + cmapname : str, optional + Colormap name. Default is 'jet'. A specific colormap can be chosen, such as "jet", "viridis", "plasma", etc. + xlog_scale : bool, optional + If True, x-axis is scaled logarithmically. Default is False. + ylog_scale : bool, optional + If True, y-axis is scaled logarithmically. Default is False. + num_bins : int, optional + Number of bins for regrouping within the same range. Default is None. + velocity_range : tuple, optional + Limits for the x-axis. If specified, overrides any automatically determined limits. + + Returns + ------- + matplotlib.axes.Axes + Axes object with the plot. + """ + # If species_list and nelements requested, tell user that nelements is ignored + if species_list is not None and nelements is not None: + logger.info( + "Both nelements and species_list were requested. Species_list takes priority; nelements is ignored" + ) + nelements = None + + plot_data, plot_colors, bin_edges = self._prepare_plot_data( + packets_mode, species_list, cmapname, num_bins, nelements + ) + + if ax is None: + self.ax = plt.figure(figsize=figsize).add_subplot(111) + else: + self.ax = ax + + for data, color, name in zip( + plot_data, plot_colors, self._species_name + ): + step_x, step_y = self._get_step_plot_data(data, bin_edges) + self.ax.plot( + step_x, + step_y, + label=name, + color=color, + linewidth=2.5, + drawstyle="steps-post", + alpha=0.75, + ) + + self.ax.ticklabel_format(axis="y", scilimits=(0, 0)) + self.ax.tick_params("both", labelsize=15) + self.ax.set_xlabel("Last Interaction Velocity (km/s)", fontsize=14) + self.ax.set_ylabel("Packet Count", fontsize=15) + self.ax.legend(fontsize=15, bbox_to_anchor=(1.0, 1.0), loc="upper left") + self.ax.figure.tight_layout() + if xlog_scale: + self.ax.set_xscale("log") + if ylog_scale: + self.ax.set_yscale("log") + if velocity_range: + self.ax.set_xlim(velocity_range[0], velocity_range[1]) + + return self.ax + + def generate_plot_ply( + self, + species_list=None, + nelements=None, + packets_mode="virtual", + fig=None, + graph_height=600, + cmapname="jet", + xlog_scale=False, + ylog_scale=False, + num_bins=None, + velocity_range=None, + ): + """ + Generate the last interaction velocity distribution plot using plotly. + + Parameters + ---------- + species_list : list of str, optional + List of species to plot. Default is None which plots all species in the model. + nelements : int, optional + Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. + packets_mode : str, optional + Packet mode, either 'virtual' or 'real'. Default is 'virtual'. + fig : plotly.graph_objects.Figure, optional + Plotly figure object to add the plot to. If None, creates a new figure. + graph_height : int, optional + Height (in px) of the plotly graph to display. Default value is 600. + cmapname : str, optional + Colormap name. Default is 'jet'. A specific colormap can be chosen, such as "jet", "viridis", "plasma", etc. + xlog_scale : bool, optional + If True, x-axis is scaled logarithmically. Default is False. + ylog_scale : bool, optional + If True, y-axis is scaled logarithmically. Default is False. + num_bins : int, optional + Number of bins for regrouping within the same range. Default is None. + velocity_range : tuple, optional + Limits for the x-axis. If specified, overrides any automatically determined limits. + + Returns + ------- + plotly.graph_objects.Figure + Plotly figure object with the plot. + """ + # If species_list and nelements requested, tell user that nelements is ignored + if species_list is not None and nelements is not None: + logger.info( + "Both nelements and species_list were requested. Species_list takes priority; nelements is ignored" + ) + nelements = None + + plot_data, plot_colors, bin_edges = self._prepare_plot_data( + packets_mode, species_list, cmapname, num_bins, nelements + ) + + if fig is None: + self.fig = go.Figure() + else: + self.fig = fig + + for data, color, name in zip( + plot_data, plot_colors, self._species_name + ): + step_x, step_y = self._get_step_plot_data(data, bin_edges) + self.fig.add_trace( + go.Scatter( + x=step_x, + y=step_y, + mode="lines", + line=dict( + color=pu.to_rgb255_string(color), + width=2.5, + shape="hv", + ), + name=name, + opacity=0.75, + ) + ) + self.fig.update_layout( + height=graph_height, + xaxis_title="Last Interaction Velocity (km/s)", + yaxis_title="Packet Count", + font=dict(size=15), + yaxis=dict(exponentformat="power" if ylog_scale else "e"), + xaxis=dict(exponentformat="power" if xlog_scale else "none"), + ) + if xlog_scale: + self.fig.update_xaxes(type="log") + if ylog_scale: + self.fig.update_yaxes(type="log", dtick=1) + + if velocity_range: + self.fig.update_xaxes(range=velocity_range) + + return self.fig diff --git a/tardis/visualization/tools/sdec_plot.py b/tardis/visualization/tools/sdec_plot.py index 75b01aa4f39..3a68a19ae4a 100644 --- a/tardis/visualization/tools/sdec_plot.py +++ b/tardis/visualization/tools/sdec_plot.py @@ -4,6 +4,7 @@ This plot is a spectral diagnostics plot similar to those originally proposed by M. Kromer (see, for example, Kromer et al. 2013, figure 4). """ + import logging import astropy.units as u @@ -40,6 +41,7 @@ def __init__( last_line_interaction_in_id, last_line_interaction_out_id, last_line_interaction_in_nu, + last_interaction_in_r, lines_df, packet_nus, packet_energies, @@ -67,6 +69,8 @@ def __init__( emission (interaction out) last_line_interaction_in_nu : np.array Frequency values of the last absorption of emitted packets + last_line_interaction_in_r : np.array + Radius of the last interaction experienced by emitted packets lines_df : pd.DataFrame Data about the atomic lines present in simulation model's plasma packet_nus : astropy.Quantity @@ -98,6 +102,7 @@ def __init__( "last_line_interaction_out_id": last_line_interaction_out_id, "last_line_interaction_in_id": last_line_interaction_in_id, "last_line_interaction_in_nu": last_line_interaction_in_nu, + "last_interaction_in_r": last_interaction_in_r, } ) @@ -177,6 +182,7 @@ def from_simulation(cls, sim, packets_mode): last_line_interaction_in_id=transport_state.vpacket_tracker.last_interaction_in_id, last_line_interaction_out_id=transport_state.vpacket_tracker.last_interaction_out_id, last_line_interaction_in_nu=transport_state.vpacket_tracker.last_interaction_in_nu, + last_interaction_in_r=transport_state.vpacket_tracker.last_interaction_in_r, lines_df=lines_df, packet_nus=u.Quantity( transport_state.vpacket_tracker.nus, "Hz" @@ -210,6 +216,9 @@ def from_simulation(cls, sim, packets_mode): last_line_interaction_in_nu=transport_state.last_interaction_in_nu[ transport_state.emitted_packet_mask ], + last_interaction_in_r=transport_state.last_interaction_in_r[ + transport_state.emitted_packet_mask + ], lines_df=lines_df, packet_nus=transport_state.packet_collection.output_nus[ transport_state.emitted_packet_mask @@ -283,6 +292,12 @@ def from_hdf(cls, hdf_fpath, packets_mode): ].to_numpy(), "Hz", ), + last_interaction_in_r=u.Quantity( + hdf[ + "/simulation/transport/transport_state/virt_packet_last_interaction_in_r" + ].to_numpy(), + "cm", + ), lines_df=lines_df, packet_nus=u.Quantity( hdf[ @@ -347,6 +362,12 @@ def from_hdf(cls, hdf_fpath, packets_mode): ].to_numpy()[emitted_packet_mask], "Hz", ), + last_interaction_in_r=u.Quantity( + hdf[ + "/simulation/transport/transport_state/last_interaction_in_r" + ].to_numpy()[emitted_packet_mask], + "cm", + ), lines_df=lines_df, packet_nus=u.Quantity( hdf[ @@ -509,6 +530,7 @@ def _parse_species_list(self, species_list): ) else: full_species_list = [] + species_mapped = {} for species in species_list: # check if a hyphen is present. If it is, then it indicates a # range of ions. Add each ion in that range to the list as a new entry @@ -546,20 +568,20 @@ def _parse_species_list(self, species_list): # the requested ion for species in full_species_list: if " " in species: - requested_species_ids.append( - [ - species_string_to_tuple(species)[0] * 100 - + species_string_to_tuple(species)[1] - ] + species_id = ( + species_string_to_tuple(species)[0] * 100 + + species_string_to_tuple(species)[1] ) + requested_species_ids.append([species_id]) + species_mapped[species_id] = [species_id] else: atomic_number = element_symbol2atomic_number(species) - requested_species_ids.append( - [ - atomic_number * 100 + ion_number - for ion_number in np.arange(atomic_number) - ] - ) + species_ids = [ + atomic_number * 100 + ion_number + for ion_number in np.arange(atomic_number) + ] + requested_species_ids.append(species_ids) + species_mapped[atomic_number * 100] = species_ids # add the atomic number to a list so you know that this element should # have all species in the same colour, i.e. it was requested like # species_list = [Si] @@ -570,6 +592,7 @@ def _parse_species_list(self, species_list): for species_id in temp_list ] + self._species_mapped = species_mapped self._species_list = requested_species_ids self._keep_colour = keep_colour else: @@ -1692,25 +1715,6 @@ def generate_plot_ply( return self.fig - @staticmethod - def to_rgb255_string(color_tuple): - """ - Convert a matplotlib RGBA tuple to a generic RGB 255 string. - - Parameters - ---------- - color_tuple : tuple - Matplotlib RGBA tuple of float values in closed interval [0, 1] - - Returns - ------- - str - RGB string of format rgb(r,g,b) where r,g,b are integers between - 0 and 255 (both inclusive) - """ - color_tuple_255 = tuple([int(x * 255) for x in color_tuple[:3]]) - return f"rgb{color_tuple_255}" - def _plot_emission_ply(self): """Plot emission part of the SDEC Plot using plotly.""" # By specifying a common stackgroup, plotly will itself add up @@ -1767,7 +1771,7 @@ def _plot_emission_ply(self): name=species_name + " Emission", hovertemplate=f"{species_name:s} Emission
" # noqa: ISC003 + "(%{x:.2f}, %{y:.3g})", - fillcolor=self.to_rgb255_string( + fillcolor=pu.to_rgb255_string( self._color_list[species_counter] ), stackgroup="emission", @@ -1826,7 +1830,7 @@ def _plot_absorption_ply(self): name=species_name + " Absorption", hovertemplate=f"{species_name:s} Absorption
" # noqa: ISC003 + "(%{x:.2f}, %{y:.3g})", - fillcolor=self.to_rgb255_string( + fillcolor=pu.to_rgb255_string( self._color_list[species_counter] ), stackgroup="absorption", @@ -1865,7 +1869,7 @@ def _show_colorbar_ply(self): # twice in a row (https://plotly.com/python/colorscales/#constructing-a-discrete-or-discontinuous-color-scale) categorical_colorscale = [] for species_counter in range(len(self._species_name)): - color = self.to_rgb255_string( + color = pu.to_rgb255_string( self.cmap(colorscale_bins[species_counter]) ) categorical_colorscale.append( diff --git a/tardis/visualization/tools/tests/test_rpacket_plot.py b/tardis/visualization/tools/tests/test_rpacket_plot.py index 3a46d62756d..dcf8dd2f714 100755 --- a/tardis/visualization/tools/tests/test_rpacket_plot.py +++ b/tardis/visualization/tools/tests/test_rpacket_plot.py @@ -46,6 +46,7 @@ def simulation_simple(config_verysimple, atomic_dataset): return sim +@pytest.mark.rpacket_tracking class TestRPacketPlotter: """Test the RPacketPlotter class."""