Skip to content

Commit

Permalink
Enable option to use ASSIST integrate instead of integrate_or_interpo…
Browse files Browse the repository at this point in the history
…late
  • Loading branch information
akoumjian committed Jan 15, 2025
1 parent fda3d1e commit 0e5f105
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 56 deletions.
37 changes: 20 additions & 17 deletions src/sorcha/ephemeris/pixel_dict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy as np
from collections import defaultdict

import healpy as hp
import numba
import numpy as np

from collections import defaultdict

from sorcha.ephemeris.simulation_geometry import *
from sorcha.ephemeris.simulation_constants import *
from sorcha.ephemeris.simulation_geometry import *


@numba.njit(fastmath=True)
Expand Down Expand Up @@ -62,6 +62,7 @@ def __init__(
nside=128,
nested=True,
n_sub_intervals=101,
use_integrate=False,
):
"""
Initialization function for the class. Computes the initial positions required for the ephemerides interpolation
Expand All @@ -87,6 +88,8 @@ def __init__(
Defines the ordering scheme for the healpix ordering. True (default) means a NESTED ordering
n_sub_intervals: int
Number of sub-intervals for the Lagrange interpolation (default: 101)
use_integrate: boolean
Whether to use the integrator to compute the ephemerides (default: False)
"""
self.nside = nside
self.picket_interval = picket_interval
Expand All @@ -96,7 +99,7 @@ def __init__(
self.sim_dict = sim_dict
self.ephem = ephem
self.observatory = observatory

self.use_integrate = use_integrate
# Set the three times and compute the observatory position
# at those times
# Using a quadratic isn't very general, but that can be
Expand All @@ -115,9 +118,9 @@ def __init__(

self.pixel_dict = defaultdict(list)

self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm)
self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp)
self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm, use_integrate=self.use_integrate)
self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0, use_integrate=self.use_integrate)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp, use_integrate=self.use_integrate)

self.compute_pixel_traversed()

Expand All @@ -138,7 +141,7 @@ def get_observatory_position(self, t):
r_obs = self.observatory.barycentricObservatory(et, self.obsCode) / AU_KM
return r_obs

def get_object_unit_vectors(self, desigs, r_obs, t, lt0=0.01):
def get_object_unit_vectors(self, desigs, r_obs, t, lt0=0.01, use_integrate=False):
"""
Computes the unit vector (in the equatorial sphere) that point towards the object - observatory vector
for a list of objects, at a given time
Expand All @@ -165,13 +168,13 @@ def get_object_unit_vectors(self, desigs, r_obs, t, lt0=0.01):

# Get the topocentric unit vectors
rho, rho_mag, lt, r_ast, v_ast = integrate_light_time(
sim, ex, t - self.ephem.jd_ref, r_obs, lt0=lt0
sim, ex, t - self.ephem.jd_ref, r_obs, lt0=lt0, use_integrate=use_integrate
)
rho_hat = rho / rho_mag
rho_hat_dict[k] = rho_hat
return rho_hat_dict

def get_all_object_unit_vectors(self, r_obs, t, lt0=0.01):
def get_all_object_unit_vectors(self, r_obs, t, lt0=0.01, use_integrate=False):
"""
Computes the unit vector (in the equatorial sphere) that point towards the object - observatory vector
for *all* objects, at a given time
Expand All @@ -191,7 +194,7 @@ def get_all_object_unit_vectors(self, r_obs, t, lt0=0.01):
"""

desigs = self.sim_dict.keys()
return self.get_object_unit_vectors(desigs, r_obs, t, lt0=lt0)
return self.get_object_unit_vectors(desigs, r_obs, t, lt0=lt0, use_integrate=use_integrate)

def get_interp_factors(self, tm, t0, tp, n_sub_intervals):
"""
Expand Down Expand Up @@ -313,7 +316,7 @@ def update_pickets(self, jd_tdb):

self.tm = self.t0 - self.picket_interval
self.r_obs_m = self.get_observatory_position(self.tm)
self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm)
self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm, use_integrate=self.use_integrate)

else:
# shift later
Expand All @@ -327,7 +330,7 @@ def update_pickets(self, jd_tdb):

self.tp = self.t0 + self.picket_interval
self.r_obs_p = self.get_observatory_position(self.tp)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp, use_integrate=self.use_integrate)

else:
# Need to compute three new sets
Expand All @@ -336,15 +339,15 @@ def update_pickets(self, jd_tdb):
# This is repeated code
self.t0 += n * self.picket_interval
self.r_obs_0 = self.get_observatory_position(self.t0)
self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0)
self.rho_hat_0_dict = self.get_all_object_unit_vectors(self.r_obs_0, self.t0, use_integrate=self.use_integrate)

