From 0e5f105b712b6ea8233144604239435872526041 Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Wed, 15 Jan 2025 13:36:07 -0500 Subject: [PATCH] Enable option to use ASSIST integrate instead of integrate_or_interpolate --- src/sorcha/ephemeris/pixel_dict.py | 37 +++++++++++--------- src/sorcha/ephemeris/simulation_driver.py | 18 +++++----- src/sorcha/ephemeris/simulation_geometry.py | 15 +++++--- src/sorcha/ephemeris/simulation_setup.py | 37 ++++++++------------ src/sorcha/utilities/sorchaConfigs.py | 12 +++++-- tests/ephemeris/test_ephemeris_generation.py | 18 ++++++++++ 6 files changed, 81 insertions(+), 56 deletions(-) diff --git a/src/sorcha/ephemeris/pixel_dict.py b/src/sorcha/ephemeris/pixel_dict.py index 157eec8c..c93942ce 100644 --- a/src/sorcha/ephemeris/pixel_dict.py +++ b/src/sorcha/ephemeris/pixel_dict.py @@ -1,11 +1,11 @@ -import numpy as np +from collections import defaultdict + import healpy as hp import numba +import numpy as np -from collections import defaultdict - -from sorcha.ephemeris.simulation_geometry import * from sorcha.ephemeris.simulation_constants import * +from sorcha.ephemeris.simulation_geometry import * @numba.njit(fastmath=True) @@ -62,6 +62,7 @@ def __init__( nside=128, nested=True, n_sub_intervals=101, + use_integrate=False, ): """ Initialization function for the class. Computes the initial positions required for the ephemerides interpolation @@ -87,6 +88,8 @@ def __init__( Defines the ordering scheme for the healpix ordering. True (default) means a NESTED ordering n_sub_intervals: int Number of sub-intervals for the Lagrange interpolation (default: 101) + use_integrate: boolean + Whether to use the integrator to compute the ephemerides (default: False) """ self.nside = nside self.picket_interval = picket_interval @@ -96,7 +99,7 @@ def __init__( self.sim_dict = sim_dict self.ephem = ephem self.observatory = observatory - + self.use_integrate = use_integrate # Set the three times and compute the observatory position # at those times # Using a quadratic isn't very general, but that can be @@ -115,9 +118,9 @@ def __init__( self.pixel_dict = defaultdict(list) - self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm) - self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0) - self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp) + self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm, use_integrate=self.use_integrate) + self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0, use_integrate=self.use_integrate) + self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp, use_integrate=self.use_integrate) self.compute_pixel_traversed() @@ -138,7 +141,7 @@ def get_observatory_position(self, t): r_obs = self.observatory.barycentricObservatory(et, self.obsCode) / AU_KM return r_obs - def get_object_unit_vectors(self, desigs, r_obs, t, lt0=0.01): + def get_object_unit_vectors(self, desigs, r_obs, t, lt0=0.01, use_integrate=False): """ Computes the unit vector (in the equatorial sphere) that point towards the object - observatory vector for a list of objects, at a given time @@ -165,13 +168,13 @@ def get_object_unit_vectors(self, desigs, r_obs, t, lt0=0.01): # Get the topocentric unit vectors rho, rho_mag, lt, r_ast, v_ast = integrate_light_time( - sim, ex, t - self.ephem.jd_ref, r_obs, lt0=lt0 + sim, ex, t - self.ephem.jd_ref, r_obs, lt0=lt0, use_integrate=use_integrate ) rho_hat = rho / rho_mag rho_hat_dict[k] = rho_hat return rho_hat_dict - def get_all_object_unit_vectors(self, r_obs, t, lt0=0.01): + def get_all_object_unit_vectors(self, r_obs, t, lt0=0.01, use_integrate=False): """ Computes the unit vector (in the equatorial sphere) that point towards the object - observatory vector for *all* objects, at a given time @@ -191,7 +194,7 @@ def get_all_object_unit_vectors(self, r_obs, t, lt0=0.01): """ desigs = self.sim_dict.keys() - return self.get_object_unit_vectors(desigs, r_obs, t, lt0=lt0) + return self.get_object_unit_vectors(desigs, r_obs, t, lt0=lt0, use_integrate=use_integrate) def get_interp_factors(self, tm, t0, tp, n_sub_intervals): """ @@ -313,7 +316,7 @@ def update_pickets(self, jd_tdb): self.tm = self.t0 - self.picket_interval self.r_obs_m = self.get_observatory_position(self.tm) - self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm) + self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm, use_integrate=self.use_integrate) else: # shift later @@ -327,7 +330,7 @@ def update_pickets(self, jd_tdb): self.tp = self.t0 + self.picket_interval self.r_obs_p = self.get_observatory_position(self.tp) - self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp) + self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp, use_integrate=self.use_integrate) else: # Need to compute three new sets @@ -336,15 +339,15 @@ def update_pickets(self, jd_tdb): # This is repeated code self.t0 += n * self.picket_interval self.r_obs_0 = self.get_observatory_position(self.t0) - self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0) + self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0, use_integrate=self.use_integrate) self.tp = self.t0 + self.picket_interval self.r_obs_p = self.get_observatory_position(self.tp) - self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp) + self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp, use_integrate=self.use_integrate) self.tm = self.t0 - self.picket_interval self.r_obs_m = self.get_observatory_position(self.tm) - self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm) + self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm, use_integrate=self.use_integrate) self.compute_pixel_traversed() else: diff --git a/src/sorcha/ephemeris/simulation_driver.py b/src/sorcha/ephemeris/simulation_driver.py index 0417522d..748fb08b 100644 --- a/src/sorcha/ephemeris/simulation_driver.py +++ b/src/sorcha/ephemeris/simulation_driver.py @@ -1,23 +1,22 @@ -from dataclasses import dataclass from collections import defaultdict from csv import writer +from dataclasses import dataclass from io import StringIO import numpy as np import pandas as pd import spiceypy as spice -from sorcha.ephemeris.simulation_setup import ( - create_assist_ephemeris, - furnish_spiceypy, - generate_simulations, -) +from sorcha.ephemeris.pixel_dict import PixelDict from sorcha.ephemeris.simulation_constants import * from sorcha.ephemeris.simulation_geometry import * from sorcha.ephemeris.simulation_parsing import * +from sorcha.ephemeris.simulation_setup import (create_assist_ephemeris, + furnish_spiceypy, + generate_simulations) +from sorcha.modules.PPOutput import (PPOutWriteCSV, PPOutWriteHDF5, + PPOutWriteSqlite3) from sorcha.utilities.dataUtilitiesForTests import get_data_out_filepath -from sorcha.ephemeris.pixel_dict import PixelDict -from sorcha.modules.PPOutput import PPOutWriteCSV, PPOutWriteSqlite3, PPOutWriteHDF5 @dataclass @@ -177,6 +176,7 @@ def create_ephemeris(orbits_df, pointings_df, args, sconfigs): picket_interval, nside, n_sub_intervals=n_sub_intervals, + use_integrate=sconfigs.expert.ar_use_integrate, ) for _, pointing in pointings_df.iterrows(): mjd_tai = float(pointing["observationMidpointMJD_TAI"]) @@ -208,7 +208,7 @@ def create_ephemeris(orbits_df, pointings_df, args, sconfigs): _, ephem_geom_params.r_ast, ephem_geom_params.v_ast, - ) = integrate_light_time(sim, ex, pointing["fieldJD_TDB"] - ephem.jd_ref, r_obs, lt0=0.01) + ) = integrate_light_time(sim, ex, pointing["fieldJD_TDB"] - ephem.jd_ref, r_obs, lt0=0.01, use_integrate=sconfigs.expert.ar_use_integrate) ephem_geom_params.rho_hat = ephem_geom_params.rho / ephem_geom_params.rho_mag ang_from_center = 180 / np.pi * np.arccos(np.dot(ephem_geom_params.rho_hat, visit_vector)) diff --git a/src/sorcha/ephemeris/simulation_geometry.py b/src/sorcha/ephemeris/simulation_geometry.py index 9e3f07e4..eb11b47f 100644 --- a/src/sorcha/ephemeris/simulation_geometry.py +++ b/src/sorcha/ephemeris/simulation_geometry.py @@ -1,12 +1,13 @@ import healpy as hp import numpy as np +import spiceypy as spice + from sorcha.ephemeris.simulation_constants import ( - RADIUS_EARTH_KM, - SPEED_OF_LIGHT, ECL_TO_EQ_ROTATION_MATRIX, EQ_TO_ECL_ROTATION_MATRIX, + RADIUS_EARTH_KM, + SPEED_OF_LIGHT, ) -import spiceypy as spice def ecliptic_to_equatorial(v, rot_mat=ECL_TO_EQ_ROTATION_MATRIX): @@ -45,7 +46,7 @@ def equatorial_to_ecliptic(v, rot_mat=EQ_TO_ECL_ROTATION_MATRIX): return np.dot(v, rot_mat) -def integrate_light_time(sim, ex, t, r_obs, lt0=0, iter=3, speed_of_light=SPEED_OF_LIGHT): +def integrate_light_time(sim, ex, t, r_obs, lt0=0, iter=3, speed_of_light=SPEED_OF_LIGHT, use_integrate=False): """ Performs the light travel time correction between object and observatory iteratively for the object at a given reference time @@ -79,8 +80,12 @@ def integrate_light_time(sim, ex, t, r_obs, lt0=0, iter=3, speed_of_light=SPEED_ Object velocity at t-lt """ lt = lt0 + for i in range(iter): - ex.integrate_or_interpolate(t - lt) + if use_integrate: + sim.integrate(t - lt) + else: + ex.integrate_or_interpolate(t - lt) target = np.array(sim.particles[0].xyz) vtarget = np.array(sim.particles[0].vxyz) rho = target - r_obs diff --git a/src/sorcha/ephemeris/simulation_setup.py b/src/sorcha/ephemeris/simulation_setup.py index 6a75d396..eef893ac 100644 --- a/src/sorcha/ephemeris/simulation_setup.py +++ b/src/sorcha/ephemeris/simulation_setup.py @@ -1,30 +1,24 @@ -from functools import partial -import spiceypy as spice -from assist import Ephem -from . import simulation_parsing as sp -import rebound -from collections import defaultdict -import assist import logging -import sys import os +import sys +from collections import defaultdict +from functools import partial + +import assist import numpy as np +import rebound +import spiceypy as spice +from assist import Ephem from sorcha.ephemeris.simulation_constants import * from sorcha.ephemeris.simulation_data_files import make_retriever - -from sorcha.ephemeris.simulation_geometry import ( - barycentricObservatoryRates, - get_hp_neighbors, - ra_dec2vec, -) -from sorcha.ephemeris.simulation_parsing import ( - Observatory, - mjd_tai_to_epoch, -) - +from sorcha.ephemeris.simulation_geometry import (barycentricObservatoryRates, + get_hp_neighbors, ra_dec2vec) +from sorcha.ephemeris.simulation_parsing import Observatory, mjd_tai_to_epoch from sorcha.utilities.generate_meta_kernel import build_meta_kernel_file +from . import simulation_parsing as sp + def create_assist_ephemeris(args, auxconfigs) -> tuple: """Build the ASSIST ephemeris object @@ -197,9 +191,8 @@ def precompute_pointing_information(pointings_df, args, sconfigs): pointings_df["visit_vector_z"] = vectors[:, 2] # use pandas `apply` (even though it's slow) instead of looping over the df in a for loop - pointings_df["fieldJD_TDB"] = pointings_df.apply( - lambda row: mjd_tai_to_epoch(row["observationMidpointMJD_TAI"]), axis=1 - ) + pointings_df["fieldJD_TDB"] = pointings_df["observationMidpointMJD_TAI"].apply(mjd_tai_to_epoch) + et = (pointings_df["fieldJD_TDB"] - spice.j2000()) * 24 * 60 * 60 # create a partial function since most params don't change, and it makes the lambda easier to read diff --git a/src/sorcha/utilities/sorchaConfigs.py b/src/sorcha/utilities/sorchaConfigs.py index 192dda1a..032f6bd0 100644 --- a/src/sorcha/utilities/sorchaConfigs.py +++ b/src/sorcha/utilities/sorchaConfigs.py @@ -1,11 +1,13 @@ -from dataclasses import dataclass import configparser import logging -import sys import os +import sys +from dataclasses import dataclass + import numpy as np -from sorcha.lightcurves.lightcurve_registration import LC_METHODS + from sorcha.activity.activity_registration import CA_METHODS +from sorcha.lightcurves.lightcurve_registration import LC_METHODS from sorcha.utilities.fileAccessUtils import FindFileOrExit @@ -750,6 +752,9 @@ class expertConfigs: brute_force: bool = None """brute-force ephemeris generation on all objects without running a first-pass""" + ar_use_integrate: bool = None + """flag for using the integrate method instead of integrate_or_interpolate""" + def __post_init__(self): """Automagically validates the expert configs after initialisation.""" self._validate_expert_configs() @@ -799,6 +804,7 @@ def _validate_expert_configs(self): self.randomization_on = cast_as_bool_or_set_default(self.randomization_on, "randomization_on", True) self.vignetting_on = cast_as_bool_or_set_default(self.vignetting_on, "vignetting_on", True) self.brute_force = cast_as_bool_or_set_default(self.brute_force, "brute_force", True) + self.ar_use_integrate = cast_as_bool_or_set_default(self.ar_use_integrate, "ar_use_integrate", False) @dataclass diff --git a/tests/ephemeris/test_ephemeris_generation.py b/tests/ephemeris/test_ephemeris_generation.py index af4e86f8..a1733c13 100644 --- a/tests/ephemeris/test_ephemeris_generation.py +++ b/tests/ephemeris/test_ephemeris_generation.py @@ -2,6 +2,7 @@ import pytest import os import re +from numpy.testing import assert_almost_equal from sorcha.utilities.dataUtilitiesForTests import get_test_filepath, get_demo_filepath from sorcha.modules.PPGetLogger import PPGetLogger @@ -153,6 +154,23 @@ def test_ephemeris_end2end(single_synthetic_pointing, tmp_path): for file in files: assert not re.match(r".+\.csv", file) + configs["ar_use_integrate"] = True + + observations_integrate = create_ephemeris( + single_synthetic_pointing, + filterpointing, + args, + configs, + ) + + assert len(observations_integrate) == 10 + + assert_almost_equal( + observations_integrate["fieldMJD_TAI"].values, observations["fieldMJD_TAI"].values, decimal=6 + ) + assert_almost_equal(observations_integrate["RA_deg"].values, observations["RA_deg"].values, decimal=6) + assert_almost_equal(observations_integrate["Dec_deg"].values, observations["Dec_deg"].values, decimal=6) + def test_ephemeris_writeread_csv(single_synthetic_ephemeris, tmp_path): """Tests to ensure the ephemeris file is written out correctly AND