Skip to content

Commit

Permalink
parellel with intermediate arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
jvshields committed Dec 5, 2024
1 parent 4d0fe56 commit 7d40a11
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions stardis/radiation_field/opacities/opacities_solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def calc_alpha_line_at_nu(
use_vald_broadening=line_opacity_config.vald_linelist.use_vald_broadening
and line_opacity_config.vald_linelist.use_linelist, # don't try to use vald broadening if you don't use vald linelists at all
)
logger.info("Calculating line opacities at spectral points.")
alpha_line_at_nu = calc_alan_entries(
stellar_model.no_of_depth_points,
tracing_nus.value,
Expand Down Expand Up @@ -483,7 +484,7 @@ def calc_molecular_alpha_line_at_nu(
return alpha_line_at_nu, gammas, doppler_widths


@numba.njit
@numba.njit(parallel=True)
def calc_alan_entries(
no_of_depth_points,
tracing_nus_values,
Expand Down Expand Up @@ -523,8 +524,17 @@ def calc_alan_entries(
tracing_nus_values[0] - tracing_nus_values[1]
) # This is a bit awkward, but not sure of a better way to do it for non-uniform grids

for line_index in range(len(line_nus)):
intermediate_arrays = np.zeros(
(
numba.config.NUMBA_DEFAULT_NUM_THREADS,
no_of_depth_points,
len(tracing_nus_values),
)
)

for line_index in numba.prange(len(line_nus)):
line_nu = line_nus[line_index]
thread_id = numba.get_thread_id()
for depth_point_index in range(no_of_depth_points):
# If gamma is not for each depth point, we need to index it differently
line_gamma = (
Expand All @@ -542,8 +552,8 @@ def calc_alan_entries(

# We want to consider grid points within a certain range of the line_nu
line_broadening = (
(line_gamma + doppler_width) * alpha # Scale by alpha of the line
) / d_nu
((line_gamma + doppler_width) * alpha) / d_nu * 20
) # Scale by alpha of the line
line_broadening_range = max(10.0, line_broadening) # Force a minimum range

lower_freq_index = max(
Expand All @@ -556,14 +566,19 @@ def calc_alan_entries(

delta_nus = tracing_nus_values[lower_freq_index:upper_freq_index] - line_nu

alpha_line_at_nu[
depth_point_index, lower_freq_index:upper_freq_index
intermediate_arrays[
thread_id, depth_point_index, lower_freq_index:upper_freq_index
] += _calc_alan_entries(
delta_nus,
doppler_width,
line_gamma,
alpha,
)

# Combine the results from the intermediate arrays into the final array
for thread_id in range(numba.config.NUMBA_DEFAULT_NUM_THREADS):
alpha_line_at_nu += intermediate_arrays[thread_id]

return alpha_line_at_nu


Expand Down

0 comments on commit 7d40a11

Please sign in to comment.