Skip to content

Commit

Permalink
Add generic Ephemeris implementation (#115)
Browse files Browse the repository at this point in the history
Add generic ephemeris implementation that can be used with compliant propagators.
  • Loading branch information
akoumjian authored Aug 1, 2024
1 parent 7f7fcd5 commit 13165ed
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/adam_core/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.2.1.dev5+g4d3da0c.d20240731"
14 changes: 14 additions & 0 deletions src/adam_core/coordinates/origin.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ def SOLAR_SYSTEM_BARYCENTER(cls) -> float:
class Origin(qv.Table):
code = qv.LargeStringColumn()

def as_OriginCodes(self) -> OriginCodes:
"""
Convert the origin codes to an `~adam_core.coordinates.origin.OriginCodes` object.
Returns
-------
OriginCodes
Origin codes as an `~adam_core.coordinates.origin.OriginCodes` object.
"""
assert (
len(self.code.unique()) == 1
), "Only one origin code can be converted at a time."
return OriginCodes[self.code.unique()[0].as_py()]

def __eq__(self, other: object) -> np.ndarray:
if isinstance(other, (str, np.ndarray)):
codes = self.code.to_numpy(zero_copy_only=False)
Expand Down
2 changes: 1 addition & 1 deletion src/adam_core/coordinates/residuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def calculate(
raise TypeError(
f"Predicted coordinates must be one of {SUPPORTED_COORDINATES}, not {type(predicted)}."
)
if type(observed) != type(predicted):
if type(observed) is not type(predicted):
raise TypeError(
"Observed and predicted coordinates must be the same type, "
f"not {type(observed)} and {type(predicted)}."
Expand Down
2 changes: 1 addition & 1 deletion src/adam_core/coordinates/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,7 @@ def transform_coordinates(
# `~adam_core.coordinates.origin.OriginCodes` so we can compare them directly.
# If its not an OriginCodes enum then origin_out will be an array of strings which
# also can be checked for equality.
if type(coords) == representation_out_:
if type(coords) is representation_out_:
if coord_frame == frame_out and np.all(coord_origin == origin_out):
return coords

Expand Down
36 changes: 28 additions & 8 deletions src/adam_core/orbits/query/horizons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

import numpy.typing as npt
import pandas as pd
import pyarrow as pa
from astroquery.jplhorizons import Horizons

from ...coordinates.cartesian import CartesianCoordinates
from ...coordinates.cometary import CometaryCoordinates
from ...coordinates.keplerian import KeplerianCoordinates
from ...coordinates.origin import Origin
from ...coordinates.spherical import SphericalCoordinates
from ...observers import Observers
from ...time import Timestamp
from ..ephemeris import Ephemeris
from ..orbits import Orbits


Expand Down Expand Up @@ -53,7 +56,7 @@ def _get_horizons_vectors(
for i, obj_id in enumerate(object_ids):
obj = Horizons(
id=obj_id,
epochs=times.rescale("tdb").mjd().to_numpy(zero_copy_only=False),
epochs=times.rescale("tdb").jd().to_numpy(zero_copy_only=False),
location=location,
id_type=id_type,
)
Expand Down Expand Up @@ -157,10 +160,11 @@ def _get_horizons_ephemeris(
as seen from the observer location at the given times.
"""
dfs = []
jd_utc = times.rescale("utc").jd().to_numpy(zero_copy_only=False)
for i, obj_id in enumerate(object_ids):
obj = Horizons(
id=obj_id,
epochs=times.rescale("utc").mjd().to_numpy(zero_copy_only=False),
epochs=jd_utc,
location=location,
id_type=id_type,
)
Expand All @@ -171,7 +175,7 @@ def _get_horizons_ephemeris(
cache=False,
).to_pandas()
ephemeris.insert(0, "orbit_id", f"{i:05d}")
ephemeris.insert(2, "mjd_utc", times.utc.mjd)
ephemeris.insert(2, "jd_utc", jd_utc)
ephemeris.insert(3, "observatory_code", location)

dfs.append(ephemeris)
Expand All @@ -187,7 +191,7 @@ def _get_horizons_ephemeris(

def query_horizons_ephemeris(
object_ids: Union[List, npt.ArrayLike], observers: Observers
) -> pd.DataFrame:
) -> Ephemeris:
"""
Query JPL Horizons (through astroquery) for an object's predicted ephemeris
as seen from a given location at the given times.
Expand All @@ -208,19 +212,35 @@ def query_horizons_ephemeris(
"""
dfs = []
for observatory_code, observers_i in observers.iterate_codes():
ephemeris = _get_horizons_ephemeris(
_ephemeris = _get_horizons_ephemeris(
object_ids,
observers_i.coordinates.time,
observatory_code,
)
dfs.append(ephemeris)
dfs.append(_ephemeris)

ephemeris = pd.concat(dfs, ignore_index=True)
ephemeris.sort_values(
dfs = pd.concat(dfs, ignore_index=True)
dfs.sort_values(
by=["orbit_id", "datetime_jd", "observatory_code"],
inplace=True,
ignore_index=True,
)

ephemeris = Ephemeris.from_kwargs(
orbit_id=dfs["orbit_id"],
object_id=dfs["targetname"],
# Convert from minutes to days
light_time=dfs["lighttime"] / 1440,
alpha=dfs["alpha"],
coordinates=SphericalCoordinates.from_kwargs(
time=Timestamp.from_jd(pa.array(dfs["datetime_jd"]), scale="utc"),
lon=dfs["RA"],
lat=dfs["DEC"],
origin=Origin.from_kwargs(code=dfs["observatory_code"]),
frame="ecliptic",
),
)

return ephemeris


Expand Down
171 changes: 161 additions & 10 deletions src/adam_core/propagator/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
import numpy.typing as npt
import quivr as qv

from adam_core.ray_cluster import initialize_use_ray

from ..constants import Constants as c
from ..coordinates.cartesian import CartesianCoordinates
from ..coordinates.origin import Origin, OriginCodes
from ..coordinates.spherical import SphericalCoordinates
from ..coordinates.transform import transform_coordinates
from ..observers.observers import Observers
from ..orbits.ephemeris import Ephemeris
from ..orbits.orbits import Orbits
from ..orbits.variants import VariantEphemeris, VariantOrbits
from ..ray_cluster import initialize_use_ray
from ..time import Timestamp
from .utils import _iterate_chunks

logger = logging.getLogger(__name__)

C = c.C

RAY_INSTALLED = False
try:
import ray
Expand Down Expand Up @@ -89,17 +95,162 @@ class EphemerisMixin:
Subclasses should implement the _generate_ephemeris method.
"""

@abstractmethod
def _add_light_time(
self,
orbits,
observers,
lt_tol: float = 1e-12,
max_iter: int = 10,
):
orbits_aberrated = orbits.empty()
lts = np.zeros(len(orbits))
for i, (orbit, observer) in enumerate(zip(orbits, observers)):
# Set the running variables
lt_prev = 0
dlt = float("inf")
orbit_i = orbit
lt = 0

# Extract the observer's position which remains
# constant for all iterations
observer_position = observer.coordinates.r

# Calculate the orbit's current epoch (the epoch from which
# the light travel time will be calculated)
t0 = orbit_i.coordinates.time.rescale("tdb").mjd()[0].as_py()

iterations = 0
while dlt > lt_tol and iterations < max_iter:
iterations += 1

# Calculate the topocentric distance
rho = np.linalg.norm(orbit_i.coordinates.r - observer_position)

# Calculate the light travel time
lt = rho / C

# Calculate the change in light travel time since the previous iteration
dlt = np.abs(lt - lt_prev)

# Calculate the new epoch and propagate the initial orbit to that epoch
orbit_i = self.propagate_orbits(
orbit, Timestamp.from_mjd([t0 - lt], scale="tdb")
)

# Update the previous light travel time to this iteration's light travel time
lt_prev = lt

orbits_aberrated = qv.concatenate([orbits_aberrated, orbit_i])
lts[i] = lt

return orbits_aberrated, lts

def _generate_ephemeris(
self, orbits: EphemerisType, observers: ObserverType
self, orbits: OrbitType, observers: ObserverType
) -> EphemerisType:
"""
Generate ephemerides for the given orbits as observed by
the observers.
THIS FUNCTION SHOULD BE DEFINED BY THE USER.
A generic ephemeris implementation, which can be used or overridden by subclasses.
"""
pass

if isinstance(orbits, Orbits):
ephemeris_total = Ephemeris.empty()
elif isinstance(orbits, VariantOrbits):
ephemeris_total = VariantEphemeris.empty()

for orbit in orbits:
propagated_orbits = self.propagate_orbits(orbit, observers.coordinates.time)

# Transform both the orbits and observers to the barycenter if they are not already.
propagated_orbits_barycentric = propagated_orbits.set_column(
"coordinates",
transform_coordinates(
propagated_orbits.coordinates,
CartesianCoordinates,
frame_out="ecliptic",
origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER,
),
)
observers_barycentric = observers.set_column(
"coordinates",
transform_coordinates(
observers.coordinates,
CartesianCoordinates,
frame_out="ecliptic",
origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER,
),
)
num_orbits = len(propagated_orbits_barycentric.orbit_id.unique())

observer_codes = np.tile(
observers.code.to_numpy(zero_copy_only=False), num_orbits
)

propagated_orbits_aberrated, light_time = self._add_light_time(
propagated_orbits_barycentric,
observers_barycentric,
lt_tol=1e-12,
)

topocentric_state = (
propagated_orbits_aberrated.coordinates.values
- observers_barycentric.coordinates.values
)
topocentric_coordinates = CartesianCoordinates.from_kwargs(
x=topocentric_state[:, 0],
y=topocentric_state[:, 1],
z=topocentric_state[:, 2],
vx=topocentric_state[:, 3],
vy=topocentric_state[:, 4],
vz=topocentric_state[:, 5],
covariance=None,
# The ephemeris times are at the point of the observer,
# not the aberrated orbit
time=observers.coordinates.time,
origin=Origin.from_kwargs(code=observer_codes),
frame="ecliptic",
)

spherical_coordinates = SphericalCoordinates.from_cartesian(
topocentric_coordinates
)

light_time = np.array(light_time)

spherical_coordinates = transform_coordinates(
spherical_coordinates, SphericalCoordinates, frame_out="equatorial"
)

# Ephemeris are generally compared in UTC, so rescale the time
spherical_coordinates = spherical_coordinates.set_column(
"time",
spherical_coordinates.time.rescale("utc"),
)

if isinstance(orbits, Orbits):

ephemeris = Ephemeris.from_kwargs(
orbit_id=propagated_orbits_barycentric.orbit_id,
object_id=propagated_orbits_barycentric.object_id,
coordinates=spherical_coordinates,
light_time=light_time,
aberrated_coordinates=propagated_orbits_aberrated.coordinates,
)

elif isinstance(orbits, VariantOrbits):
weights = orbits.weights
weights_cov = orbits.weights_cov

ephemeris = VariantEphemeris.from_kwargs(
orbit_id=propagated_orbits_barycentric.orbit_id,
object_id=propagated_orbits_barycentric.object_id,
coordinates=spherical_coordinates,
weights=weights,
weights_cov=weights_cov,
)

ephemeris_total = qv.concatenate([ephemeris_total, ephemeris])

return ephemeris_total

def generate_ephemeris(
self,
Expand Down Expand Up @@ -261,7 +412,7 @@ def generate_ephemeris(
)


class Propagator(ABC):
class Propagator(ABC, EphemerisMixin):
"""
Abstract class for propagating orbits and related functions.
Expand Down
6 changes: 4 additions & 2 deletions src/adam_core/time/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,10 @@ def add_fractional_days(
nano_part = pc.subtract(fractional_days, day_part)

days = pc.cast(day_part, pa.int64())
nanos = pc.cast(pc.multiply(nano_part, 86400 * 1e9), pa.int64())

nanos = pc.cast(
pc.multiply(nano_part, 86400 * 1e9),
options=pc.CastOptions(target_type=pa.int64(), allow_float_truncate=True),
)
return self.add_days(days).add_nanos(nanos)

def difference_scalar(
Expand Down

0 comments on commit 13165ed

Please sign in to comment.