Skip to content

Commit

Permalink
Add initial version of collision detection
Browse files Browse the repository at this point in the history
Co-authored-by: Kathleen Kiker <[email protected]>
Co-authored-by: Joachim Moeyens <[email protected]>
  • Loading branch information
3 people committed Feb 20, 2025
1 parent 20aa441 commit c5a0c74
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 65 deletions.
197 changes: 138 additions & 59 deletions src/adam_core/dynamics/impacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pyarrow.compute as pc
import quivr as qv

from adam_core.constants import KM_P_AU
from adam_core.constants import Constants as c
from adam_core.coordinates import CartesianCoordinates
from adam_core.ray_cluster import initialize_use_ray

Expand All @@ -21,6 +23,12 @@

logger = logging.getLogger(__name__)

C = c.C

# Use the Earth's equatorial radius as used in DE4XX ephemerides
# adam_core defines it in au but we need it in km
EARTH_RADIUS_KM = c.R_EARTH_EQUATORIAL * KM_P_AU

RAY_INSTALLED = False
try:
import ray
Expand All @@ -37,10 +45,11 @@
def impact_worker_ray(idx_chunk, orbits, propagator_class, num_days):
prop = propagator_class()
orbits_chunk = orbits.take(idx_chunk)
variants, impacts = prop._detect_impacts(orbits_chunk, num_days)
variants, impacts = prop._detect_collisions(orbits_chunk, num_days)
return variants, impacts


# Remove after adam-assist gets updated
class EarthImpacts(qv.Table):
#: Orbit ID
orbit_id = qv.StringColumn()
Expand All @@ -51,17 +60,9 @@ class EarthImpacts(qv.Table):
#: Earth-centered, Earth-fixed coordinates [ECEF - ITRF93] of the impact
impact_coordinates = SphericalCoordinates.as_column()

def preview(self) -> None:
"""
Plot the risk corridor for the given impacts.
"""
from .plots import plot_risk_corridor

fig = plot_risk_corridor(self, title="Risk Corridor")
fig.show()


class ImpactProbabilities(qv.Table):
condition_id = qv.LargeStringColumn()
orbit_id = qv.LargeStringColumn()
impacts = qv.Int64Column()
variants = qv.Int64Column()
Expand All @@ -72,39 +73,75 @@ class ImpactProbabilities(qv.Table):
maximum_impact_time = Timestamp.as_column(nullable=True)


class CollisionConditions(qv.Table):
condition_id = qv.LargeStringColumn()
collision_object_name = qv.LargeStringColumn()
collision_distance = qv.Float64Column()
stopping_condition = qv.BooleanColumn()


class CollisionEvent(qv.Table):
orbit_id = qv.LargeStringColumn()
variant_id = qv.LargeStringColumn(nullable=True)
coordinates = CartesianCoordinates.as_column()
condition_id = qv.LargeStringColumn()
collision_object_name = qv.LargeStringColumn()
collision_coordinates = SphericalCoordinates.as_column()
stopping_condition = qv.BooleanColumn()

def preview(self) -> None:
"""
Plot the risk corridor for the given impacts.
"""
from .plots import plot_risk_corridor

fig = plot_risk_corridor(self, title="Risk Corridor")
fig.show()


class ImpactMixin:
"""
`~adam_core.propagator.Propagator` mixin with signature for detecting Earth impacts.
Subclasses should implement the _detect_impacts method.
Subclasses should implement the _detect_collisions method.
"""

@abstractmethod
def _detect_impacts(
self, orbits: Orbits, num_days: float
) -> Tuple[OrbitType, EarthImpacts]:
def _detect_collisions(
self,
orbits: Orbits,
num_days: float,
conditions: Optional[CollisionConditions] = None,
) -> Tuple[OrbitType, CollisionEvent]:
"""
Detect impacts for the given orbits.
Detect collisions for the given orbits.
THIS FUNCTION SHOULD BE DEFINED BY THE USER.
THIS FUNCTION SHOULD NOT BE OVERRIDDEN BY THE USER.
"""
pass

