Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Montecarlo configuration refactor #2525

Merged
merged 15 commits into from
Apr 19, 2024
Merged
1 change: 1 addition & 0 deletions tardis/io/model_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ def transport_from_hdf(fname):
nthreads=d["nthreads"],
enable_virtual_packet_logging=d["virt_logging"],
use_gpu=d["use_gpu"],
montecarlo_configuration=d["montecarlo_configuration"],
)

new_transport.Edotlu_estimator = d["Edotlu_estimator"]
Expand Down
12 changes: 8 additions & 4 deletions tardis/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def no_of_raw_shells(self):
return self.geometry.no_of_shells

@classmethod
def from_config(cls, config, atom_data):
def from_config(cls, config, atom_data, legacy_mode_enabled=False):
"""
Create a new SimulationState instance from a Configuration object.

Expand Down Expand Up @@ -269,7 +269,9 @@ def from_config(cls, config, atom_data):
atom_data.atom_data.mass.copy(),
)

packet_source = parse_packet_source(config, geometry)
packet_source = parse_packet_source(
config, geometry, legacy_mode_enabled
)
radiation_field_state = parse_radiation_field_state(
config,
t_radiative,
Expand All @@ -288,7 +290,7 @@ def from_config(cls, config, atom_data):
)

@classmethod
def from_csvy(cls, config, atom_data=None):
def from_csvy(cls, config, atom_data=None, legacy_mode_enabled=False):
"""
Create a new SimulationState instance from a Configuration object.

Expand Down Expand Up @@ -366,7 +368,9 @@ def from_csvy(cls, config, atom_data=None):
geometry,
)

packet_source = parse_packet_source(config, geometry)
packet_source = parse_packet_source(
config, geometry, legacy_mode_enabled
)

