Skip to content

Commit

Permalink
Avoid repeated use of np.append
Browse files Browse the repository at this point in the history
  • Loading branch information
austinschneider committed Aug 29, 2024
1 parent 45bf92a commit 5164560
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions resources/Processes/DarkNewsTables/DarkNewsCrossSection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import functools
from scipy.interpolate import LinearNDInterpolator, PchipInterpolator
from typing import List, Tuple

from siren import _util

Expand All @@ -17,6 +18,7 @@
# DarkNews methods
from DarkNews import phase_space


# A class representing a single ups_case DarkNews class
# Only handles methods concerning the upscattering part
class PyDarkNewsCrossSection(DarkNewsCrossSection):
Expand Down Expand Up @@ -51,9 +53,7 @@ def load_from_table(self, table_dir):
total_xsec_file = os.path.join(table_dir, "total_cross_sections.npy")
if os.path.exists(total_xsec_file):
self.total_cross_section_table = np.load(total_xsec_file)
diff_xsec_file = os.path.join(
table_dir, "differential_cross_sections.npy"
)
diff_xsec_file = os.path.join(table_dir, "differential_cross_sections.npy")
if os.path.exists(diff_xsec_file):
self.differential_cross_section_table = np.load(diff_xsec_file)

Expand All @@ -62,9 +62,7 @@ def load_from_table(self, table_dir):
def save_to_table(self, table_dir, total=True, diff=True):
if total:
self._redefine_interpolation_objects(total=True)
with open(
os.path.join(table_dir, "total_cross_sections.npy"), "wb"
) as f:
with open(os.path.join(table_dir, "total_cross_sections.npy"), "wb") as f:
np.save(f, self.total_cross_section_table)
if diff:
self._redefine_interpolation_objects(diff=True)
Expand All @@ -91,7 +89,6 @@ def get_representation(self):
# tolerance, interp_tolerance, always_interpolate
# kwargs argument can be used to set any of these
def configure(self, **kwargs):

for k, v in kwargs.items():
self.__setattr__(k, v)

Expand Down Expand Up @@ -254,22 +251,22 @@ def _query_interpolation_table(self, inputs, mode):
else:
return -1

def FillTableAtEnergy(self, E, total=True, diff=True, factor=0.8):
def FillTableAtEnergy(
self, E: float, total: bool = True, diff: bool = True, factor: float = 0.8
) -> int:
num_added_points = 0
new_total_points: List[Tuple[float, float]] = []
new_diff_points: List[Tuple[float, float, float]] = []

if total:
xsec = self.ups_case.total_xsec(E)
self.total_cross_section_table = np.append(
self.total_cross_section_table, [[E, xsec]], axis=0
)
new_total_points.append((E, xsec))
num_added_points += 1

if diff:
interaction = dataclasses.InteractionRecord()
interaction.signature.primary_type = self.GetPossiblePrimaries()[
0
] # only one primary
interaction.signature.target_type = self.GetPossibleTargets()[
0
] # only one target
interaction.signature.primary_type = self.GetPossiblePrimaries()[0]
interaction.signature.target_type = self.GetPossibleTargets()[0]
interaction.target_mass = self.ups_case.MA
interaction.primary_momentum = [E, 0, 0, 0]
zmin, zmax = self.tolerance, 1
Expand All @@ -279,13 +276,19 @@ def FillTableAtEnergy(self, E, total=True, diff=True, factor=0.8):
while z < zmax:
Q2 = Q2min + z * (Q2max - Q2min)
dxsec = self.ups_case.diff_xsec_Q2(E, Q2).item()
self.differential_cross_section_table = np.append(
self.differential_cross_section_table,
[[E, z, dxsec]],
axis=0,
)
new_diff_points.append((E, z, dxsec))
num_added_points += 1
z *= 1 + factor * self.interp_tolerance

if new_total_points:
self.total_cross_section_table = np.vstack(
(self.total_cross_section_table, new_total_points)
)
if new_diff_points:
self.differential_cross_section_table = np.vstack(
(self.differential_cross_section_table, new_diff_points)
)

self._redefine_interpolation_objects(total=total, diff=diff)
return num_added_points

Expand Down Expand Up @@ -473,8 +476,8 @@ def TotalCrossSection(self, arg1, energy=None, target=None):

# If we have reached this block, we must compute the cross section using DarkNews
xsec = self.ups_case.total_xsec(energy)
self.total_cross_section_table = np.append(
self.total_cross_section_table, [[energy, xsec]], axis=0
self.total_cross_section_table = np.vstack(
(self.total_cross_section_table, [[energy, xsec]])
)
self._redefine_interpolation_objects(total=True)
return xsec
Expand Down

0 comments on commit 5164560

Please sign in to comment.