def detect_impacts(
def detect_collisions(
self,
orbits: OrbitType,
num_days: int,
conditions: Optional[CollisionConditions] = None,
max_processes: Optional[int] = 1,
chunk_size: int = 100,
) -> Tuple[OrbitType, EarthImpacts]:
chunk_size: Optional[int] = 100,
) -> Tuple[OrbitType, CollisionConditions]:
"""
Detect impacts for each orbit in orbits after num_days.
Detect collisions for each orbit in orbits after num_days.
Parameters
----------
orbits : `~adam_core.orbits.orbits.Orbits` (N)
Orbits for which to detect impacts.
num_days : int
Number of days after which to detect impacts.
conditions : `~adam_core.orbits.earth_impacts.CollisionConditions`
Conditions for detecting collisions, including:
- condition_id: Unique identifier for the condition.
- collision_object_name: Name of the object with which to detect collisions.
- collision_distance: Distance from the object at which to detect collisions.
- stopping_condition: Whether to stop propagation after a collision.
max_processes : int or None, optional
Maximum number of processes to launch. If None then the number of
processes will be equal to the number of cores on the machine. If 1
Expand All @@ -114,11 +151,19 @@ def detect_impacts(
-------
propagated : `~adam_core.orbits.OrbitType`
The input orbits propagated to the end of simulation.
impacts : `~adam_core.orbits.earth_impacts.EarthImpacts`
Impacts detected for the orbits.
impacts : `~adam_core.orbits.earth_impacts.CollisionEvent`
Impacts/collisions detected for the orbits. Includes:
- orbit_id: Unique identifier for the orbit.
- distance: Distance from the collision object.
- coordinates: Cartesian coordinates of the impact.
- variant_id: Unique identifier for the variant.
- condition_id: Unique identifier for the condition.
- collision_object_name: Name of the object with which collisions were detected.
- collision_distance: Distance from the object at which collisions were detected.
- stopping_condition: Whether the propagation was stopped after a collision.
"""
if max_processes is None or max_processes > 1:
impact_list: List[EarthImpacts] = []
impact_list: List[CollisionConditions] = []
propagated_list: List[OrbitType] = []

