Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attributes specific to line interaction in the Tracker classes. #2662

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
216fec0
Create and Initialize RPacketLastInteractionClass
Sumit112192 Jun 9, 2024
2978a53
Add default value test
Sumit112192 Jun 10, 2024
b9fbbaf
Add basic tests
Sumit112192 Jun 10, 2024
6a520f0
Fix some typos
Sumit112192 Jun 10, 2024
6bd9339
Compare all elements of numpy array
Sumit112192 Jun 10, 2024
51b0002
Write more tests using actual data
Sumit112192 Jun 10, 2024
abe4f6d
Resolve some issues
Sumit112192 Jun 11, 2024
8a48791
Run Black
Sumit112192 Jun 11, 2024
1fa863f
Break the test to seperate out fixtures from the acutal test
Sumit112192 Jun 11, 2024
cf9e3ee
Change the way last interaction is updated
Sumit112192 Jun 14, 2024
5311c94
Use either RPacketTracker or RPacketLastInteractionTracker
Sumit112192 Jun 14, 2024
48e70c9
Store both trackers in same variable using if else
Sumit112192 Jun 14, 2024
3d4412a
Remove duplicate assignment
Sumit112192 Jun 14, 2024
593ee05
Fix a Typo
Sumit112192 Jun 14, 2024
ff61a2e
Add a flag for last interaction and update the tests
Sumit112192 Jun 16, 2024
2632443
Update schemas to include Last Interaction Flag
Sumit112192 Jun 16, 2024
f288f23
Explicitly set track_rpacket to False
Sumit112192 Jun 16, 2024
e829d0d
Fix a small typo
Sumit112192 Jun 17, 2024
93a1c1b
Not modifying config fixture since other tests might use it
Sumit112192 Jun 17, 2024
1fb5819
Remove config flag for RPacketLastInteractionTracker
Sumit112192 Jun 19, 2024
b4ebac5
Remove unnecessary imports
Sumit112192 Jun 19, 2024
e6c6e2d
Remove unnecessary imports
Sumit112192 Jun 19, 2024
be1a349
Add line interaction attributes to Tracker Classes
Sumit112192 Jun 19, 2024
301b1b0
Update Line interaction attributes using the rpacket
Sumit112192 Jun 20, 2024
2d15a4a
Remove Reassignment of Variables
Sumit112192 Jun 20, 2024
9cbe537
Remove Reassignment of Variables
Sumit112192 Jun 20, 2024
efa97be
Distinct last interaction tracker classes as old and new
Sumit112192 Jun 20, 2024
26686b8
Remove Outdated comment
Sumit112192 Jun 20, 2024
4367f7f
Add Parametrize Test
Sumit112192 Jun 21, 2024
b90e7ff
Remove Old tests
Sumit112192 Jun 21, 2024
b390fa7
Add shell_id test
Sumit112192 Jun 24, 2024
2141bad
Fix some typos
Sumit112192 Jun 24, 2024
bf44d33
Fix typo
Sumit112192 Jun 24, 2024
33f62c7
Add tests and segregate them into appropriate fixtures and test funct…
Sumit112192 Jun 24, 2024
5797157
tracker bug fix
Sumit112192 Jun 24, 2024
55a89d1
Add Enum to tests
Sumit112192 Jun 25, 2024
02b845a
Add Enum to packet_trackers
Sumit112192 Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def run(
last_interaction_tracker,
vpacket_tracker,
rpacket_trackers,
rpacket_last_interaction_trackers,
) = montecarlo_main_loop(
transport_state.packet_collection,
transport_state.geometry_state,
Expand Down Expand Up @@ -212,13 +213,14 @@ def run(
# Condition for Checking if RPacket Tracking is enabled
if self.montecarlo_configuration.ENABLE_RPACKET_TRACKING:
transport_state.rpacket_tracker = rpacket_trackers

if self.transport_state.rpacket_tracker is not None:
self.transport_state.rpacket_tracker_df = (
rpacket_trackers_to_dataframe(
self.transport_state.rpacket_tracker
)
)
else:
transport_state.rpacket_tracker = rpacket_last_interaction_trackers

transport_state.virt_logging = (
self.montecarlo_configuration.ENABLE_VPACKET_TRACKING
)
Expand Down
12 changes: 11 additions & 1 deletion tardis/transport/montecarlo/montecarlo_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from numba.typed import List

from tardis.transport.montecarlo import njit_dict
from tardis.transport.montecarlo.packet_trackers import RPacketTracker
from tardis.transport.montecarlo.packet_trackers import (
RPacketTracker,
RPacketLastInteractionTracker,
)
from tardis.transport.montecarlo.packet_collections import (
VPacketCollection,
consolidate_vpacket_tracker,
Expand Down Expand Up @@ -72,6 +75,7 @@ def montecarlo_main_loop(
vpacket_collections = List()
# Configuring the Tracking for R_Packets
rpacket_trackers = List()
rpacket_last_interaction_trackers = List()
for i in range(no_of_packets):
vpacket_collections.append(
VPacketCollection(
Expand All @@ -88,6 +92,9 @@ def montecarlo_main_loop(
montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH
)
)
rpacket_last_interaction_trackers.append(
RPacketLastInteractionTracker()
)

# Get the ID of the main thread and the number of threads
main_thread_id = get_thread_id()
Expand Down Expand Up @@ -130,6 +137,7 @@ def montecarlo_main_loop(

# RPacket Tracker for this thread
rpacket_tracker = rpacket_trackers[i]
rpacket_last_interaction_tracker = rpacket_last_interaction_trackers[i]

loop = single_packet_loop(
r_packet,
Expand All @@ -139,6 +147,7 @@ def montecarlo_main_loop(
local_estimators,
vpacket_collection,
rpacket_tracker,
rpacket_last_interaction_tracker,
montecarlo_configuration,
)
packet_collection.output_nus[i] = r_packet.nu
Expand Down Expand Up @@ -194,4 +203,5 @@ def montecarlo_main_loop(
last_interaction_tracker,
vpacket_tracker,
rpacket_trackers,
rpacket_last_interaction_trackers,
)
124 changes: 123 additions & 1 deletion tardis/transport/montecarlo/packet_trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import pandas as pd

from tardis.transport.montecarlo.r_packet import InteractionType

rpacket_tracker_spec = [
("length", int64),
("seed", int64),
Expand All @@ -14,6 +16,9 @@
("energy", float64[:]),
("shell_id", int64[:]),
("interaction_type", int64[:]),
("interaction_in_line_nu", float64[:]),
("interaction_in_line_id", int64[:]),
("interaction_out_line_id", int64[:]),
("num_interactions", int64),
]

Expand All @@ -35,7 +40,7 @@ class RPacketTracker(object):
r : float
Radius of the shell where the RPacket is present
nu : float
Luminosity of the RPacket
Frequency of the RPacket
mu : float
Cosine of the angle made by the direction of movement of the RPacket from its original direction
energy : float
Expand All @@ -44,6 +49,12 @@ class RPacketTracker(object):
Current Shell No in which the RPacket is present
interaction_type: int
Type of interaction the rpacket undergoes
interaction_in_line_nu : float
frequency corresponding to the absroption line
interaction_in_line_id : int
Id of the absorption line
interaction_out_line_id : int
Id of the transmission line
num_interactions : int
Internal counter for the interactions that a particular RPacket undergoes
"""
Expand All @@ -59,6 +70,9 @@ def __init__(self, length):
self.energy = np.empty(self.length, dtype=np.float64)
self.shell_id = np.empty(self.length, dtype=np.int64)
self.interaction_type = np.empty(self.length, dtype=np.int64)
self.interaction_in_line_nu = np.empty(self.length, dtype=np.float64)
self.interaction_in_line_id = np.empty(self.length, dtype=np.int64)
self.interaction_out_line_id = np.empty(self.length, dtype=np.int64)
self.num_interactions = 0

def track(self, r_packet):
Expand All @@ -71,6 +85,11 @@ def track(self, r_packet):
temp_energy = np.empty(temp_length, dtype=np.float64)
temp_shell_id = np.empty(temp_length, dtype=np.int64)
temp_interaction_type = np.empty(temp_length, dtype=np.int64)
temp_interaction_in_line_nu = np.empty(
temp_length, dtype=np.float64
)
temp_interaction_in_line_id = np.empty(temp_length, dtype=np.int64)
temp_interaction_out_line_id = np.empty(temp_length, dtype=np.int64)

temp_status[: self.length] = self.status
temp_r[: self.length] = self.r
Expand All @@ -79,6 +98,15 @@ def track(self, r_packet):
temp_energy[: self.length] = self.energy
temp_shell_id[: self.length] = self.shell_id
temp_interaction_type[: self.length] = self.interaction_type
temp_interaction_in_line_nu[
: self.length
] = self.interaction_in_line_nu
temp_interaction_in_line_id[
: self.length
] = self.interaction_in_line_id
temp_interaction_out_line_id[
: self.length
] = self.interaction_out_line_id

self.status = temp_status
self.r = temp_r
Expand All @@ -87,6 +115,9 @@ def track(self, r_packet):
self.energy = temp_energy
self.shell_id = temp_shell_id
self.interaction_type = temp_interaction_type
self.interaction_in_line_nu = temp_interaction_in_line_nu
self.interaction_in_line_id = temp_interaction_in_line_id
self.interaction_out_line_id = temp_interaction_out_line_id
self.length = temp_length

self.index = r_packet.index
Expand All @@ -100,6 +131,21 @@ def track(self, r_packet):
self.interaction_type[
self.num_interactions
] = r_packet.last_interaction_type
# Only set if last interaction is line interaction, else -1 or 0.
if r_packet.last_interaction_type == InteractionType.LINE:
self.interaction_in_line_nu[
self.num_interactions
] = r_packet.last_interaction_in_nu
self.interaction_in_line_id[
self.num_interactions
] = r_packet.last_line_interaction_in_id
self.interaction_out_line_id[
self.num_interactions
] = r_packet.last_line_interaction_out_id
else:
self.interaction_in_line_nu[self.num_interactions] = 0.0
self.interaction_in_line_id[self.num_interactions] = -1
self.interaction_out_line_id[self.num_interactions] = -1
self.num_interactions += 1

def finalize_array(self):
Expand All @@ -110,6 +156,15 @@ def finalize_array(self):
self.energy = self.energy[: self.num_interactions]
self.shell_id = self.shell_id[: self.num_interactions]
self.interaction_type = self.interaction_type[: self.num_interactions]
self.interaction_in_line_nu = self.interaction_in_line_nu[
: self.num_interactions
]
self.interaction_in_line_id = self.interaction_in_line_id[
: self.num_interactions
]
self.interaction_out_line_id = self.interaction_out_line_id[
: self.num_interactions
]


def rpacket_trackers_to_dataframe(rpacket_trackers):
Expand Down Expand Up @@ -156,3 +211,70 @@ def rpacket_trackers_to_dataframe(rpacket_trackers):
index=pd.MultiIndex.from_arrays(index_array, names=["index", "step"]),
columns=df_dtypes.names,
)


rpacket_last_interaction_tracker_spec = [
("index", int64),
("r", float64),
("nu", float64),
("energy", float64),
("shell_id", int64),
("interaction_type", int64),
("interaction_in_line_nu", float64),
("interaction_in_line_id", int64),
("interaction_out_line_id", int64),
]


@jitclass(rpacket_last_interaction_tracker_spec)
class RPacketLastInteractionTracker(object):
"""
Numba JITCLASS for storing the last interaction the RPacket undergoes.
Parameters
----------
index : int
Index position of each RPacket
r : float
Radius of the shell where the RPacket is present
nu : float
Frequency of the RPacket
energy : float
Energy possessed by the RPacket
shell_id : int
Current Shell No in which the last interaction happened
interaction_type: int
Type of interaction the rpacket undergoes
interaction_in_line_nu : float
frequency corresponding to the absroption line
interaction_in_line_id : int
Id of the absorption line
interaction_out_line_id : int
Id of the transmission line
"""

def __init__(self):
self.index = -1
self.r = -1.0
self.nu = 0.0
self.energy = 0.0
self.shell_id = -1
self.interaction_type = -1
self.interaction_in_line_id = -1
self.interaction_out_line_id = -1
self.interaction_in_line_nu = 0.0

def track(self, r_packet):
self.index = r_packet.index
self.r = r_packet.r
self.nu = r_packet.nu
self.energy = r_packet.energy
self.shell_id = r_packet.current_shell_id
self.interaction_type = r_packet.last_interaction_type
if r_packet.last_interaction_type == InteractionType.LINE:
self.interaction_in_line_nu = r_packet.last_interaction_in_nu
self.interaction_in_line_id = r_packet.last_line_interaction_in_id
self.interaction_out_line_id = r_packet.last_line_interaction_out_id
else:
self.interaction_in_line_nu = 0.0
self.interaction_in_line_id = -1
self.interaction_out_line_id = -1
10 changes: 8 additions & 2 deletions tardis/transport/montecarlo/single_packet_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def single_packet_loop(
estimators,
vpacket_collection,
rpacket_tracker,
rpacket_last_interaction_tracker,
montecarlo_configuration,
):
"""
Expand Down Expand Up @@ -85,6 +86,8 @@ def single_packet_loop(

if montecarlo_configuration.ENABLE_RPACKET_TRACKING:
rpacket_tracker.track(r_packet)
else:
rpacket_last_interaction_tracker.track(r_packet)

# this part of the code is temporary and will be better incorporated
while r_packet.status == PacketStatus.IN_PROCESS:
Expand Down Expand Up @@ -264,8 +267,11 @@ def single_packet_loop(
)
else:
pass
if montecarlo_configuration.ENABLE_RPACKET_TRACKING:
rpacket_tracker.track(r_packet)
if interaction_type != InteractionType.BOUNDARY:
if montecarlo_configuration.ENABLE_RPACKET_TRACKING:
rpacket_tracker.track(r_packet)
else:
rpacket_last_interaction_tracker.track(r_packet)


@njit
Expand Down
33 changes: 21 additions & 12 deletions tardis/transport/montecarlo/tests/test_r_packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,42 @@
import numpy.testing as npt
import pytest

from tardis.io.configuration.config_reader import Configuration
from tardis.base import run_tardis
from tardis.transport.montecarlo.packet_trackers import (
rpacket_trackers_to_dataframe,
)


@pytest.fixture(scope="module")
def simulation_rpacket_tracking_enabled(config_verysimple, atomic_dataset):
config_verysimple.montecarlo.iterations = 3
config_verysimple.montecarlo.no_of_packets = 4000
config_verysimple.montecarlo.last_no_of_packets = -1
config_verysimple.spectrum.virtual.virtual_packet_logging = True
config_verysimple.montecarlo.no_of_virtual_packets = 1
config_verysimple.montecarlo.tracking.track_rpacket = True
config_verysimple.spectrum.num = 2000
@pytest.fixture()
def config_rpacket_tracker(example_configuration_dir):
"""Config object for rpacket tracker"""
return Configuration.from_yaml(
example_configuration_dir / "tardis_configv1_verysimple.yml"
)


@pytest.fixture()
def simulation_rpacket_tracking_enabled(config_rpacket_tracker, atomic_dataset):
"""Simulation object with track_rpacket enabled"""
config_rpacket_tracker.montecarlo.iterations = 3
config_rpacket_tracker.montecarlo.no_of_packets = 4000
config_rpacket_tracker.montecarlo.last_no_of_packets = -1
config_rpacket_tracker.montecarlo.tracking.track_rpacket = True
config_rpacket_tracker.spectrum.num = 2000
atomic_data = deepcopy(atomic_dataset)
sim = run_tardis(
config_verysimple,
config_rpacket_tracker,
atom_data=atomic_data,
show_convergence_plots=False,
)
return sim


def test_rpacket_trackers_to_dataframe(simulation_rpacket_tracking_enabled):
sim = simulation_rpacket_tracking_enabled
transport_state = sim.transport.transport_state
transport_state = (
simulation_rpacket_tracking_enabled.transport.transport_state
)
rtracker_df = rpacket_trackers_to_dataframe(transport_state.rpacket_tracker)

# check df shape and column names
Expand Down
Loading
Loading