self.tp = self.t0 + self.picket_interval
self.r_obs_p = self.get_observatory_position(self.tp)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp)
self.rho_hat_p_dict = self.get_all_object_unit_vectors(self.r_obs_p, self.tp, use_integrate=self.use_integrate)

self.tm = self.t0 - self.picket_interval
self.r_obs_m = self.get_observatory_position(self.tm)
self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm)
self.rho_hat_m_dict = self.get_all_object_unit_vectors(self.r_obs_m, self.tm, use_integrate=self.use_integrate)

self.compute_pixel_traversed()
else:
Expand Down
18 changes: 9 additions & 9 deletions src/sorcha/ephemeris/simulation_driver.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from dataclasses import dataclass
from collections import defaultdict
from csv import writer
from dataclasses import dataclass
from io import StringIO

import numpy as np
import pandas as pd
import spiceypy as spice

from sorcha.ephemeris.simulation_setup import (
create_assist_ephemeris,
furnish_spiceypy,
generate_simulations,
)
from sorcha.ephemeris.pixel_dict import PixelDict
from sorcha.ephemeris.simulation_constants import *
from sorcha.ephemeris.simulation_geometry import *
from sorcha.ephemeris.simulation_parsing import *
from sorcha.ephemeris.simulation_setup import (create_assist_ephemeris,
furnish_spiceypy,
generate_simulations)
from sorcha.modules.PPOutput import (PPOutWriteCSV, PPOutWriteHDF5,
PPOutWriteSqlite3)
from sorcha.utilities.dataUtilitiesForTests import get_data_out_filepath
from sorcha.ephemeris.pixel_dict import PixelDict
from sorcha.modules.PPOutput import PPOutWriteCSV, PPOutWriteSqlite3, PPOutWriteHDF5


@dataclass
Expand Down Expand Up @@ -177,6 +176,7 @@ def create_ephemeris(orbits_df, pointings_df, args, sconfigs):
picket_interval,
nside,
n_sub_intervals=n_sub_intervals,
use_integrate=sconfigs.expert.ar_use_integrate,
)
for _, pointing in pointings_df.iterrows():
mjd_tai = float(pointing["observationMidpointMJD_TAI"])
Expand Down Expand Up @@ -208,7 +208,7 @@ def create_ephemeris(orbits_df, pointings_df, args, sconfigs):
_,
ephem_geom_params.r_ast,
ephem_geom_params.v_ast,
) = integrate_light_time(sim, ex, pointing["fieldJD_TDB"] - ephem.jd_ref, r_obs, lt0=0.01)
) = integrate_light_time(sim, ex, pointing["fieldJD_TDB"] - ephem.jd_ref, r_obs, lt0=0.01, use_integrate=sconfigs.expert.ar_use_integrate)
ephem_geom_params.rho_hat = ephem_geom_params.rho / ephem_geom_params.rho_mag

ang_from_center = 180 / np.pi * np.arccos(np.dot(ephem_geom_params.rho_hat, visit_vector))
Expand Down
15 changes: 10 additions & 5 deletions src/sorcha/ephemeris/simulation_geometry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import healpy as hp
import numpy as np
import spiceypy as spice

from sorcha.ephemeris.simulation_constants import (
RADIUS_EARTH_KM,
SPEED_OF_LIGHT,
ECL_TO_EQ_ROTATION_MATRIX,
EQ_TO_ECL_ROTATION_MATRIX,
RADIUS_EARTH_KM,
SPEED_OF_LIGHT,
)
import spiceypy as spice


def ecliptic_to_equatorial(v, rot_mat=ECL_TO_EQ_ROTATION_MATRIX):
Expand Down Expand Up @@ -45,7 +46,7 @@ def equatorial_to_ecliptic(v, rot_mat=EQ_TO_ECL_ROTATION_MATRIX):
return np.dot(v, rot_mat)


def integrate_light_time(sim, ex, t, r_obs, lt0=0, iter=3, speed_of_light=SPEED_OF_LIGHT):
def integrate_light_time(sim, ex, t, r_obs, lt0=0, iter=3, speed_of_light=SPEED_OF_LIGHT, use_integrate=False):
"""
Performs the light travel time correction between object and observatory iteratively for the object at a given reference time
Expand Down Expand Up @@ -79,8 +80,12 @@ def integrate_light_time(sim, ex, t, r_obs, lt0=0, iter=3, speed_of_light=SPEED_
Object velocity at t-lt
"""
lt = lt0

for i in range(iter):
ex.integrate_or_interpolate(t - lt)
if use_integrate:
sim.integrate(t - lt)
else:
ex.integrate_or_interpolate(t - lt)
target = np.array(sim.particles[0].xyz)
vtarget = np.array(sim.particles[0].vxyz)
rho = target - r_obs
Expand Down
37 changes: 15 additions & 22 deletions src/sorcha/ephemeris/simulation_setup.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,24 @@
from functools import partial
import spiceypy as spice
from assist import Ephem
from . import simulation_parsing as sp
import rebound
from collections import defaultdict
import assist
import logging
import sys
import os
import sys
from collections import defaultdict
from functools import partial