if RAY_INSTALLED is False:
Expand Down Expand Up @@ -161,7 +206,9 @@ def detect_impacts(
impacts = qv.concatenate(impact_list)

else:
propagated, impacts = self._detect_impacts(orbits, num_days)
propagated, impacts = self._detect_collisions(
orbits, num_days, conditions=conditions
)

return propagated, impacts

Expand All @@ -173,7 +220,8 @@ def calculate_impacts(
num_samples: int = 1000,
processes: Optional[int] = None,
seed: Optional[int] = None,
) -> Tuple[Orbits, EarthImpacts]:
impact_conditions: Optional[CollisionConditions] = None,
) -> Tuple[OrbitType, CollisionEvent]:
"""
Calculate the impacts for each variant orbit generated from the input orbits.
Expand Down Expand Up @@ -203,18 +251,28 @@ def calculate_impacts(
variants = VariantOrbits.create(
orbits, method="monte-carlo", num_samples=num_samples, seed=seed
)
if impact_conditions is None:
impact_conditions = CollisionConditions.from_kwargs(
condition_id=["Default"],
collision_object_name=["Earth"],
collision_distance=[EARTH_RADIUS_KM],
stopping_condition=[True],
)
logger.info("Detecting impacts...")
results, impacts = propagator.detect_impacts(
results, collisions = propagator.detect_collisions(
variants,
num_days,
impact_conditions,
max_processes=processes,
)

return results, impacts
return results, collisions


def calculate_impact_probabilities(
variants: VariantOrbits, impacts: EarthImpacts
variants: VariantOrbits,
collision_events: CollisionEvent,
conditions: Optional[CollisionConditions] = None,
) -> ImpactProbabilities:
"""
Calculate the impact probabilities for each variant orbit generated from the input orbits.
Expand All @@ -230,49 +288,70 @@ def calculate_impact_probabilities(
Impact probabilities for the variant orbits.
"""

if conditions is None:
conditions = CollisionConditions.from_kwargs(
condition_id=["Default"],
collision_object_name=["Earth"],
collision_distance=[EARTH_RADIUS_KM],
stopping_condition=[True],
)

# Loop through the unique set of orbit_ids within variants using quivr
unique_orbits = pc.unique(variants.orbit_id).to_pylist()

earth_impact_probabilities = ImpactProbabilities.empty()
impact_probabilities = None

for orbit_id in unique_orbits:
# mask = pc.equal(variants.orbit_id, orbit_id)
variant_masked = variants.select("orbit_id", orbit_id)
variant_count = len(variant_masked)
impacts_masked = collision_events.select("orbit_id", orbit_id)

impacts_masked = impacts.select("orbit_id", orbit_id)
impact_count = len(impacts_masked)
for unique_condition in conditions:
condition_id = unique_condition.condition_id[0]
impacts_per_condition = impacts_masked.select("condition_id", condition_id)
impact_count = len(impacts_per_condition)

if len(impacts_masked) > 0:
impact_mjds = impacts_masked.coordinates.time.mjd().to_numpy(
zero_copy_only=False
)
mean_mjd = Timestamp.from_mjd(
[np.mean(impact_mjds)], scale=impacts_masked.coordinates.time.scale
if len(impacts_per_condition) > 0:
impact_mjds = impacts_per_condition.coordinates.time.mjd().to_numpy(
zero_copy_only=False
)
mean_mjd = Timestamp.from_mjd(
[np.mean(impact_mjds)],
scale=impacts_per_condition.coordinates.time.scale,
)
stddev = np.std(impact_mjds)
min_mjd = impacts_per_condition.coordinates.time.min()
max_mjd = impacts_per_condition.coordinates.time.max()
else:
mean_mjd = Timestamp.nulls(
1, scale=impacts_per_condition.coordinates.time.scale
)
stddev = None
min_mjd = Timestamp.nulls(
1, scale=impacts_per_condition.coordinates.time.scale
)
max_mjd = Timestamp.nulls(
1, scale=impacts_per_condition.coordinates.time.scale
)

ip = ImpactProbabilities.from_kwargs(
condition_id=[condition_id],
orbit_id=[orbit_id],
impacts=[impact_count],
variants=[variant_count],
cumulative_probability=[impact_count / variant_count],
mean_impact_time=mean_mjd,
stddev_impact_time=[stddev],
minimum_impact_time=min_mjd,
maximum_impact_time=max_mjd,
)
stddev = np.std(impact_mjds)
min_mjd = impacts_masked.coordinates.time.min()
max_mjd = impacts_masked.coordinates.time.max()
else:
mean_mjd = Timestamp.nulls(1, scale=impacts_masked.coordinates.time.scale)
stddev = None
min_mjd = Timestamp.nulls(1, scale=impacts_masked.coordinates.time.scale)
max_mjd = Timestamp.nulls(1, scale=impacts_masked.coordinates.time.scale)

ip = ImpactProbabilities.from_kwargs(
orbit_id=[orbit_id],
impacts=[impact_count],
variants=[variant_count],
cumulative_probability=[impact_count / variant_count],
mean_impact_time=mean_mjd,
stddev_impact_time=[stddev],
minimum_impact_time=min_mjd,
maximum_impact_time=max_mjd,
)

earth_impact_probabilities = qv.concatenate([earth_impact_probabilities, ip])
if impact_probabilities is None:
impact_probabilities = ip
else:
impact_probabilities = qv.concatenate([impact_probabilities, ip])

return earth_impact_probabilities
return impact_probabilities


def link_impacting_variants(variants, impacts):
Expand Down
Loading

0 comments on commit c5a0c74

Please sign in to comment.