radiation_field_state = parse_csvy_radiation_field_state(
config, csvy_model_config, csvy_model_data, geometry, packet_source
Expand Down
22 changes: 17 additions & 5 deletions tardis/model/parse_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,9 @@ def parse_radiation_field_state(
)


def initialize_packet_source(config, geometry, packet_source):
def initialize_packet_source(
config, geometry, packet_source, legacy_mode_enabled
):
"""
Initialize the packet source based on config and geometry

Expand All @@ -613,9 +615,13 @@ def initialize_packet_source(config, geometry, packet_source):
packet_source = BlackBodySimpleSourceRelativistic(
base_seed=config.montecarlo.seed,
time_explosion=config.supernova.time_explosion,
legacy_mode_enabled=legacy_mode_enabled,
)
else:
packet_source = BlackBodySimpleSource(base_seed=config.montecarlo.seed)
packet_source = BlackBodySimpleSource(
base_seed=config.montecarlo.seed,
legacy_mode_enabled=legacy_mode_enabled,
)

luminosity_requested = config.supernova.luminosity_requested
if config.plasma.initial_t_inner > 0.0 * u.K:
Expand All @@ -635,7 +641,7 @@ def initialize_packet_source(config, geometry, packet_source):
return packet_source


def parse_packet_source(config, geometry):
def parse_packet_source(config, geometry, legacy_mode_enabled):
"""
Parse the packet source based on the given configuration and geometry.

Expand All @@ -655,11 +661,17 @@ def parse_packet_source(config, geometry):
packet_source = BlackBodySimpleSourceRelativistic(
base_seed=config.montecarlo.seed,
time_explosion=config.supernova.time_explosion,
legacy_mode_enabled=legacy_mode_enabled,
)
else:
packet_source = BlackBodySimpleSource(base_seed=config.montecarlo.seed)
packet_source = BlackBodySimpleSource(
base_seed=config.montecarlo.seed,
legacy_mode_enabled=legacy_mode_enabled,
)

return initialize_packet_source(config, geometry, packet_source)
return initialize_packet_source(
config, geometry, packet_source, legacy_mode_enabled
)


def parse_csvy_radiation_field_state(
Expand Down
25 changes: 18 additions & 7 deletions tardis/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from tardis import constants as const
from tardis.io.logger import montecarlo_tracking as mc_tracker
from tardis.io.util import HDFWriterMixin
from tardis.montecarlo import montecarlo_configuration
from tardis.montecarlo.estimators.radfield_mc_estimators import (
initialize_estimator_statistics,
)
from tardis.montecarlo.montecarlo_configuration import (
MonteCarloConfiguration,
configuration_initialize,
)
from tardis.montecarlo.montecarlo_numba import (
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(
debug_packets=False,
logger_buffer=1,
use_gpu=False,
montecarlo_configuration=None,
):
# inject different packets
self.disable_electron_scattering = disable_electron_scattering
Expand All @@ -86,6 +87,7 @@ def __init__(

self.enable_vpacket_tracking = enable_virtual_packet_logging
self.enable_rpacket_tracking = enable_rpacket_tracking
self.montecarlo_configuration = montecarlo_configuration

self.packet_source = packet_source

Expand Down Expand Up @@ -124,7 +126,10 @@ def initialize_transport_state(

geometry_state = simulation_state.geometry.to_numba()
opacity_state = opacity_state_initialize(
plasma, self.line_interaction_type
plasma,
self.line_interaction_type,
self.montecarlo_configuration.DISABLE_LINE_SCATTERING,
self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
)
transport_state = MonteCarloTransportState(
packet_collection,
Expand All @@ -139,7 +144,9 @@ def initialize_transport_state(
transport_state._integrator = FormalIntegrator(
simulation_state, plasma, self
)
configuration_initialize(self, no_of_virtual_packets)
configuration_initialize(
self.montecarlo_configuration, self, no_of_virtual_packets
)

return transport_state

Expand Down Expand Up @@ -172,7 +179,7 @@ def run(

numba_model = NumbaModel(time_explosion.to("s").value)

number_of_vpackets = montecarlo_configuration.NUMBER_OF_VPACKETS
number_of_vpackets = self.montecarlo_configuration.NUMBER_OF_VPACKETS

(
v_packets_energy_hist,
Expand All @@ -184,6 +191,7 @@ def run(
transport_state.geometry_state,
numba_model,
transport_state.opacity_state,
self.montecarlo_configuration,
transport_state.radfield_mc_estimators,
transport_state.spectrum_frequency.value,
number_of_vpackets,
Expand All @@ -208,15 +216,15 @@ def run(
last_interaction_tracker.shell_ids
)

if montecarlo_configuration.ENABLE_VPACKET_TRACKING and (
if self.montecarlo_configuration.ENABLE_VPACKET_TRACKING and (
number_of_vpackets > 0
):
transport_state.vpacket_tracker = vpacket_tracker

update_iterations_pbar(1)
refresh_packet_pbar()
# Condition for Checking if RPacket Tracking is enabled
if montecarlo_configuration.ENABLE_RPACKET_TRACKING:
if self.montecarlo_configuration.ENABLE_RPACKET_TRACKING:
transport_state.rpacket_tracker = rpacket_trackers

if self.transport_state.rpacket_tracker is not None:
Expand All @@ -226,7 +234,7 @@ def run(
)
)
transport_state.virt_logging = (
montecarlo_configuration.ENABLE_VPACKET_TRACKING
self.montecarlo_configuration.ENABLE_VPACKET_TRACKING
)

def legacy_return(self):
Expand Down Expand Up @@ -300,6 +308,8 @@ def from_config(
valid values are 'GPU', 'CPU', and 'Automatic'."""
)

montecarlo_configuration = MonteCarloConfiguration()

montecarlo_configuration.DISABLE_LINE_SCATTERING = (
config.plasma.disable_line_scattering
)
Expand Down Expand Up @@ -329,4 +339,5 @@ def from_config(
enable_rpacket_tracking=config.montecarlo.tracking.track_rpacket,
nthreads=config.montecarlo.nthreads,
use_gpu=use_gpu,
montecarlo_configuration=montecarlo_configuration,
)
120 changes: 120 additions & 0 deletions tardis/montecarlo/estimators/radfield_estimator_calcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from math import exp

from numba import njit

from tardis.montecarlo.montecarlo_numba import (
njit_dict_no_parallel,
)
from tardis.montecarlo.montecarlo_numba.numba_config import KB, H
from tardis.transport.frame_transformations import (
calc_packet_energy,
calc_packet_energy_full_relativity,
)


@njit(**njit_dict_no_parallel)
def update_base_estimators(
r_packet, distance, estimator_state, comov_nu, comov_energy
):
"""
Updating the estimators
"""
estimator_state.j_estimator[r_packet.current_shell_id] += (
comov_energy * distance
)
estimator_state.nu_bar_estimator[r_packet.current_shell_id] += (
comov_energy * distance * comov_nu
)


@njit(**njit_dict_no_parallel)
def update_bound_free_estimators(
comov_nu,
comov_energy,
shell_id,
distance,
estimator_state,
t_electron,
x_sect_bfs,
current_continua,
bf_threshold_list_nu,
):
"""
Update the estimators for bound-free processes.

Parameters
----------
comov_nu : float
comov_energy : float
shell_id : int
distance : float
numba_estimator : tardis.montecarlo.montecarlo_numba.numba_interface.Estimators
t_electron : float
Electron temperature in the current cell.
x_sect_bfs : numpy.ndarray, dtype float
Photoionization cross-sections of all bound-free continua for
which absorption is possible for frequency `comov_nu`.
current_continua : numpy.ndarray, dtype int
Continuum ids for which absorption is possible for frequency `comov_nu`.
bf_threshold_list_nu : numpy.ndarray, dtype float
Threshold frequencies for photoionization sorted by decreasing frequency.
"""
# TODO: Add full relativity mode
boltzmann_factor = exp(-(H * comov_nu) / (KB * t_electron))
for i, current_continuum in enumerate(current_continua):
photo_ion_rate_estimator_increment = (
comov_energy * distance * x_sect_bfs[i] / comov_nu
)
estimator_state.photo_ion_estimator[
current_continuum, shell_id
] += photo_ion_rate_estimator_increment
estimator_state.stim_recomb_estimator[current_continuum, shell_id] += (
photo_ion_rate_estimator_increment * boltzmann_factor
)
estimator_state.photo_ion_estimator_statistics[
current_continuum, shell_id
] += 1

nu_th = bf_threshold_list_nu[current_continuum]
bf_heating_estimator_increment = (
comov_energy * distance * x_sect_bfs[i] * (1 - nu_th / comov_nu)
)
estimator_state.bf_heating_estimator[
current_continuum, shell_id
] += bf_heating_estimator_increment
estimator_state.stim_recomb_cooling_estimator[
current_continuum, shell_id
] += (bf_heating_estimator_increment * boltzmann_factor)


@njit(**njit_dict_no_parallel)
def update_line_estimators(
radfield_mc_estimators,
r_packet,
cur_line_id,
distance_trace,
time_explosion,
enable_full_relativity,
):
"""
Function to update the line estimators

Parameters
----------
estimators : tardis.montecarlo.montecarlo_numba.numba_interface.Estimators
r_packet : tardis.montecarlo.montecarlo_numba.r_packet.RPacket
cur_line_id : int
distance_trace : float
time_explosion : float
"""
if not enable_full_relativity:
energy = calc_packet_energy(r_packet, distance_trace, time_explosion)
else:
energy = calc_packet_energy_full_relativity(r_packet)

radfield_mc_estimators.j_blue_estimator[
cur_line_id, r_packet.current_shell_id
] += (energy / r_packet.nu)
radfield_mc_estimators.Edotlu_estimator[
cur_line_id, r_packet.current_shell_id
] += energy
Loading
Loading