import assist
import numpy as np
import rebound
import spiceypy as spice
from assist import Ephem

from sorcha.ephemeris.simulation_constants import *
from sorcha.ephemeris.simulation_data_files import make_retriever

from sorcha.ephemeris.simulation_geometry import (
barycentricObservatoryRates,
get_hp_neighbors,
ra_dec2vec,
)
from sorcha.ephemeris.simulation_parsing import (
Observatory,
mjd_tai_to_epoch,
)

from sorcha.ephemeris.simulation_geometry import (barycentricObservatoryRates,
get_hp_neighbors, ra_dec2vec)
from sorcha.ephemeris.simulation_parsing import Observatory, mjd_tai_to_epoch
from sorcha.utilities.generate_meta_kernel import build_meta_kernel_file

from . import simulation_parsing as sp


def create_assist_ephemeris(args, auxconfigs) -> tuple:
"""Build the ASSIST ephemeris object
Expand Down Expand Up @@ -197,9 +191,8 @@ def precompute_pointing_information(pointings_df, args, sconfigs):
pointings_df["visit_vector_z"] = vectors[:, 2]

# use pandas `apply` (even though it's slow) instead of looping over the df in a for loop
pointings_df["fieldJD_TDB"] = pointings_df.apply(
lambda row: mjd_tai_to_epoch(row["observationMidpointMJD_TAI"]), axis=1
)
pointings_df["fieldJD_TDB"] = pointings_df["observationMidpointMJD_TAI"].apply(mjd_tai_to_epoch)

et = (pointings_df["fieldJD_TDB"] - spice.j2000()) * 24 * 60 * 60

# create a partial function since most params don't change, and it makes the lambda easier to read
Expand Down
12 changes: 9 additions & 3 deletions src/sorcha/utilities/sorchaConfigs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dataclasses import dataclass
import configparser
import logging
import sys
import os
import sys
from dataclasses import dataclass

import numpy as np
from sorcha.lightcurves.lightcurve_registration import LC_METHODS

from sorcha.activity.activity_registration import CA_METHODS
from sorcha.lightcurves.lightcurve_registration import LC_METHODS
from sorcha.utilities.fileAccessUtils import FindFileOrExit


Expand Down Expand Up @@ -750,6 +752,9 @@ class expertConfigs:
brute_force: bool = None
"""brute-force ephemeris generation on all objects without running a first-pass"""

ar_use_integrate: bool = None
"""flag for using the integrate method instead of integrate_or_interpolate"""

def __post_init__(self):
"""Automagically validates the expert configs after initialisation."""
self._validate_expert_configs()
Expand Down Expand Up @@ -799,6 +804,7 @@ def _validate_expert_configs(self):
self.randomization_on = cast_as_bool_or_set_default(self.randomization_on, "randomization_on", True)
self.vignetting_on = cast_as_bool_or_set_default(self.vignetting_on, "vignetting_on", True)
self.brute_force = cast_as_bool_or_set_default(self.brute_force, "brute_force", True)
self.ar_use_integrate = cast_as_bool_or_set_default(self.ar_use_integrate, "ar_use_integrate", False)


@dataclass
Expand Down
18 changes: 18 additions & 0 deletions tests/ephemeris/test_ephemeris_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import os
import re
from numpy.testing import assert_almost_equal

from sorcha.utilities.dataUtilitiesForTests import get_test_filepath, get_demo_filepath
from sorcha.modules.PPGetLogger import PPGetLogger
Expand Down Expand Up @@ -153,6 +154,23 @@ def test_ephemeris_end2end(single_synthetic_pointing, tmp_path):
for file in files:
assert not re.match(r".+\.csv", file)

configs["ar_use_integrate"] = True

observations_integrate = create_ephemeris(
single_synthetic_pointing,
filterpointing,
args,
configs,
)

assert len(observations_integrate) == 10

assert_almost_equal(
observations_integrate["fieldMJD_TAI"].values, observations["fieldMJD_TAI"].values, decimal=6
)
assert_almost_equal(observations_integrate["RA_deg"].values, observations["RA_deg"].values, decimal=6)
assert_almost_equal(observations_integrate["Dec_deg"].values, observations["Dec_deg"].values, decimal=6)


def test_ephemeris_writeread_csv(single_synthetic_ephemeris, tmp_path):
"""Tests to ensure the ephemeris file is written out correctly AND
Expand Down

0 comments on commit 0e5f105

Please sign in to comment.