Skip to content

Commit

Permalink
Modify generate_test_orbits to accept a path to a parquet file
Browse files Browse the repository at this point in the history
Parallelize test orbit generation over chunks of unique healpixels.
Test orbits are now generated at the midnight of the day of the first
observation.

Co-authored-by: Alec Koumjian <[email protected]>
  • Loading branch information
moeyensj and akoumjian committed Jan 25, 2024
1 parent 72dbd9e commit 17a03f6
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 47 deletions.
2 changes: 1 addition & 1 deletion thor/observations/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def filter_observations(
test_orbit: TestOrbits,
config: Config,
filters: Optional[List[ObservationFilter]] = None,
chunk_size: int = 100,
chunk_size: int = 10,
) -> Observations:
"""
Filter observations by applying a list of filters. The input observations
Expand Down
233 changes: 187 additions & 46 deletions thor/orbit_selection.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import logging
import multiprocessing as mp
import time
from dataclasses import dataclass
from typing import Optional, Union

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq
import quivr as qv
import ray
from adam_core.coordinates import KeplerianCoordinates
from adam_core.observers import Observers
from adam_core.orbits import Ephemeris, Orbits
from adam_core.propagator import PYOORB, Propagator
from adam_core.propagator.utils import _iterate_chunks
from adam_core.ray_cluster import initialize_use_ray
from adam_core.time import Timestamp

from thor.observations import Observations
from thor.orbit import TestOrbits
Expand Down Expand Up @@ -185,12 +192,88 @@ def select_test_orbits(ephemeris: Ephemeris, orbits: Orbits) -> Orbits:
return Orbits.empty()


def generate_test_orbits_worker(
healpixel_chunk: pa.Array,
ephemeris_healpixels: pa.Array,
propagated_orbits: Union[Orbits, ray.ObjectRef],
ephemeris: Union[Ephemeris, ray.ObjectRef],
) -> TestOrbits:
"""
Worker function for generating test orbits.
Parameters
----------
healpixel_chunk
Healpixels to generate test orbits for.
ephemeris_healpixels
Healpixels for the ephemeris.
propagated_orbits
Propagated orbits.
ephemeris
Ephemeris for the propagated orbits.
Returns
-------
test_orbits
Test orbits generated from the propagated orbits.
"""
test_orbits_list = []

# Filter the ephemerides to only those in the observations
ephemeris_mask = pc.is_in(ephemeris_healpixels, healpixel_chunk)
ephemeris_filtered = ephemeris.apply_mask(ephemeris_mask)
ephemeris_healpixels = pc.filter(ephemeris_healpixels, ephemeris_mask)
logger.info(
f"{len(ephemeris_filtered)} orbit ephemerides overlap with the observations."
)

# Filter the orbits to only those in the ephemeris
orbits_filtered = propagated_orbits.apply_mask(
pc.is_in(propagated_orbits.orbit_id, ephemeris_filtered.orbit_id)
)

logger.info("Selecting test orbits from the orbit catalog...")
for healpixel in healpixel_chunk:
healpixel_mask = pc.equal(ephemeris_healpixels, healpixel)
ephemeris_healpixel = ephemeris_filtered.apply_mask(healpixel_mask)

if len(ephemeris_healpixel) == 0:
logger.debug(f"No ephemerides in healpixel {healpixel}.")
continue

test_orbits_healpixel = select_test_orbits(ephemeris_healpixel, orbits_filtered)

if len(test_orbits_healpixel) > 0:
test_orbits_list.append(
TestOrbits.from_kwargs(
orbit_id=test_orbits_healpixel.orbit_id,
object_id=test_orbits_healpixel.object_id,
coordinates=test_orbits_healpixel.coordinates,
bundle_id=[healpixel for _ in range(len(test_orbits_healpixel))],
)
)
else:
logger.debug(f"No orbits in healpixel {healpixel}.")

if len(test_orbits_list) > 0:
test_orbits = qv.concatenate(test_orbits_list)
else:
test_orbits = TestOrbits.empty()

return test_orbits


generate_test_orbits_worker_remote = ray.remote(generate_test_orbits_worker)
generate_test_orbits_worker_remote.options(num_cpus=1, num_returns=1)


def generate_test_orbits(
observations: Observations,
observations: Union[str, Observations],
catalog: Orbits,
nside: int = 32,
propagator: Propagator = PYOORB(),
max_processes: int = 1,
max_processes: Optional[int] = None,
chunk_size: int = 100,
) -> TestOrbits:
"""
Given observations and a catalog of known orbits generate test orbits
Expand All @@ -205,7 +288,9 @@ def generate_test_orbits(
Parameters
----------
observations
Observations to generate test orbits for.
Observations for which to generate test orbits. These observations can
be an in-memory Observations object or a path to a parquet file containing the
observations.
catalog
Catalog of known orbits.
nside
Expand All @@ -215,14 +300,43 @@ def generate_test_orbits(
max_processes
Maximum number of processes to use while propagating orbits and
generating ephemerides.
chunk_size
The maximum number of unique healpixels for which to generate test orbits per
process. This function will dynamically compute the chunk size based on the
number of unique healpixels and the number of processes. The dynamic chunk
size will never exceed the given value.
Returns
-------
test_orbits
Test orbits generated from the catalog.
"""
# Extract the minimum time from the observations
start_time = observations.coordinates.time.min()
time_start = time.perf_counter()
logger.info("Generating test orbits...")

# If the input file is a string, read in the days column to
# extract the minimum time
if isinstance(observations, str):
table = pq.read_table(
observations, columns=["coordinates.time.days"], memory_map=True
)

min_day = pc.min(table["days"]).as_py()
# Set the start time to the midnight of the first night of observations
start_time = Timestamp.from_kwargs(days=[min_day], nanos=[0], scale="utc")

elif isinstance(observations, Observations):
# Extract the minimum time from the observations
earliest_time = observations.coordinates.time.min()

# Set the start time to the midnight of the first night of observations
start_time = Timestamp.from_kwargs(
days=earliest_time.days, nanos=[0], scale="utc"
)
else:
raise ValueError(
f"observations must be a path to a parquet file or an Observations object. Got {type(observations)}."
)

# Propagate the orbits to the minimum time
logger.info("Propagating orbits to the start time of the observations...")
Expand Down Expand Up @@ -258,67 +372,94 @@ def generate_test_orbits(
f"Ephemeris generation completed in {ephemeris_end_time - ephemeris_start_time:.3f} seconds."
)

if isinstance(observations, str):
table = pq.read_table(
observations,
columns=["coordinates.lon", "coordinates.lat"],
memory_map=True,
)
lon = table["lon"].to_numpy(zero_copy_only=False)
lat = table["lat"].to_numpy(zero_copy_only=False)
del table

else:
lon = observations.coordinates.lon.to_numpy(zero_copy_only=False)
lat = observations.coordinates.lat.to_numpy(zero_copy_only=False)

# Calculate the healpixels for observations and ephemerides
# Here we want the unique healpixels so we can cross match against our
# catalog's predicted ephemeris
observations_healpixels = calculate_healpixels(
observations.coordinates.lon.to_numpy(zero_copy_only=False),
observations.coordinates.lat.to_numpy(zero_copy_only=False),
lon,
lat,
nside=nside,
)
observations_healpixels = np.unique(observations_healpixels)
observations_healpixels = pc.unique(pa.array(observations_healpixels))
logger.info(
f"Observations occur in {len(observations_healpixels)} unique healpixels."
)

# Calculate the healpixels for the ephemerides
# Calculate the healpixels for each ephemeris
# We do not want unique healpixels here because we want to
# select orbits from the same healpixel as the observations
ephemeris_healpixels = calculate_healpixels(
ephemeris.coordinates.lon.to_numpy(zero_copy_only=False),
ephemeris.coordinates.lat.to_numpy(zero_copy_only=False),
nside=nside,
)
ephemeris_healpixels = pa.array(ephemeris_healpixels)

# Filter the ephemerides to only those in the observations
ephemeris_mask = pa.array(np.in1d(ephemeris_healpixels, observations_healpixels))
ephemeris_filtered = ephemeris.apply_mask(ephemeris_mask)
ephemeris_healpixels = ephemeris_healpixels[
ephemeris_mask.to_numpy(zero_copy_only=False)
]
logger.info(
f"{len(ephemeris_filtered)} orbit ephemerides overlap with the observations."
)
# Dynamically compute the chunk size based on the number of healpixels
# and the number of processes
if max_processes is None:
max_processes = mp.cpu_count()

# Filter the orbits to only those in the ephemeris
orbits_filtered = propagated_orbits.apply_mask(
pc.is_in(propagated_orbits.orbit_id, ephemeris_filtered.orbit_id)
chunk_size = np.minimum(
np.ceil(len(observations_healpixels) / max_processes).astype(int), chunk_size
)

logger.info("Selecting test orbits from the orbit catalog...")
test_orbits_list = []
for healpixel in observations_healpixels:
healpixel_mask = pc.equal(ephemeris_healpixels, healpixel)
ephemeris_healpixel = ephemeris_filtered.apply_mask(healpixel_mask)

if len(ephemeris_healpixel) == 0:
logger.debug(f"No ephemerides in healpixel {healpixel}.")
continue

test_orbits_healpixel = select_test_orbits(ephemeris_healpixel, orbits_filtered)

if len(test_orbits_healpixel) > 0:
test_orbits_list.append(
TestOrbits.from_kwargs(
orbit_id=test_orbits_healpixel.orbit_id,
object_id=test_orbits_healpixel.object_id,
coordinates=test_orbits_healpixel.coordinates,
bundle_id=[healpixel for _ in range(len(test_orbits_healpixel))],
logger.info(f"Generating test orbits with a chunk size of {chunk_size} healpixels.")

test_orbits = TestOrbits.empty()
use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:

ephemeris_ref = ray.put(ephemeris)
ephemeris_healpixels_ref = ray.put(ephemeris_healpixels)
propagated_orbits_ref = ray.put(propagated_orbits)

futures = []
for healpixel_chunk in _iterate_chunks(observations_healpixels, chunk_size):
futures.append(
generate_test_orbits_worker_remote.remote(
healpixel_chunk,
ephemeris_healpixels_ref,
propagated_orbits_ref,
ephemeris_ref,
)
)
else:
logger.debug(f"No orbits in healpixel {healpixel}.")

if len(test_orbits_list) > 0:
test_orbits = qv.concatenate(test_orbits_list)
while futures:
finished, futures = ray.wait(futures, num_returns=1)
test_orbits = qv.concatenate([test_orbits, ray.get(finished[0])])
if test_orbits.fragmented():
test_orbits = qv.defragment(test_orbits)

else:
test_orbits = TestOrbits.empty()

for healpixel_chunk in _iterate_chunks(observations_healpixels, chunk_size):
test_orbits_chunk = generate_test_orbits_worker(
healpixel_chunk,
ephemeris_healpixels,
propagated_orbits,
ephemeris,
)
test_orbits = qv.concatenate([test_orbits, test_orbits_chunk])
if test_orbits.fragmented():
test_orbits = qv.defragment(test_orbits)

time_end = time.perf_counter()
logger.info(f"Selected {len(test_orbits)} test orbits.")
logger.info(
f"Test orbit generation completed in {time_end - time_start:.3f} seconds."
)
return test_orbits

0 comments on commit 17a03f6

Please sign in to comment.