# --- Standard library ---
import io
import os
import re
from enum import Enum, auto
from numbers import Real
# --- Third-party ---
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.colors import cnames, to_rgb, rgb2hex
from PIL import Image as PILImage
from PyPDF2 import PdfMerger
from IPython.display import Image as IPythonImage, display as ipython_display
from astropy.time import Time
from erfa import gst94
# --- Local modules ---
from ssapy.utils import find_file
from ..constants import EARTH_RADIUS, MOON_RADIUS
from ..vectors import rotation_matrix_from_vectors
[docs]
class VarType(Enum):
NONE = auto()
TIME = auto()
ARRAY = auto()
LIST_ARRAYS = auto()
LIST_LISTS = auto()
MIXED_LIST = auto()
OTHER = auto()
[docs]
def is_list_of_arrays(lst):
return all(isinstance(item, np.ndarray) for item in lst)
[docs]
def is_list_of_lists(lst):
return all(isinstance(item, list) for item in lst)
[docs]
def check_type(var):
"""
Classify 'var' into one of VarType cases.
"""
if var is None:
return VarType.NONE
if isinstance(var, Time):
return VarType.TIME
if isinstance(var, np.ndarray):
return VarType.ARRAY
if isinstance(var, list):
if len(var) == 0:
return VarType.OTHER
if is_list_of_arrays(var):
if all(isinstance(item.flat[0], Time) for item in var if item.size > 0):
return VarType.TIME
return VarType.LIST_ARRAYS
if is_list_of_lists(var):
return VarType.LIST_LISTS
if all(isinstance(item, Time) for item in var):
return VarType.TIME
return VarType.MIXED_LIST
return VarType.OTHER
[docs]
def valid_orbits(r, t, drop_empty=True, warn=True):
"""
Normalize r and t into parallel lists of shape-(n,3) ndarrays and astropy Time objects.
Accepts:
r:
- (3,), (1,3), (N,3), (B,N,3) ndarray
- list/tuple of any of the above
t:
- None
- scalar (float/int) interpreted as GPS seconds (broadcast)
- ndarray of GPS seconds: (N,) or (B,N)
- astropy.time.Time (scalar or array)
- list/tuple of scalars/ndarrays/Time, matching number of tracks
Additionally (if drop_empty=True):
* removes empty r-tracks (Ni==0)
* if t is a per-track list/tuple with the same length as the original r_list,
removes the corresponding t entries too
* prints a warning when any are removed
Returns:
r_list: list[np.ndarray] where each is (Ni,3)
t_list: list[astropy.time.Time] where each has len Ni
"""
def _to_track_list_r(r_in):
# returns list of (N,3) float arrays
if isinstance(r_in, (list, tuple)):
out = []
for item in r_in:
out.extend(_to_track_list_r(item))
return out
arr = np.asarray(r_in, dtype=float).squeeze()
# Single position vector
if arr.ndim == 1 and arr.size == 3:
return [arr.reshape(1, 3)]
# Row/col vector forms
if arr.ndim == 2 and arr.shape in [(1, 3), (3, 1)]:
return [arr.reshape(1, 3)]
# Standard track
if arr.ndim == 2 and arr.shape[1] == 3:
return [arr]
# Batched tracks
if arr.ndim == 3 and arr.shape[2] == 3:
return [arr[k] for k in range(arr.shape[0])]
raise ValueError(f"valid_orbits: cannot interpret r with shape {arr.shape}")
def _to_time_track_list(t_in, r_list):
"""
Return list[Time] matching r_list length.
"""
n_tracks = len(r_list)
# None -> dummy gps time arrays (zeros)
if t_in is None:
return [Time(np.zeros(len(rr), dtype=float), format="gps") for rr in r_list]
# Single Time object -> broadcast if needed
if isinstance(t_in, Time):
if t_in.isscalar:
return [Time(np.full(len(rr), t_in.gps, dtype=float), format="gps") for rr in r_list]
# If array Time:
if n_tracks == 1:
if len(t_in) != len(r_list[0]):
raise ValueError("valid_orbits: Time length must match r length")
return [t_in]
# Broadcast same Time array to all tracks (must match all)
if not all(len(t_in) == len(rr) for rr in r_list):
raise ValueError("valid_orbits: single Time array length must match all r tracks")
return [t_in for _ in r_list]
# Scalar numeric -> GPS seconds broadcast
if isinstance(t_in, (int, float, np.integer, np.floating)):
val = float(t_in)
return [Time(np.full(len(rr), val, dtype=float), format="gps") for rr in r_list]
# ndarray of gps seconds
if isinstance(t_in, np.ndarray):
arr = np.asarray(t_in)
if not np.issubdtype(arr.dtype, np.number):
raise TypeError("valid_orbits: ndarray t must be numeric GPS seconds")
if arr.ndim == 0:
val = float(arr)
return [Time(np.full(len(rr), val, dtype=float), format="gps") for rr in r_list]
if arr.ndim == 1:
if n_tracks == 1:
if arr.shape[0] != len(r_list[0]):
raise ValueError("valid_orbits: t length must match r length")
return [Time(arr.astype(float), format="gps")]
# Broadcast one time vector to all tracks if it matches all
if not all(arr.shape[0] == len(rr) for rr in r_list):
raise ValueError("valid_orbits: single t-array length must match all r tracks")
tt = Time(arr.astype(float), format="gps")
return [tt for _ in r_list]
if arr.ndim == 2:
# Batched times: (B,N)
if arr.shape[0] != n_tracks:
raise ValueError("valid_orbits: batched t must have same number of tracks as r")
out = []
for rr, row in zip(r_list, arr):
if row.shape[0] != len(rr):
raise ValueError("valid_orbits: each batched t row must match its r track length")
out.append(Time(np.asarray(row, dtype=float), format="gps"))
return out
raise ValueError("valid_orbits: ndarray t must be 0D, 1D, or 2D")
# list/tuple: per-track specification, or a single element to broadcast
if isinstance(t_in, (list, tuple)):
if len(t_in) == 1 and n_tracks > 1:
# broadcast single element if possible
return _to_time_track_list(t_in[0], r_list)
if len(t_in) != n_tracks:
raise ValueError("valid_orbits: number of t entries must equal number of r tracks")
out = []
for rr, ti in zip(r_list, t_in):
# recurse per item but ensure it yields exactly 1 track
ti_list = _to_time_track_list(ti, [rr])
if len(ti_list) != 1:
raise ValueError("valid_orbits: each per-track t entry must map to exactly one Time array")
out.append(ti_list[0])
return out
raise TypeError(f"valid_orbits: unsupported type for t: {type(t_in)}")
# 1) normalize r
r_list = _to_track_list_r(r)
# 2) optionally drop empty r tracks and corresponding per-track t entries
if drop_empty:
empty_idx = [i for i, rr in enumerate(r_list) if len(rr) == 0]
if empty_idx:
if warn:
print(f"valid_orbits warning: removed {len(empty_idx)} empty orbit track(s) at indices {empty_idx}")
# If t is per-track (same length as r_list), drop corresponding t entries too
if isinstance(t, (list, tuple)) and len(t) == len(r_list):
t = [ti for i, ti in enumerate(t) if i not in empty_idx]
r_list = [rr for i, rr in enumerate(r_list) if i not in empty_idx]
# If everything is empty, return early (avoids t-shape errors)
if len(r_list) == 0:
if warn:
print("valid_orbits warning: all orbit tracks were empty; returning empty lists.")
return [], []
# 3) normalize t against the filtered r_list
t_list = _to_time_track_list(t, r_list)
# 4) final length sanity
for rr, tt in zip(r_list, t_list):
if len(rr) != len(tt):
raise ValueError("valid_orbits: length mismatch after normalization")
# 5) shape print
try:
print(f"Returning arrays shaped: {np.shape(r_list)}, {np.shape(t_list)}")
except Exception as e:
print(
"Returning arrays with varying shapes: "
f"type(r_list)={type(r_list)}, type(t_list)={type(t_list)}, error={e}"
)
return r_list, t_list
[docs]
def load_earth_file():
earth = PILImage.open(find_file("earth", ext=".png"))
earth = earth.resize((5400 // 5, 2700 // 5))
return earth
[docs]
def drawEarth(time, ngrid=100, R=EARTH_RADIUS, rfactor=1):
"""
Parameters
----------
time : array_like or astropy.time.Time (n,)
If float (array), then should correspond to GPS seconds;
i.e., seconds since 1980-01-06 00:00:00 UTC
ngrid : int
Number of grid points in Earth model.
R : float
Earth radius in meters. Default is WGS84 value.
rfactor : float
Factor by which to enlarge Earth (for visualization purposes)
"""
import ipyvolume as ipv
earth = load_earth_file()
lat = np.linspace(-np.pi / 2, np.pi / 2, ngrid)
lon = np.linspace(-np.pi, np.pi, ngrid)
lat, lon = np.meshgrid(lat, lon)
x = np.cos(lat) * np.cos(lon)
y = np.cos(lat) * np.sin(lon)
z = np.sin(lat)
u = np.linspace(0, 1, ngrid)
v, u = np.meshgrid(u, u)
# Earth rotation angle for t (approximate, visualization only)
if isinstance(time, Time):
time = time.gps
if isinstance(time, Real):
time = np.array([time])
mjd_tt = 44244.0 + (time + 51.184) / 86400
gst = gst94(2400000.5, mjd_tt)
u = u - (gst / (2 * np.pi))[:, None, None]
v = np.broadcast_to(v, u.shape)
return ipv.plot_mesh(
x * R * rfactor, y * R * rfactor, z * R * rfactor,
u=u, v=v, wireframe=False, texture=earth
)
[docs]
def load_moon_file():
moon = PILImage.open(find_file("moon", ext=".png"))
moon = moon.resize((5400 // 5, 2700 // 5))
return moon
[docs]
def drawMoon(time, ngrid=100, R=MOON_RADIUS, rfactor=1):
"""
Parameters
----------
time : array_like or astropy.time.Time (n,)
If float (array), then should correspond to GPS seconds;
i.e., seconds since 1980-01-06 00:00:00 UTC
ngrid : int
Number of grid points in Moon model.
R : float
Moon radius in meters.
rfactor : float
Factor by which to enlarge Moon (for visualization purposes)
"""
import ipyvolume as ipv
moon = load_moon_file()
lat = np.linspace(-np.pi / 2, np.pi / 2, ngrid)
lon = np.linspace(-np.pi, np.pi, ngrid)
lat, lon = np.meshgrid(lat, lon)
x = np.cos(lat) * np.cos(lon)
y = np.cos(lat) * np.sin(lon)
z = np.sin(lat)
u = np.linspace(0, 1, ngrid)
v, u = np.meshgrid(u, u)
if isinstance(time, Time):
time = time.gps
if isinstance(time, Real):
time = np.array([time])
mjd_tt = 44244.0 + (time + 51.184) / 86400
gst = gst94(2400000.5, mjd_tt)
u = u - (gst / (2 * np.pi))[:, None, None]
v = np.broadcast_to(v, u.shape)
return ipv.plot_mesh(
x * R * rfactor, y * R * rfactor, z * R * rfactor,
u=u, v=v, wireframe=False, texture=moon
)
save_plot_to_pdf_call_count = 0
[docs]
def save_plot_to_pdf(figure, pdf_path):
"""
Save a Matplotlib figure as a PNG embedded in a PDF file.
If the specified PDF already exists, append a new page; otherwise create it.
"""
global save_plot_to_pdf_call_count
save_plot_to_pdf_call_count += 1
# Expand user directory if ~ is in the path
if pdf_path.startswith('~'):
pdf_path = os.path.expanduser(pdf_path)
# Temporary PDF path
if '.' in pdf_path:
temp_pdf_path = re.sub(r"\.[^.]+$", "_temp.pdf", pdf_path)
else:
temp_pdf_path = f"{pdf_path}_temp.pdf"
# Save the figure as a PNG in-memory using BytesIO
png_buffer = io.BytesIO()
figure.savefig(png_buffer, format='png', dpi=300, bbox_inches='tight')
png_buffer.seek(0)
# Open the in-memory PNG using PIL
png_image = PILImage.open(png_buffer)
# Create the temporary PDF with the PNG image
with PdfPages(temp_pdf_path) as pdf:
img_fig, img_ax = plt.subplots()
img_ax.imshow(png_image)
img_ax.axis('off')
pdf.savefig(img_fig, dpi=300, bbox_inches='tight')
# Merge or move into place
if os.path.exists(pdf_path):
merger = PdfMerger()
with open(pdf_path, "rb") as main_pdf, open(temp_pdf_path, "rb") as temp_pdf:
merger.append(main_pdf)
merger.append(temp_pdf)
with open(pdf_path, "wb") as merged_pdf:
merger.write(merged_pdf)
os.remove(temp_pdf_path)
else:
os.rename(temp_pdf_path, pdf_path)
plt.close(figure)
plt.close(img_fig)
print(f"Saved figure {save_plot_to_pdf_call_count} to {pdf_path}")
[docs]
def save_plot(figure, save_path, dpi=200):
"""
Save a Matplotlib figure as JPG (or append to PDF if save_path ends with .pdf).
"""
if save_path.lower().endswith('.pdf'):
save_plot_to_pdf(figure, save_path)
return
try:
base_name, extension = os.path.splitext(save_path)
if extension.lower() != '.jpg':
save_path = base_name + '.jpg'
figure.savefig(save_path, dpi=dpi, bbox_inches=None)
plt.close(figure)
print(f"Figure saved at: {save_path}")
except Exception as e:
print(f"Error occurred while saving the figure: {e}")
[docs]
def yufig(figure, save_path, dpi=200):
"""
Save a Matplotlib figure.
Behavior:
* If save_path has no extension -> save as JPG ('.jpg' is appended).
* If save_path ends with '.pdf' (case-insensitive) -> append/write to PDF
via save_plot_to_pdf.
* If save_path has any other extension -> use it directly with figure.savefig().
"""
from .figpath import figpath
save_path = figpath(save_path)
# Split into base and extension
base_name, extension = os.path.splitext(save_path)
# If no extension was given, default to .jpg
if extension == "":
extension = ".jpg"
save_path = base_name + extension
# PDF: use custom handler
if extension.lower() == ".pdf":
save_plot_to_pdf(figure, save_path)
return
# All other extensions: save as-is
try:
figure.savefig(save_path, dpi=dpi, bbox_inches=None)
plt.close(figure)
print(f"Figure saved at: {save_path}")
except Exception as e:
print(f"Error occurred while saving the figure: {e}")
[docs]
def make_white(fig, *axes):
fig.patch.set_facecolor('white')
for ax in axes:
ax.set_facecolor('white')
ax_items = [ax.title, ax.xaxis.label, ax.yaxis.label]
if hasattr(ax, 'zaxis'):
ax_items.append(ax.zaxis.label)
ax_items += ax.get_xticklabels() + ax.get_yticklabels()
if hasattr(ax, 'get_zticklabels'):
ax_items += ax.get_zticklabels()
ax_items += ax.get_xticklines() + ax.get_yticklines()
if hasattr(ax, 'get_zticklines'):
ax_items += ax.get_zticklines()
for item in ax_items:
item.set_color('black')
return fig, axes
[docs]
def make_black(fig, *axes):
fig.patch.set_facecolor('black')
for ax in axes:
ax.set_facecolor('black')
ax_items = [ax.title, ax.xaxis.label, ax.yaxis.label]
if hasattr(ax, 'zaxis'):
ax_items.append(ax.zaxis.label)
ax_items += ax.get_xticklabels() + ax.get_yticklabels()
if hasattr(ax, 'get_zticklabels'):
ax_items += ax.get_zticklabels()
ax_items += ax.get_xticklines() + ax.get_yticklines()
if hasattr(ax, 'get_zticklines'):
ax_items += ax.get_zticklines()
for item in ax_items:
item.set_color('white')
return fig, axes
[docs]
def draw_dashed_circle(ax, normal_vector, radius, dashes, dash_length=0.1, label='Dashed Circle'):
# Define the circle in the xy-plane
theta = np.linspace(0, 2 * np.pi, 1000)
x_circle = radius * np.cos(theta)
y_circle = radius * np.sin(theta)
z_circle = np.zeros_like(theta)
# Stack the coordinates into a matrix
circle_points = np.vstack((x_circle, y_circle, z_circle)).T
# Create the rotation matrix to align z-axis with the normal vector
normal_vector = normal_vector / np.linalg.norm(normal_vector)
rot = rotation_matrix_from_vectors(np.array([0, 0, 1]), normal_vector)
# Rotate the circle points
rotated_points = circle_points @ rot.T
# Create dashed effect
dash_points = []
dash_gap = int(len(theta) / dashes)
for i in range(dashes):
start_idx = i * dash_gap
end_idx = start_idx + int(dash_length * len(theta))
dash_points.append(rotated_points[start_idx:end_idx])
# Plot the dashed circle in 3D
for points in dash_points:
ax.plot(points[:, 0], points[:, 1], points[:, 2], 'k--', label=label)
label = None # Only one label
[docs]
def create_sphere(cx, cy, cz, r, resolution=360):
"""
Create sphere coordinates with center (cx, cy, cz) and radius r.
Returns
-------
np.ndarray of shape (3, 2*resolution, resolution)
"""
phi = np.linspace(0, 2 * np.pi, 2 * resolution)
theta = np.linspace(0, np.pi, resolution)
theta, phi = np.meshgrid(theta, phi)
r_xy = r * np.sin(theta)
x = cx + np.cos(phi) * r_xy
y = cy + np.sin(phi) * r_xy
z = cz + r * np.cos(theta)
return np.stack([x, y, z])
[docs]
def drawSphere(xCenter, yCenter, zCenter, r, res=10j, flatten=True):
if 'j' not in str(res):
res = complex(0, res)
# draw sphere
u, v = np.mgrid[0:2 * np.pi:2 * res, 0:np.pi:res]
x = np.cos(u) * np.sin(v)
y = np.sin(u) * np.sin(v)
z = np.cos(v)
# shift and scale sphere
x = r * x + xCenter
y = r * y + yCenter
z = r * z + zCenter
if flatten:
x = np.squeeze(np.array(x).flatten())
y = np.squeeze(np.array(y).flatten())
z = np.squeeze(np.array(z).flatten())
return (x, y, z)
[docs]
def darken(color, amount=0.5):
"""
Darken a color by reducing its lightness.
Parameters
----------
color : str
Named color or hex string.
amount : float or iterable of floats in [0,1]
0 -> no change, 1 -> black. Iterable returns multiple shades.
Returns
-------
list of RGB tuples in 0..1
"""
import colorsys
# Resolve base color
try:
base = cnames[color]
except Exception:
base = color
base_rgb = to_rgb(base) # 0..1
h, l, s = colorsys.rgb_to_hls(*base_rgb)
# Normalize amount to iterable
try:
iterator = iter(amount)
except TypeError:
iterator = [amount]
out = []
for a in iterator:
a = float(a)
a = min(max(a, 0.0), 1.0)
new_l = 1 - a * (1 - l)
out.append(colorsys.hls_to_rgb(h, new_l, s))
return out
[docs]
def rgb(minimum, maximum, value):
minimum, maximum = float(minimum), float(maximum)
ratio = 2 * (value - minimum) / (maximum - minimum)
b = int(max(0, 255 * (1 - ratio)))
r = int(max(0, 255 * (ratio - 1)))
g = 255 - b - r
return r, g, b
[docs]
def generate_rainbow_colors(num_iterations):
cmap = plt.get_cmap('rainbow')
colors = [rgb2hex(cmap(i / num_iterations)) for i in range(num_iterations)]
return colors