Source code for ssapy_toolkit.Plots.misc_plotting

######################################################################
# COLLECTION OF ALL PLOTTING AND MEDIA
######################################################################
# flake8: noqa: E501
import numpy as np
from ssapy.body import get_body
from ..constants import RGEO, EARTH_MU, MOON_MU
from ..Time_Functions import Time
from .plotutils import make_black, make_white, save_plot

import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import LogNorm, Normalize

lunar_semi_major = 384399000  # m


[docs] def koe_plot(r: np.ndarray, v: np.ndarray, t=None, elements=None, save_path=None, body: str = 'Earth'): """ Generates a plot of orbital elements (eccentricity, inclination, and semi-major axis) for a given position and velocity vectors. Parameters ---------- r : np.ndarray Position vectors of the satellite in 3D space (shape: [n, 3]). v : np.ndarray Velocity vectors of the satellite in 3D space (shape: [n, 3]). t : Optional[Time], optional Time instance(s) associated with the orbital elements. If None, the x-axis is an index. elements : list, optional Subset of elements to plot among ['a', 'e', 'i']. Default: all three. save_path : str or None, optional Path to save the plot. If None, the plot is not saved. body : str, optional 'Earth' or 'Moon'. Returns ------- (plt.Figure, plt.Axes) """ if elements is None: elements = ['a', 'e', 'i'] # Calculate orbital elements for Earth or Moon from ..Orbital_Mechanics import calculate_orbital_elements if 'earth' in body.lower(): orbital_elements = calculate_orbital_elements(r, v, mu_barycenter=EARTH_MU) else: orbital_elements = calculate_orbital_elements(r, v, mu_barycenter=MOON_MU) # Choose an x source based on what's available def _len_of_first_present(): for key in ('e', 'i', 'a'): if key in orbital_elements: return len(orbital_elements[key]) return 0 fig, ax1 = plt.subplots(dpi=100) ax2 = ax1.twinx() make_white(fig, *[ax1, ax2]) x_values = Time(t).decimalyear if t is not None else np.arange(_len_of_first_present()) # Plot eccentricity and inclination (left axis) if 'e' in elements and 'e' in orbital_elements: ax1.plot(x_values, [x for x in orbital_elements['e']], label='eccentricity', c='C1') if 'i' in elements and 'i' in orbital_elements: ax1.plot(x_values, [x for x in orbital_elements['i']], label='inclination [rad]', c='C2') ax1.set_xlabel('Year' if t is not None else 'Index') ax1.set_ylim((0, np.pi / 2)) ylabel = ax1.set_ylabel('', color='black') # Annotate only what we actually plotted xlab = ylabel.get_position()[0] + 0.05 ylab = ylabel.get_position()[1] if 'e' in elements and 'e' in orbital_elements: fig.text(xlab - 0.001, ylab - 0.225, 'Eccentricity', color='C1', rotation=90) if ('e' in elements and 'e' in orbital_elements) and ('i' in elements and 'i' in orbital_elements): fig.text(xlab, ylab - 0.05, '/', color='k', rotation=90) if 'i' in elements and 'i' in orbital_elements: fig.text(xlab, ylab - 0.025, 'Inclination [Radians]', color='C2', rotation=90) ax1.legend(loc='upper left') # Semi-major axis on right axis if 'a' in elements and 'a' in orbital_elements: a = [x / RGEO for x in orbital_elements['a']] ax2.plot(x_values, a, label='semi-major axis [GEO]', c='C0', linestyle='-') ax2.set_ylabel('semi-major axis [GEO]', color='C0') ax2.yaxis.label.set_color('C0') ax2.tick_params(axis='y', colors='C0') ax2.spines['right'].set_color('C0') if np.abs(np.max(a) - np.min(a)) < 2: ax2.set_ylim((np.min(a) - 0.5, np.max(a) + 0.5)) # Optionally save the plot if save_path: fig.savefig(save_path) return fig, ax1
[docs] def koe_2dhist(stable_data, title: str = "Initial orbital elements of\n1 year stable cislunar orbits", limits: list = [1, 50], bins: int = 200, logscale: bool = False, cmap: str = 'coolwarm', save_path: str = None) -> plt.Figure: """ Generates a 2D histogram plot of orbital elements for a set of stable orbital data. """ # Validate angle data ranges if not (np.all((0 <= stable_data.i) & (stable_data.i <= 2 * np.pi))): raise ValueError("Inclination (`i`) must be in the range [0, 2π] radians.") if not (np.all((0 <= stable_data.ta) & (stable_data.ta <= 2 * np.pi))): raise ValueError("True Anomaly (`ta`) must be in the range [0, 2π] radians.") if logscale or logscale == 'log': norm = LogNorm(limits[0], limits[1]) else: norm = Normalize(limits[0], limits[1]) fig, axes = plt.subplots(dpi=100, figsize=(10, 8), nrows=3, ncols=3) st = fig.suptitle(title, fontsize=12) st.set_x(0.46) st.set_y(0.9) mappable = None # will hold a QuadMesh from hist2d # Semi-major axis vs Eccentricity ax = axes.flat[0] *_, mappable = ax.hist2d([x / RGEO for x in stable_data.a], [x for x in stable_data.e], bins=bins, norm=norm, cmap=cmap) ax.set_xlabel("") ax.set_ylabel("eccentricity") ax.set_xticks(np.arange(1, 20, 2)) ax.set_yticks(np.arange(0, 1, 0.2)) ax.set_xlim((1, 18)) # Empty plots axes.flat[1].set_axis_off() axes.flat[2].set_axis_off() # Semi-major axis vs Inclination ax = axes.flat[3] *_, mappable = ax.hist2d([x / RGEO for x in stable_data.a], [np.degrees(x) for x in stable_data.i], bins=bins, norm=norm, cmap=cmap) ax.set_xlabel("") ax.set_ylabel("inclination [deg]") ax.set_xticks(np.arange(1, 20, 2)) ax.set_yticks(np.arange(0, 91, 15)) ax.set_xlim((1, 18)) # Eccentricity vs Inclination ax = axes.flat[4] *_, mappable = ax.hist2d([x for x in stable_data.e], [np.degrees(x) for x in stable_data.i], bins=bins, norm=norm, cmap=cmap) ax.set_xlabel("") ax.set_ylabel("") ax.set_xticks(np.arange(0, 1, 0.2)) ax.set_yticks(np.arange(0, 91, 15)) # Empty plot axes.flat[5].set_axis_off() # Semi-major axis vs True Anomaly ax = axes.flat[6] *_, mappable = ax.hist2d([x / RGEO for x in stable_data.a], [np.degrees(x) for x in stable_data.ta], bins=bins, norm=norm, cmap=cmap) ax.set_xlabel("semi-major axis [GEO]") ax.set_ylabel("True Anomaly [deg]") ax.set_xticks(np.arange(1, 20, 2)) ax.set_yticks(np.arange(0, 361, 60)) ax.set_xlim((1, 18)) # Eccentricity vs True Anomaly ax = axes.flat[7] *_, mappable = ax.hist2d([x for x in stable_data.e], [np.degrees(x) for x in stable_data.ta], bins=bins, norm=norm, cmap=cmap) ax.set_xlabel("eccentricity") ax.set_ylabel("") ax.set_xticks(np.arange(0, 1, 0.2)) ax.set_yticks(np.arange(0, 361, 60)) # Inclination vs True Anomaly ax = axes.flat[8] *_, mappable = ax.hist2d([np.degrees(x) for x in stable_data.i], [np.degrees(x) for x in stable_data.ta], bins=bins, norm=norm, cmap=cmap) ax.set_xlabel("inclination [deg]") ax.set_ylabel("") ax.set_xticks(np.arange(0, 91, 15)) ax.set_yticks(np.arange(0, 361, 60)) # Colorbar with a real mappable fig.subplots_adjust(right=0.8) cbar_ax = fig.add_axes([0.82, 0.15, 0.01, 0.7]) if mappable is not None: fig.colorbar(mappable, cax=cbar_ax) fig, ax = make_white(fig, ax) if save_path: save_plot(fig, save_path) return fig
[docs] def scatter2d(x: list, y: list, cs: list, xlabel: str = 'x', ylabel: str = 'y', title: str = '', cbar_label: str = '', dotsize: int = 1, colorsMap: str = 'jet', colorscale: str = 'linear', colormin: float = None, colormax: float = None, save_path: str = None) -> None: fig = plt.figure() ax = fig.add_subplot(111) cs_arr = np.asarray(cs) if colormax is None: colormax = np.nanmax(cs_arr) if colormin is None: colormin = np.nanmin(cs_arr) cmap = plt.get_cmap(colorsMap) if colorscale == 'linear': cNorm = Normalize(vmin=colormin, vmax=colormax) elif colorscale == 'log': cNorm = LogNorm(vmin=colormin, vmax=colormax) else: cNorm = Normalize(vmin=colormin, vmax=colormax) scalarMap = cm.ScalarMappable(norm=cNorm, cmap=cmap) ax.scatter(x, y, c=scalarMap.to_rgba(cs_arr), s=dotsize) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) scalarMap.set_array(cs_arr) fig.colorbar(scalarMap, shrink=.5, label=f'{cbar_label}', pad=0.04) plt.tight_layout() fig, ax = make_black(fig, ax) plt.show(block=False) if save_path: save_plot(fig, save_path) return
[docs] def scatter3d(x: list, y: list = None, z: list = None, cs: list = None, xlabel: str = 'x', ylabel: str = 'y', zlabel: str = 'z', cbar_label: str = '', dotsize: int = 1, colorsMap: str = 'jet', title: str = '', save_path: str = None): """ Returns ------- (plt.Figure, matplotlib.axes._subplots.Axes3DSubplot) """ fig = plt.figure() ax = fig.add_subplot(111, projection='3d') if hasattr(x, "ndim") and getattr(x, "ndim", 1) > 1: r = np.asarray(x) x = r[:, 0] y = r[:, 1] z = r[:, 2] if cs is None: ax.scatter(x, y, z, s=dotsize) else: cs_arr = np.asarray(cs) cmap = plt.get_cmap(colorsMap) cNorm = Normalize(vmin=np.nanmin(cs_arr), vmax=np.nanmax(cs_arr)) scalarMap = cm.ScalarMappable(norm=cNorm, cmap=cmap) ax.scatter(x, y, z, c=scalarMap.to_rgba(cs_arr), s=dotsize) scalarMap.set_array(cs_arr) fig.colorbar(scalarMap, shrink=.5, label=f'{cbar_label}', pad=0.075) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_zlabel(zlabel) plt.title(title) plt.tight_layout() fig, ax = make_black(fig, ax) plt.show(block=False) if save_path: save_plot(fig, save_path) return fig, ax
[docs] def dotcolors_scaled(num_colors: int) -> list: return cm.rainbow(np.linspace(0, 1, num_colors))
# Make a plot of multiple cislunar orbit in GCRF frame.
[docs] def orbit_divergence_plot(rs: np.ndarray, r_moon: np.ndarray = None, t=None, limits: float = None, title: str = '', save_path: str = None) -> None: if limits is None: limits = np.nanmax(np.linalg.norm(rs, axis=1) / RGEO) * 1.2 print(f'limits: {limits}') # Acquire or validate moon positions if r_moon is None or np.size(r_moon) < 1: moon = get_body("moon") r_moon = moon.position(t) # expect shape (N, 3) or (3, N) r_moon = np.asarray(r_moon) if r_moon.ndim != 2: raise IndexError(f"input moon data shape: {np.shape(r_moon)}, input should be 2 dimensions.") # Ensure shape (N, 3) if r_moon.shape[1] == 3: pass elif r_moon.shape[0] == 3: r_moon = r_moon.T else: raise IndexError(f"input moon data shape: {np.shape(r_moon)}, expected (N,3) or (3,N).") # Normalize once r_moon = r_moon / RGEO fig = plt.figure(dpi=100, figsize=(15, 4)) for i in range(rs.shape[-1]): r = rs[:, :, i] x = r[:, 0] / RGEO y = r[:, 1] / RGEO z = r[:, 2] / RGEO dotcolors = cm.rainbow(np.linspace(0, 1, len(x))) # XY plt.subplot(1, 3, 1) plt.scatter(x, y, color=dotcolors, s=1) plt.scatter(0, 0, color="blue", s=50) plt.scatter(r_moon[:, 0], r_moon[:, 1], color="grey", s=5) plt.axis('scaled') plt.xlabel('x [GEO]') plt.ylabel('y [GEO]') plt.xlim((-limits, limits)) plt.ylim((-limits, limits)) plt.text(x[0], y[0], '$\\leftarrow$ start') plt.text(x[-1], y[-1], '$\\leftarrow$ end') # XZ plt.subplot(1, 3, 2) plt.scatter(x, z, color=dotcolors, s=1) plt.scatter(0, 0, color="blue", s=50) plt.scatter(r_moon[:, 0], r_moon[:, 2], color="grey", s=5) plt.axis('scaled') plt.xlabel('x [GEO]') plt.ylabel('z [GEO]') plt.xlim((-limits, limits)) plt.ylim((-limits, limits)) plt.text(x[0], z[0], '$\\leftarrow$ start') plt.text(x[-1], z[-1], '$\\leftarrow$ end') plt.title(f'{title}') # YZ plt.subplot(1, 3, 3) plt.scatter(y, z, color=dotcolors, s=1) plt.scatter(0, 0, color="blue", s=50) plt.scatter(r_moon[:, 1], r_moon[:, 2], color="grey", s=5) plt.axis('scaled') plt.xlabel('y [GEO]') plt.ylabel('z [GEO]') plt.xlim((-limits, limits)) plt.ylim((-limits, limits)) plt.text(y[0], z[0], '$\\leftarrow$ start') plt.text(y[-1], z[-1], '$\\leftarrow$ end') plt.tight_layout() plt.show(block=False) if save_path: save_plot(fig, save_path) return