"""
SSAPy acceleration-ladder runner + accel-ladder-specific divergence dashboard.
Design in this file:
- calculate_accel_comparisons(): the ONE function that does *all* propagation + math
- make_accel_ladder_dashboard_figures(): plotting only (always called)
- compare_models(): workflow wrapper that calls exactly:
1) calculate_accel_comparisons()
2) make_accel_ladder_dashboard_figures()
(no optional plotting; make_plots is always True)
Early-stop behavior:
- If SSAPy terminates early (e.g., collision/event), r_hist/v_hist will be shorter than `times`.
- We truncate all rungs to the common prefix length (minimum length across rungs) so
comparisons remain aligned. We also return per-rung stop indices/times.
Dashboard:
- Figure 1: time-domain comparisons (divergence vs ref, incremental per rung, worst-rung NTW)
* time-domain model lines alternate solid/dashed to help reveal overlaps
* includes a top-left text box with the initial Keplerian elements (in ax[0])
- Figure 2: rung-summary with:
(A) final ||dr|| vs ref by rung (colored points + colored rung index annotations)
(B) final ||dr_inc|| by rung (colored points + colored rung index annotations)
(C) heatmap of log10(||dr|| vs ref) with color-coded rung index y-tick labels
plus a header area above the subplots:
- LEFT: accel ladder key with color-coded indices (figure-level, left aligned)
- RIGHT: initial Keplerian elements (figure-level, top-right)
Notes:
- Log plots use a floor epsilon_m (default 1 mm = 1e-3 m).
- Heatmap colorbar is capped at max_error_m (default 1e7 m).
"""
from __future__ import annotations
import numpy as np
from ..constants import EARTH_MU
# =============================================================================
# Time helpers (used only by calculate_accel_comparisons)
# =============================================================================
def _is_astropy_time(x):
return x.__class__.__module__.startswith("astropy.time") and x.__class__.__name__ == "Time"
def _to_astropy_time(x):
from astropy.time import Time
return Time(x)
def _orbit_epoch_gps(orbit):
t0 = getattr(orbit, "t", None)
if t0 is None:
raise ValueError("Orbit object is missing attribute `.t` (epoch).")
if _is_astropy_time(t0):
return float(t0.gps)
return float(np.asarray(t0).reshape(()))
def _coerce_times_for_ssapy(times, orbit_epoch_gps_s, assume="auto"):
"""
Return SSAPy-compatible times:
- astropy Time, or
- float GPS seconds (since 1980-01-06)
If numeric and assume="auto":
- treat as offsets if values look "small" relative to GPS seconds.
"""
if _is_astropy_time(times):
return times
if isinstance(times, (list, tuple)):
times = np.array(times, dtype=object)
if isinstance(times, np.ndarray) and times.dtype.kind in {"M", "m"}:
return _to_astropy_time(times)
if isinstance(times, np.ndarray) and times.dtype == object:
sample = times.flat[0]
if _is_astropy_time(sample):
return sample.__class__(times)
if hasattr(sample, "year") or isinstance(sample, np.datetime64):
return _to_astropy_time(times)
t = np.asarray(times, dtype=float).ravel()
if assume == "gps":
return t
if assume == "offset":
return float(orbit_epoch_gps_s) + t
if (np.nanmax(np.abs(t)) < 1e8) and (orbit_epoch_gps_s > 1e8):
return float(orbit_epoch_gps_s) + t
return t
def _times_to_relative_seconds(times_ssapy):
if _is_astropy_time(times_ssapy):
dt = (times_ssapy - times_ssapy[0]).to_value("s")
return np.asarray(dt, dtype=float).ravel()
t = np.asarray(times_ssapy, dtype=float).ravel()
return t - float(t[0])
# =============================================================================
# Math helpers (used only by calculate_accel_comparisons)
# =============================================================================
def _keplerian_elements_from_rv(r_m, v_mps, mu_m3s2=3.986004418e14):
"""
Classical Keplerian elements from inertial r,v. Returns degrees for angles.
"""
r = np.asarray(r_m, dtype=float).reshape(3)
v = np.asarray(v_mps, dtype=float).reshape(3)
mu = float(mu_m3s2)
tiny = 1e-14
rnorm = float(np.linalg.norm(r))
vnorm = float(np.linalg.norm(v))
h = np.cross(r, v)
hnorm = float(np.linalg.norm(h))
k = np.array([0.0, 0.0, 1.0], dtype=float)
n = np.cross(k, h)
nnorm = float(np.linalg.norm(n))
e_vec = (np.cross(v, h) / mu) - (r / rnorm)
e = float(np.linalg.norm(e_vec))
energy = 0.5 * (vnorm * vnorm) - mu / rnorm
if abs(energy) > tiny:
a = -mu / (2.0 * energy)
else:
a = np.inf
if hnorm > tiny:
inc = float(np.arccos(np.clip(h[2] / hnorm, -1.0, 1.0)))
else:
inc = 0.0
if nnorm > tiny:
raan = float(np.arccos(np.clip(n[0] / nnorm, -1.0, 1.0)))
if n[1] < 0.0:
raan = 2.0 * np.pi - raan
else:
raan = 0.0
if nnorm > tiny and e > 1e-10:
argp = float(np.arccos(np.clip(np.dot(n, e_vec) / (nnorm * e), -1.0, 1.0)))
if e_vec[2] < 0.0:
argp = 2.0 * np.pi - argp
else:
argp = 0.0
if e > 1e-10:
nu = float(np.arccos(np.clip(np.dot(e_vec, r) / (e * rnorm), -1.0, 1.0)))
if np.dot(r, v) < 0.0:
nu = 2.0 * np.pi - nu
else:
if nnorm > tiny:
u = float(np.arccos(np.clip(np.dot(n, r) / (nnorm * rnorm), -1.0, 1.0)))
if r[2] < 0.0:
u = 2.0 * np.pi - u
nu = u
else:
lam = float(np.arccos(np.clip(r[0] / rnorm, -1.0, 1.0)))
if r[1] < 0.0:
lam = 2.0 * np.pi - lam
nu = lam
if np.isfinite(a) and a > 0.0:
period_s = float(2.0 * np.pi * np.sqrt((a * a * a) / mu))
else:
period_s = np.nan
return {
"a_m": float(a),
"e": float(e),
"i_deg": float(np.degrees(inc)),
"raan_deg": float(np.degrees(raan)),
"argp_deg": float(np.degrees(argp)),
"nu_deg": float(np.degrees(nu)),
"period_s": period_s,
}
def _format_oe_text(oe):
a_km = oe["a_m"] / 1e3 if np.isfinite(oe["a_m"]) else np.nan
e = oe["e"]
inc = oe["i_deg"]
raan = oe["raan_deg"]
argp = oe["argp_deg"]
nu = oe["nu_deg"]
T = oe["period_s"]
if np.isfinite(T):
T_min = T / 60.0
tline = f"T = {T_min:.2f} min"
else:
tline = "T = n/a"
return (
"Initial Keplerian OE\n"
f"a = {a_km:.3f} km\n"
f"e = {e:.6f}\n"
f"i = {inc:.3f} deg\n"
f"RAAN = {raan:.3f} deg\n"
f"argp = {argp:.3f} deg\n"
f"nu = {nu:.3f} deg\n"
f"{tline}"
)
def _ntw_components(r_ref, v_ref, dr):
"""
Project dr into an NTW basis built from reference (r_ref, v_ref).
T along velocity
W along (r x v) (orbit normal)
N along v x (r x v) (in-plane, perpendicular to T)
Returns (N,T,W) components in meters.
"""
r_ref = np.asarray(r_ref, dtype=float)
v_ref = np.asarray(v_ref, dtype=float)
dr = np.asarray(dr, dtype=float)
def _normed(x):
n = np.linalg.norm(x, axis=-1, keepdims=True)
n = np.where(n == 0.0, 1.0, n)
return x / n
t_hat = _normed(v_ref)
w_hat = _normed(np.cross(r_ref, v_ref))
n_hat = _normed(np.cross(v_ref, np.cross(r_ref, v_ref)))
n = np.einsum("...i,...i->...", n_hat, dr)
t = np.einsum("...i,...i->...", t_hat, dr)
w = np.einsum("...i,...i->...", w_hat, dr)
return np.stack([n, t, w], axis=-1)
# =============================================================================
# The ONE calculation function you asked for
# =============================================================================
[docs]
def calculate_accel_comparisons(
orbit=None,
r=None,
v=None,
t0=None,
times=None,
assume_times="auto",
ode_kwargs=None,
reference=None,
mu_m3s2=3.986004418e14,
):
"""
All propagation + all divergence math in ONE function.
Returns a dict containing:
- labels, reference
- times_ssapy, t_rel_s (aligned to common prefix), t_rel_s_full
- r_list, v_list (aligned)
- early-stop metadata: stop_idx, stop_t_s, common_len
- orbit elements + text
- divergence arrays/metrics: drn_vs_ref, drn_inc, final_drn_vs_ref, final_drn_inc, worst_idx, ntw_worst, final_ntw_abs
"""
if times is None:
raise ValueError("You must provide `times`.")
import ssapy
from ..SSAPy_wrappers.accel_ladder import ssapy_accel_ladder
from ssapy.compute import rv as ssapy_rv
from ssapy.propagator import SciPyPropagator
# Orbit construction if needed
if orbit is None:
if r is None or v is None or t0 is None:
raise ValueError("Provide either `orbit` OR (`r`, `v`, `t0`).")
orbit = ssapy.Orbit(r=np.asarray(r, dtype=float), v=np.asarray(v, dtype=float), t=t0)
r0 = getattr(orbit, "r", None)
v0 = getattr(orbit, "v", None)
if r0 is None or v0 is None:
raise ValueError("Orbit object is missing `.r` and/or `.v` needed for initial elements.")
# Initial elements
oe = _keplerian_elements_from_rv(r0, v0, mu_m3s2=mu_m3s2)
oe_text = _format_oe_text(oe)
# Ladder models
ladder = ssapy_accel_ladder()
labels = list(ladder.keys())
if reference is None:
reference = max(0, len(labels) - 1)
reference = int(reference)
if reference < 0 or reference >= len(labels):
raise ValueError("reference index out of range.")
# Time coercion
epoch_gps = _orbit_epoch_gps(orbit)
times_ssapy = _coerce_times_for_ssapy(times, epoch_gps, assume=assume_times)
t_rel_s_full = _times_to_relative_seconds(times_ssapy)
# Propagate each rung
r_list = []
v_list = []
for name in labels:
accel = ladder[name]
prop = SciPyPropagator(accel, ode_kwargs=None if ode_kwargs is None else dict(ode_kwargs))
r_hist, v_hist = ssapy_rv(orbit, times_ssapy, prop)
r_hist = np.asarray(r_hist, dtype=float)
v_hist = np.asarray(v_hist, dtype=float)
if r_hist.ndim == 3:
r_hist = r_hist[0]
v_hist = v_hist[0]
r_list.append(r_hist)
v_list.append(v_hist)
# Early-stop alignment (common prefix)
lengths = np.array([int(np.asarray(rh).shape[0]) for rh in r_list], dtype=int)
if np.any(lengths <= 0):
raise ValueError("One or more rungs returned empty histories (length 0).")
common_len = int(np.min(lengths))
r_list = [np.asarray(rh, dtype=float)[:common_len] for rh in r_list]
v_list = [np.asarray(vh, dtype=float)[:common_len] for vh in v_list]
t_rel_s = np.asarray(t_rel_s_full, dtype=float)[:common_len]
stop_idx = lengths - 1
stop_idx_clip = np.clip(stop_idx, 0, t_rel_s_full.size - 1)
stop_t_s = np.asarray(t_rel_s_full, dtype=float)[stop_idx_clip]
# Divergence metrics
n_models = len(r_list)
r_ref = r_list[reference]
v_ref = v_list[reference]
drn_vs_ref = np.zeros((n_models, t_rel_s.size), dtype=float)
for i in range(n_models):
drn_vs_ref[i] = np.linalg.norm(r_list[i] - r_ref, axis=-1)
drn_inc = np.zeros_like(drn_vs_ref)
for i in range(1, n_models):
drn_inc[i] = np.linalg.norm(r_list[i] - r_list[i - 1], axis=-1)
final_drn_vs_ref = drn_vs_ref[:, -1]
final_drn_inc = drn_inc[:, -1]
idx_candidates = np.arange(n_models)
idx_candidates = idx_candidates[idx_candidates != reference]
worst_idx = int(idx_candidates[np.argmax(final_drn_vs_ref[idx_candidates])]) if idx_candidates.size else reference
ntw_worst = _ntw_components(r_ref, v_ref, (r_list[worst_idx] - r_ref))
final_ntw_abs = np.zeros((n_models, 3), dtype=float)
for i in range(n_models):
ntw_i = _ntw_components(r_ref, v_ref, (r_list[i] - r_ref))
final_ntw_abs[i] = np.abs(ntw_i[-1])
return {
# Core ladder outputs
"labels": labels,
"reference": reference,
"times_ssapy": times_ssapy,
"t_rel_s": t_rel_s,
"t_rel_s_full": t_rel_s_full,
"r_list": r_list,
"v_list": v_list,
# Early-stop metadata
"stop_idx": stop_idx,
"stop_t_s": stop_t_s,
"common_len": common_len,
# Orbit elements
"orbit_elements": oe,
"orbit_elements_text": oe_text,
# Divergence products (for any plot)
"drn_vs_ref": drn_vs_ref,
"drn_inc": drn_inc,
"final_drn_vs_ref": final_drn_vs_ref,
"final_drn_inc": final_drn_inc,
"worst_idx": worst_idx,
"ntw_worst": ntw_worst,
"final_ntw_abs": final_ntw_abs,
}
# =============================================================================
# Plotting (plot-only)
# =============================================================================
def _nice_vivid_colors(n):
import matplotlib.pyplot as plt
if n <= 10:
cmap = plt.get_cmap("tab10")
return cmap(np.arange(n))
if n <= 20:
cmap = plt.get_cmap("tab20")
return cmap(np.arange(n))
hsv = plt.get_cmap("hsv")
return hsv(np.linspace(0.0, 1.0, n, endpoint=False))
def _draw_ladder_key_two_columns(
fig,
labels,
colors,
header_x,
header_y,
x_left,
x_right,
y_top,
line_height,
idx_dx=0.045,
header_fontsize=14,
idx_fontsize=12,
label_fontsize=11,
title_gap=0.030,
):
labels = list(labels)
n = len(labels)
half = int(np.ceil(n / 2.0))
fig.text(
header_x,
header_y,
"Accel ladder key (rung index -> model)",
ha="left",
va="top",
fontsize=header_fontsize,
)
def _draw_column(start_i, end_i, x0):
for row, i in enumerate(range(start_i, end_i)):
y = (y_top - title_gap) - row * line_height
idx_txt = f"{i:2d}: "
fig.text(x0, y, idx_txt, ha="left", va="top", fontsize=idx_fontsize, color=colors[i])
fig.text(x0 + idx_dx, y, str(labels[i]), ha="left", va="top", fontsize=label_fontsize, color="black")
_draw_column(0, half, x_left)
_draw_column(half, n, x_right)
# =============================================================================
# compare_models: calls calculation + plot (non-optional)
# =============================================================================
[docs]
def compare_models(
orbit=None,
r=None,
v=None,
t0=None,
times=None,
assume_times="auto",
ode_kwargs=None,
reference=None,
plot_title="SSAPy accel ladder divergences",
show_legend=True,
epsilon_m=1e-3,
max_error_m=1e7,
mu_m3s2=EARTH_MU,
):
"""
Workflow wrapper:
- always computes
- always plots
- only calls:
1) calculate_accel_comparisons()
2) make_accel_ladder_dashboard_figures()
"""
calc = calculate_accel_comparisons(
orbit=orbit,
r=r,
v=v,
t0=t0,
times=times,
assume_times=assume_times,
ode_kwargs=ode_kwargs,
reference=reference,
mu_m3s2=mu_m3s2,
)
figs = make_accel_ladder_dashboard_figures(
calc=calc,
plot_title=plot_title,
show_legend=show_legend,
epsilon_m=epsilon_m,
max_error_m=max_error_m,
)
dashboard = {"calc": calc}
dashboard.update(figs)
return {
"labels": calc["labels"],
"times_ssapy": calc["times_ssapy"],
"t_rel_s": calc["t_rel_s"],
"t_rel_s_full": calc["t_rel_s_full"],
"r_list": calc["r_list"],
"v_list": calc["v_list"],
"orbit_elements": calc["orbit_elements"],
"orbit_elements_text": calc["orbit_elements_text"],
"stop_idx": calc["stop_idx"],
"stop_t_s": calc["stop_t_s"],
"common_len": calc["common_len"],
"dashboard": dashboard,
}