From 17a03f65d604dafec7151c87837318abf6cf1af8 Mon Sep 17 00:00:00 2001 From: Joachim Moeyens Date: Thu, 25 Jan 2024 15:12:04 -0800 Subject: [PATCH] Modify generate_test_orbits to accept a path to a parquet file Parallelize test orbit generation over chunks of unique healpixels. Test orbits are now generated at the midnight of the day of the first observation. Co-authored-by: Alec Koumjian --- thor/observations/filters.py | 2 +- thor/orbit_selection.py | 233 ++++++++++++++++++++++++++++------- 2 files changed, 188 insertions(+), 47 deletions(-) diff --git a/thor/observations/filters.py b/thor/observations/filters.py index d4998968..92f8cf67 100644 --- a/thor/observations/filters.py +++ b/thor/observations/filters.py @@ -263,7 +263,7 @@ def filter_observations( test_orbit: TestOrbits, config: Config, filters: Optional[List[ObservationFilter]] = None, - chunk_size: int = 100, + chunk_size: int = 10, ) -> Observations: """ Filter observations by applying a list of filters. The input observations diff --git a/thor/orbit_selection.py b/thor/orbit_selection.py index c14b0099..2ddf2a29 100644 --- a/thor/orbit_selection.py +++ b/thor/orbit_selection.py @@ -1,15 +1,22 @@ import logging +import multiprocessing as mp import time from dataclasses import dataclass +from typing import Optional, Union import numpy as np import pyarrow as pa import pyarrow.compute as pc +import pyarrow.parquet as pq import quivr as qv +import ray from adam_core.coordinates import KeplerianCoordinates from adam_core.observers import Observers from adam_core.orbits import Ephemeris, Orbits from adam_core.propagator import PYOORB, Propagator +from adam_core.propagator.utils import _iterate_chunks +from adam_core.ray_cluster import initialize_use_ray +from adam_core.time import Timestamp from thor.observations import Observations from thor.orbit import TestOrbits @@ -185,12 +192,88 @@ def select_test_orbits(ephemeris: Ephemeris, orbits: Orbits) -> Orbits: return Orbits.empty() +def generate_test_orbits_worker( + healpixel_chunk: pa.Array, + ephemeris_healpixels: pa.Array, + propagated_orbits: Union[Orbits, ray.ObjectRef], + ephemeris: Union[Ephemeris, ray.ObjectRef], +) -> TestOrbits: + """ + Worker function for generating test orbits. + + Parameters + ---------- + healpixel_chunk + Healpixels to generate test orbits for. + ephemeris_healpixels + Healpixels for the ephemeris. + propagated_orbits + Propagated orbits. + ephemeris + Ephemeris for the propagated orbits. + + Returns + ------- + test_orbits + Test orbits generated from the propagated orbits. + """ + test_orbits_list = [] + + # Filter the ephemerides to only those in the observations + ephemeris_mask = pc.is_in(ephemeris_healpixels, healpixel_chunk) + ephemeris_filtered = ephemeris.apply_mask(ephemeris_mask) + ephemeris_healpixels = pc.filter(ephemeris_healpixels, ephemeris_mask) + logger.info( + f"{len(ephemeris_filtered)} orbit ephemerides overlap with the observations." + ) + + # Filter the orbits to only those in the ephemeris + orbits_filtered = propagated_orbits.apply_mask( + pc.is_in(propagated_orbits.orbit_id, ephemeris_filtered.orbit_id) + ) + + logger.info("Selecting test orbits from the orbit catalog...") + for healpixel in healpixel_chunk: + healpixel_mask = pc.equal(ephemeris_healpixels, healpixel) + ephemeris_healpixel = ephemeris_filtered.apply_mask(healpixel_mask) + + if len(ephemeris_healpixel) == 0: + logger.debug(f"No ephemerides in healpixel {healpixel}.") + continue + + test_orbits_healpixel = select_test_orbits(ephemeris_healpixel, orbits_filtered) + + if len(test_orbits_healpixel) > 0: + test_orbits_list.append( + TestOrbits.from_kwargs( + orbit_id=test_orbits_healpixel.orbit_id, + object_id=test_orbits_healpixel.object_id, + coordinates=test_orbits_healpixel.coordinates, + bundle_id=[healpixel for _ in range(len(test_orbits_healpixel))], + ) + ) + else: + logger.debug(f"No orbits in healpixel {healpixel}.") + + if len(test_orbits_list) > 0: + test_orbits = qv.concatenate(test_orbits_list) + else: + test_orbits = TestOrbits.empty() + + return test_orbits + + +generate_test_orbits_worker_remote = ray.remote(generate_test_orbits_worker) +generate_test_orbits_worker_remote.options(num_cpus=1, num_returns=1) + + def generate_test_orbits( - observations: Observations, + observations: Union[str, Observations], catalog: Orbits, nside: int = 32, propagator: Propagator = PYOORB(), - max_processes: int = 1, + max_processes: Optional[int] = None, + chunk_size: int = 100, ) -> TestOrbits: """ Given observations and a catalog of known orbits generate test orbits @@ -205,7 +288,9 @@ def generate_test_orbits( Parameters ---------- observations - Observations to generate test orbits for. + Observations for which to generate test orbits. These observations can + be an in-memory Observations object or a path to a parquet file containing the + observations. catalog Catalog of known orbits. nside @@ -215,14 +300,43 @@ def generate_test_orbits( max_processes Maximum number of processes to use while propagating orbits and generating ephemerides. + chunk_size + The maximum number of unique healpixels for which to generate test orbits per + process. This function will dynamically compute the chunk size based on the + number of unique healpixels and the number of processes. The dynamic chunk + size will never exceed the given value. Returns ------- test_orbits Test orbits generated from the catalog. """ - # Extract the minimum time from the observations - start_time = observations.coordinates.time.min() + time_start = time.perf_counter() + logger.info("Generating test orbits...") + + # If the input file is a string, read in the days column to + # extract the minimum time + if isinstance(observations, str): + table = pq.read_table( + observations, columns=["coordinates.time.days"], memory_map=True + ) + + min_day = pc.min(table["days"]).as_py() + # Set the start time to the midnight of the first night of observations + start_time = Timestamp.from_kwargs(days=[min_day], nanos=[0], scale="utc") + + elif isinstance(observations, Observations): + # Extract the minimum time from the observations + earliest_time = observations.coordinates.time.min() + + # Set the start time to the midnight of the first night of observations + start_time = Timestamp.from_kwargs( + days=earliest_time.days, nanos=[0], scale="utc" + ) + else: + raise ValueError( + f"observations must be a path to a parquet file or an Observations object. Got {type(observations)}." + ) # Propagate the orbits to the minimum time logger.info("Propagating orbits to the start time of the observations...") @@ -258,67 +372,94 @@ def generate_test_orbits( f"Ephemeris generation completed in {ephemeris_end_time - ephemeris_start_time:.3f} seconds." ) + if isinstance(observations, str): + table = pq.read_table( + observations, + columns=["coordinates.lon", "coordinates.lat"], + memory_map=True, + ) + lon = table["lon"].to_numpy(zero_copy_only=False) + lat = table["lat"].to_numpy(zero_copy_only=False) + del table + + else: + lon = observations.coordinates.lon.to_numpy(zero_copy_only=False) + lat = observations.coordinates.lat.to_numpy(zero_copy_only=False) + # Calculate the healpixels for observations and ephemerides + # Here we want the unique healpixels so we can cross match against our + # catalog's predicted ephemeris observations_healpixels = calculate_healpixels( - observations.coordinates.lon.to_numpy(zero_copy_only=False), - observations.coordinates.lat.to_numpy(zero_copy_only=False), + lon, + lat, nside=nside, ) - observations_healpixels = np.unique(observations_healpixels) + observations_healpixels = pc.unique(pa.array(observations_healpixels)) logger.info( f"Observations occur in {len(observations_healpixels)} unique healpixels." ) - # Calculate the healpixels for the ephemerides + # Calculate the healpixels for each ephemeris + # We do not want unique healpixels here because we want to + # select orbits from the same healpixel as the observations ephemeris_healpixels = calculate_healpixels( ephemeris.coordinates.lon.to_numpy(zero_copy_only=False), ephemeris.coordinates.lat.to_numpy(zero_copy_only=False), nside=nside, ) + ephemeris_healpixels = pa.array(ephemeris_healpixels) - # Filter the ephemerides to only those in the observations - ephemeris_mask = pa.array(np.in1d(ephemeris_healpixels, observations_healpixels)) - ephemeris_filtered = ephemeris.apply_mask(ephemeris_mask) - ephemeris_healpixels = ephemeris_healpixels[ - ephemeris_mask.to_numpy(zero_copy_only=False) - ] - logger.info( - f"{len(ephemeris_filtered)} orbit ephemerides overlap with the observations." - ) + # Dynamically compute the chunk size based on the number of healpixels + # and the number of processes + if max_processes is None: + max_processes = mp.cpu_count() - # Filter the orbits to only those in the ephemeris - orbits_filtered = propagated_orbits.apply_mask( - pc.is_in(propagated_orbits.orbit_id, ephemeris_filtered.orbit_id) + chunk_size = np.minimum( + np.ceil(len(observations_healpixels) / max_processes).astype(int), chunk_size ) - - logger.info("Selecting test orbits from the orbit catalog...") - test_orbits_list = [] - for healpixel in observations_healpixels: - healpixel_mask = pc.equal(ephemeris_healpixels, healpixel) - ephemeris_healpixel = ephemeris_filtered.apply_mask(healpixel_mask) - - if len(ephemeris_healpixel) == 0: - logger.debug(f"No ephemerides in healpixel {healpixel}.") - continue - - test_orbits_healpixel = select_test_orbits(ephemeris_healpixel, orbits_filtered) - - if len(test_orbits_healpixel) > 0: - test_orbits_list.append( - TestOrbits.from_kwargs( - orbit_id=test_orbits_healpixel.orbit_id, - object_id=test_orbits_healpixel.object_id, - coordinates=test_orbits_healpixel.coordinates, - bundle_id=[healpixel for _ in range(len(test_orbits_healpixel))], + logger.info(f"Generating test orbits with a chunk size of {chunk_size} healpixels.") + + test_orbits = TestOrbits.empty() + use_ray = initialize_use_ray(num_cpus=max_processes) + if use_ray: + + ephemeris_ref = ray.put(ephemeris) + ephemeris_healpixels_ref = ray.put(ephemeris_healpixels) + propagated_orbits_ref = ray.put(propagated_orbits) + + futures = [] + for healpixel_chunk in _iterate_chunks(observations_healpixels, chunk_size): + futures.append( + generate_test_orbits_worker_remote.remote( + healpixel_chunk, + ephemeris_healpixels_ref, + propagated_orbits_ref, + ephemeris_ref, ) ) - else: - logger.debug(f"No orbits in healpixel {healpixel}.") - if len(test_orbits_list) > 0: - test_orbits = qv.concatenate(test_orbits_list) + while futures: + finished, futures = ray.wait(futures, num_returns=1) + test_orbits = qv.concatenate([test_orbits, ray.get(finished[0])]) + if test_orbits.fragmented(): + test_orbits = qv.defragment(test_orbits) + else: - test_orbits = TestOrbits.empty() + for healpixel_chunk in _iterate_chunks(observations_healpixels, chunk_size): + test_orbits_chunk = generate_test_orbits_worker( + healpixel_chunk, + ephemeris_healpixels, + propagated_orbits, + ephemeris, + ) + test_orbits = qv.concatenate([test_orbits, test_orbits_chunk]) + if test_orbits.fragmented(): + test_orbits = qv.defragment(test_orbits) + + time_end = time.perf_counter() logger.info(f"Selected {len(test_orbits)} test orbits.") + logger.info( + f"Test orbit generation completed in {time_end - time_start:.3f} seconds." + ) return test_orbits