Skip to content

Commit

Permalink
All tests passing, black applied
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfullard committed Feb 29, 2024
1 parent ebb4bfc commit 15ebac8
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from tardis.simulation import Simulation


@pytest.mark.skip()
def test_continuum_estimators(
continuum_config,
nlte_atomic_dataset,
Expand Down
69 changes: 50 additions & 19 deletions tardis/montecarlo/montecarlo_numba/formal_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def numba_formal_integral(
p = pp[p_idx]

# initialize z intersections for p values
size_z = populate_z(geometry, model, p, z, shell_id) # check returns
size_z = populate_z(
geometry, model, p, z, shell_id
) # check returns
# initialize I_nu
if p <= R_ph:
I_nu[p_idx] = intensity_black_body(nu * z[0], iT)
Expand Down Expand Up @@ -143,7 +145,9 @@ def numba_formal_integral(
for _ in range(max(nu_end_idx - pline, 0)):
# calculate e-scattering optical depth to next resonance point
zend = (
model.time_explosion / C_INV * (1.0 - line_list_nu[pline] / nu)
model.time_explosion
/ C_INV
* (1.0 - line_list_nu[pline] / nu)
) # check

if first == 1:
Expand Down Expand Up @@ -183,8 +187,12 @@ def numba_formal_integral(
# calculate e-scattering optical depth to grid cell boundary

Jkkp = 0.5 * (Jred_lu[pJred_lu] + Jblue_lu[pJblue_lu])
zend = model.time_explosion / C_INV * (1.0 - nu_end / nu) # check
escat_contrib += (zend - zstart) * escat_op * (Jkkp - I_nu[p_idx])
zend = (
model.time_explosion / C_INV * (1.0 - nu_end / nu)
) # check
escat_contrib += (
(zend - zstart) * escat_op * (Jkkp - I_nu[p_idx])
)
zstart = zend

# advance pointers
Expand Down Expand Up @@ -274,7 +282,9 @@ def __init__(self, simulation_state, plasma, transport, points=1000):
self.transport = transport
self.points = points
if transport:
self.montecarlo_configuration = self.transport.montecarlo_configuration
self.montecarlo_configuration = (
self.transport.montecarlo_configuration
)
if plasma:
self.plasma = opacity_state_initialize(
plasma,
Expand Down Expand Up @@ -430,7 +440,9 @@ def make_source_function(self):
if transport.line_interaction_type == "macroatom":
internal_jump_mask = (macro_data.transition_type >= 0).values
ma_int_data = macro_data[internal_jump_mask]
internal = self.original_plasma.transition_probabilities[internal_jump_mask]
internal = self.original_plasma.transition_probabilities[
internal_jump_mask
]

source_level_idx = ma_int_data.source_level_idx.values
destination_level_idx = ma_int_data.destination_level_idx.values
Expand Down Expand Up @@ -469,13 +481,17 @@ def make_source_function(self):
* Jbluelu_norm_factor
)

upper_level_index = self.atomic_data.lines.index.droplevel("level_number_lower")
upper_level_index = self.atomic_data.lines.index.droplevel(
"level_number_lower"
)
e_dot_lu = pd.DataFrame(Edotlu.values, index=upper_level_index)
e_dot_u = e_dot_lu.groupby(level=[0, 1, 2]).sum()
e_dot_u_src_idx = macro_ref.loc[e_dot_u.index].references_idx.values

if transport.line_interaction_type == "macroatom":
C_frame = pd.DataFrame(columns=np.arange(no_shells), index=macro_ref.index)
C_frame = pd.DataFrame(
columns=np.arange(no_shells), index=macro_ref.index
)
q_indices = (source_level_idx, destination_level_idx)
for shell in range(no_shells):
Q = sp.coo_matrix(
Expand All @@ -492,7 +508,8 @@ def make_source_function(self):
"source_level_number",
] # To make the q_ul e_dot_u product work, could be cleaner
transitions = self.original_plasma.atomic_data.macro_atom_data[
self.original_plasma.atomic_data.macro_atom_data.transition_type == -1
self.original_plasma.atomic_data.macro_atom_data.transition_type
== -1
].copy()
transitions_index = transitions.set_index(
["atomic_number", "ion_number", "source_level_number"]
Expand All @@ -504,9 +521,9 @@ def make_source_function(self):
t = simulation_state.time_explosion.value
t = simulation_state.time_explosion.value
lines = self.atomic_data.lines.set_index("line_id")
wave = lines.wavelength_cm.loc[transitions.transition_line_id].values.reshape(
-1, 1
)
wave = lines.wavelength_cm.loc[
transitions.transition_line_id
].values.reshape(-1, 1)
if transport.line_interaction_type == "macroatom":
e_dot_u = C_frame.loc[e_dot_u.index]
att_S_ul = wave * (q_ul * e_dot_u) * t / (4 * np.pi)
Expand All @@ -529,16 +546,24 @@ def make_source_function(self):
att_S_ul, Jredlu, Jbluelu, e_dot_u
)
else:
transport.r_inner_i = montecarlo_transport_state.geometry_state.r_inner
transport.r_outer_i = montecarlo_transport_state.geometry_state.r_outer
transport.tau_sobolevs_integ = self.original_plasma.tau_sobolevs.values
transport.r_inner_i = (
montecarlo_transport_state.geometry_state.r_inner
)
transport.r_outer_i = (
montecarlo_transport_state.geometry_state.r_outer
)
transport.tau_sobolevs_integ = (
self.original_plasma.tau_sobolevs.values
)
transport.electron_densities_integ = (
self.original_plasma.electron_densities.values
)

return att_S_ul, Jredlu, Jbluelu, e_dot_u

def interpolate_integrator_quantities(self, att_S_ul, Jredlu, Jbluelu, e_dot_u):
def interpolate_integrator_quantities(
self, att_S_ul, Jredlu, Jbluelu, e_dot_u
):
transport = self.transport
mct_state = transport.transport_state
plasma = self.original_plasma
Expand Down Expand Up @@ -577,8 +602,12 @@ def interpolate_integrator_quantities(self, att_S_ul, Jredlu, Jbluelu, e_dot_u):
Jredlu = pd.DataFrame(
interp1d(r_middle, Jredlu, fill_value="extrapolate")(r_middle_integ)
)
Jbluelu = interp1d(r_middle, Jbluelu, fill_value="extrapolate")(r_middle_integ)
e_dot_u = interp1d(r_middle, e_dot_u, fill_value="extrapolate")(r_middle_integ)
Jbluelu = interp1d(r_middle, Jbluelu, fill_value="extrapolate")(
r_middle_integ
)
e_dot_u = interp1d(r_middle, e_dot_u, fill_value="extrapolate")(
r_middle_integ
)

# Set negative values from the extrapolation to zero
att_S_ul = att_S_ul.clip(0.0)
Expand Down Expand Up @@ -624,7 +653,9 @@ def formal_integral(self, nu, N):
]
)
error = np.max(np.abs((L_test - L) / L))
assert error < 1e-7, f"Incorrect I_nu_p values, max relative difference:{error}"
assert (
error < 1e-7
), f"Incorrect I_nu_p values, max relative difference:{error}"

return np.array(L, np.float64)

Expand Down
8 changes: 6 additions & 2 deletions tardis/montecarlo/montecarlo_numba/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def verysimple_estimators(nb_simulation_verysimple):

@pytest.fixture(scope="package")
def verysimple_vpacket_collection(nb_simulation_verysimple):
spectrum_frequency = nb_simulation_verysimple.transport.spectrum_frequency.value
spectrum_frequency = (
nb_simulation_verysimple.transport.spectrum_frequency.value
)
return VPacketCollection(
source_rpacket_index=0,
spectrum_frequency=spectrum_frequency,
Expand All @@ -83,7 +85,9 @@ def verysimple_vpacket_collection(nb_simulation_verysimple):

@pytest.fixture(scope="package")
def verysimple_3vpacket_collection(nb_simulation_verysimple):
spectrum_frequency = nb_simulation_verysimple.transport.spectrum_frequency.value
spectrum_frequency = (
nb_simulation_verysimple.transport.spectrum_frequency.value
)
return VPacketCollection(
source_rpacket_index=0,
spectrum_frequency=spectrum_frequency,
Expand Down
1 change: 0 additions & 1 deletion tardis/montecarlo/montecarlo_numba/tests/test_continuum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from tardis.simulation import Simulation


@pytest.mark.skip()
def test_montecarlo_continuum(
continuum_config,
regression_data,
Expand Down
8 changes: 6 additions & 2 deletions tardis/montecarlo/montecarlo_numba/tests/test_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def test_line_scatter(
init_mu = packet.mu
init_nu = packet.nu
init_energy = packet.energy
packet.initialize_line_id(verysimple_opacity_state, verysimple_numba_model, False)
packet.initialize_line_id(
verysimple_opacity_state, verysimple_numba_model, False
)
time_explosion = verysimple_numba_model.time_explosion

interaction.line_scatter(
Expand Down Expand Up @@ -94,7 +96,9 @@ def test_line_emission(
emission_line_id = test_packet["emission_line_id"]
packet.mu = test_packet["mu"]
packet.energy = test_packet["energy"]
packet.initialize_line_id(verysimple_opacity_state, verysimple_numba_model, False)
packet.initialize_line_id(
verysimple_opacity_state, verysimple_numba_model, False
)

time_explosion = verysimple_numba_model.time_explosion

Expand Down
20 changes: 15 additions & 5 deletions tardis/montecarlo/montecarlo_numba/tests/test_numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def test_opacity_state_initialize(nb_simulation_verysimple, input_params):
continuum_processes_enabled=False,
)

npt.assert_allclose(actual.electron_density, plasma.electron_densities.values)
npt.assert_allclose(
actual.electron_density, plasma.electron_densities.values
)
npt.assert_allclose(actual.line_list_nu, plasma.atomic_data.lines.nu.values)
npt.assert_allclose(actual.tau_sobolev, plasma.tau_sobolevs.values)
if line_interaction_type == "scatter":
Expand Down Expand Up @@ -67,7 +69,9 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection):
energies = [0.4, 0.1, 0.6, 1e10]
initial_mus = [0.1, 0, 1, 0.9]
initial_rs = [3e42, 4.5e45, 0, 9.0e40]
last_interaction_in_nus = np.array([3.0e15, 0.0, 1e15, 1e5], dtype=np.float64)
last_interaction_in_nus = np.array(
[3.0e15, 0.0, 1e15, 1e5], dtype=np.float64
)
last_interaction_types = np.array([1, 1, 3, 2], dtype=np.int64)
last_interaction_in_ids = np.array([100, 0, 1, 1000], dtype=np.int64)
last_interaction_out_ids = np.array([1201, 123, 545, 1232], dtype=np.int64)
Expand Down Expand Up @@ -107,11 +111,15 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection):
)

npt.assert_array_equal(
verysimple_3vpacket_collection.nus[: verysimple_3vpacket_collection.idx],
verysimple_3vpacket_collection.nus[
: verysimple_3vpacket_collection.idx
],
nus,
)
npt.assert_array_equal(
verysimple_3vpacket_collection.energies[: verysimple_3vpacket_collection.idx],
verysimple_3vpacket_collection.energies[
: verysimple_3vpacket_collection.idx
],
energies,
)
npt.assert_array_equal(
Expand All @@ -121,7 +129,9 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection):
initial_mus,
)
npt.assert_array_equal(
verysimple_3vpacket_collection.initial_rs[: verysimple_3vpacket_collection.idx],
verysimple_3vpacket_collection.initial_rs[
: verysimple_3vpacket_collection.idx
],
initial_rs,
)
npt.assert_array_equal(
Expand Down
Loading

0 comments on commit 15ebac8

Please sign in to comment.