From 9e21f2c06dbe80ed0e2a195c595a17e5d0998854 Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Fri, 10 Jan 2025 09:19:58 -0500 Subject: [PATCH] Nt/seed propagate (#135) * add a seed for variant orbits in propagate_orbits * Ensure seed is passed to all VariantOrbits.create and sort varianats before collapse --- src/adam_core/propagator/propagator.py | 42 ++++++++++++++++++++------ 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/src/adam_core/propagator/propagator.py b/src/adam_core/propagator/propagator.py index b773121c..10efce04 100644 --- a/src/adam_core/propagator/propagator.py +++ b/src/adam_core/propagator/propagator.py @@ -290,6 +290,7 @@ def generate_ephemeris( num_samples: int = 1000, chunk_size: int = 100, max_processes: Optional[int] = 1, + seed: Optional[int] = None, ) -> Ephemeris: """ Generate ephemerides for each orbit in orbits as observed by each observer @@ -376,18 +377,20 @@ def generate_ephemeris( # Add variants to futures (if we have any) if covariance is True and not orbits.coordinates.covariance.is_all_nan(): variants = VariantOrbits.create( - orbits, method=covariance_method, num_samples=num_samples + orbits, + method=covariance_method, + num_samples=num_samples, + seed=seed, ) # Add variants to object store variants_ref = ray.put(variants) idx = np.arange(0, len(variants)) - for variant_chunk in _iterate_chunks(idx, chunk_size): - idx_chunk = ray.put(variant_chunk) + for variant_chunk_idx in _iterate_chunks(idx, chunk_size): futures.append( ephemeris_worker_ray.remote( - idx_chunk, + variant_chunk_idx, variants_ref, observers_ref, self.__class__, @@ -420,7 +423,10 @@ def generate_ephemeris( if covariance is True and not orbits.coordinates.covariance.is_all_nan(): variants = VariantOrbits.create( - orbits, method=covariance_method, num_samples=num_samples + orbits, + method=covariance_method, + num_samples=num_samples, + seed=seed, ) ephemeris_variants = self._generate_ephemeris(variants, observers) else: @@ -471,6 +477,7 @@ def propagate_orbits( num_samples: int = 1000, chunk_size: int = 100, max_processes: Optional[int] = 1, + seed: Optional[int] = None, ) -> Orbits: """ Propagate each orbit in orbits to each time in times. @@ -552,16 +559,19 @@ def propagate_orbits( # Add variants to propagate to futures if covariance is True and not orbits.coordinates.covariance.is_all_nan(): variants = VariantOrbits.create( - orbits, method=covariance_method, num_samples=num_samples + orbits, + method=covariance_method, + num_samples=num_samples, + seed=seed, ) + variants_ref = ray.put(variants) idx = np.arange(0, len(variants)) - for variant_chunk in _iterate_chunks(idx, chunk_size): - idx_chunk = ray.put(variant_chunk) + for variant_chunk_idx in _iterate_chunks(idx, chunk_size): futures.append( propagation_worker_ray.remote( - idx_chunk, + variant_chunk_idx, variants_ref, times_ref, self.__class__, @@ -587,6 +597,10 @@ def propagate_orbits( propagated = qv.concatenate(propagated_list) if len(variants_list) > 0: propagated_variants = qv.concatenate(variants_list) + # sort by variant_id and time + propagated_variants = propagated_variants.sort_by( + ["variant_id", "coordinates.time.days", "coordinates.time.nanos"] + ) else: propagated_variants = None @@ -595,9 +609,17 @@ def propagate_orbits( if covariance is True and not orbits.coordinates.covariance.is_all_nan(): variants = VariantOrbits.create( - orbits, method=covariance_method, num_samples=num_samples + orbits, + method=covariance_method, + num_samples=num_samples, + seed=seed, ) + propagated_variants = self._propagate_orbits(variants, times) + # sort by variant_id and time + propagated_variants = propagated_variants.sort_by( + ["variant_id", "coordinates.time.days", "coordinates.time.nanos"] + ) else: propagated_variants = None