Source code for ssapy_toolkit.Plots.gifify

import os
import tempfile
import numpy as np
import matplotlib.pyplot as plt
import imageio.v2 as imageio

[docs] def gifify( plot_func, *fargs, save_path=None, array_arg_indices=(0, 1), array_kw_keys=None, mode="cumulative", # "cumulative", "chunks", "sliding" chunk_size=None, step=None, start=0, end=None, fps=12, loop=0, dpi=None, verbose=False, inject_ax=False, ax_arg_index=None, ax_kw_key=None, fixed_limits=True, # lock 2D/3D axes across frames fix_box_aspect_3d=True, # keep cube aspect on 3D panes **fkwargs ): """ Animate any Matplotlib plot function by slicing two array args and compiling frames into a GIF. Works with pyplot functions (return None), functions returning Axes, or Figure. If your function expects an Axes, set `inject_ax=True` and specify where with `ax_arg_index` (positional) or `ax_kw_key` (keyword). """ # ---------- defaults & validation ---------- if save_path is None: from pathlib import Path d = Path.cwd() / "figures" d.mkdir(parents=True, exist_ok=True) save_path = str(d / "animation.gif") if array_kw_keys is not None and len(array_kw_keys) != 2: raise ValueError("array_kw_keys must be a tuple of two names or None.") if array_kw_keys is None and (not isinstance(array_arg_indices, (list, tuple)) or len(array_arg_indices) != 2): raise ValueError("array_arg_indices must be a tuple/list of two indices.") if mode not in {"cumulative", "chunks", "sliding"}: raise ValueError("mode must be 'cumulative', 'chunks', or 'sliding'.") args = list(fargs) kwargs = dict(fkwargs) # Resolve the two arrays to animate def _get_arrays_full(): if array_kw_keys is not None: a = np.asarray(kwargs[array_kw_keys[0]]) b = np.asarray(kwargs[array_kw_keys[1]]) else: a = np.asarray(args[array_arg_indices[0]]) b = np.asarray(args[array_arg_indices[1]]) return a, b A_full, B_full = _get_arrays_full() n = min(len(A_full), len(B_full)) if end is None: end = n end = min(end, n) if not (0 <= start < end): raise ValueError("Invalid start/end range.") # Optional 3D limit precompute (meters -> km), passed as `limit` if plot supports it def _as_3d(a): a = np.asarray(a) return a if (a.ndim == 2 and a.shape[1] >= 3) else None limit_km = None if fixed_limits: cand = _as_3d(A_full) if cand is None: cand = _as_3d(B_full) if cand is not None: try: maxabs_km = float(np.max(np.abs(cand[:, :3])) / 1e3) # meters -> km limit_km = max(10.0, maxabs_km) * 1.02 # small pad, min cube kwargs.setdefault("limit", limit_km) # only if plot supports it if verbose: print(f"[gifify] Precomputed 3D limits (km): ±{limit_km:.3f}") except Exception as e: if verbose: print(f"[gifify] Could not precompute 3D limits: {e}") if mode in {"chunks", "sliding"} and (chunk_size is None or chunk_size <= 0): raise ValueError("chunk_size must be a positive integer for 'chunks' or 'sliding'.") if step is None: step = 1 if mode == "cumulative" else chunk_size if step <= 0: raise ValueError("step must be positive.") # ---------- build spans ---------- spans = [] if mode == "cumulative": for i1 in range(start + 1, end + 1, step): spans.append((start, i1)) if spans[-1][1] != end: spans.append((start, end)) elif mode == "chunks": i = start while i < end: i1 = min(i + chunk_size, end) spans.append((i, i1)) i += step if not spans: spans.append((start, min(start + chunk_size, end))) else: # sliding i = start while i < end: i0 = i i1 = min(i0 + chunk_size, end) if i0 >= i1: break spans.append((i0, i1)) i += step if not spans: spans.append((start, min(start + chunk_size, end))) def _apply_slice(i0, i1): args_s = list(args) kwargs_s = dict(kwargs) if array_kw_keys is not None: kwargs_s[array_kw_keys[0]] = np.asarray(kwargs_s[array_kw_keys[0]])[i0:i1] kwargs_s[array_kw_keys[1]] = np.asarray(kwargs_s[array_kw_keys[1]])[i0:i1] else: args_s[array_arg_indices[0]] = np.asarray(args_s[array_arg_indices[0]])[i0:i1] args_s[array_arg_indices[1]] = np.asarray(args_s[array_arg_indices[1]])[i0:i1] return args_s, kwargs_s def _resolve_fig(ret): if hasattr(ret, "figure") and ret.__class__.__name__.lower().endswith("axes"): return ret.figure if "Figure" in str(type(ret)): return ret return plt.gcf() # ---------- geometry-based axes matching (to keep limits stable) ---------- def _ax_geom_key(ax, rnd=3): try: pos = ax.get_position().frozen() x0, y0, w, h = pos.x0, pos.y0, pos.width, pos.height except Exception: x0 = y0 = 0.0; w = h = 1.0 is3d = hasattr(ax, "get_zlim") return (round(x0, rnd), round(y0, rnd), round(w, rnd), round(h, rnd), "3d" if is3d else "2d") def _capture_limits_and_layout(fig): """ Return a mapping keyed by axis geometry containing: - x/y/z limits - axis labels and title (to prevent disappearing text) """ limits_map = {} for ax in fig.axes: key = _ax_geom_key(ax) try: xlim = getattr(ax, "get_xlim3d", ax.get_xlim)() except Exception: xlim = ax.get_xlim() try: ylim = getattr(ax, "get_ylim3d", ax.get_ylim)() except Exception: ylim = ax.get_ylim() zlim = ax.get_zlim() if hasattr(ax, "get_zlim") else None limits_map[key] = { "x": xlim, "y": ylim, "z": zlim, "xlabel": ax.get_xlabel() or "", "ylabel": ax.get_ylabel() or "", "title": ax.get_title() or "", } return limits_map def _apply_limits_by_layout(fig, limits_map): for ax in fig.axes: key = _ax_geom_key(ax) ent = limits_map.get(key) if not ent: continue try: # Limits if ent["x"] is not None: if hasattr(ax, "set_xlim3d"): ax.set_xlim3d(ent["x"]) else: ax.set_xlim(ent["x"]) if ent["y"] is not None: if hasattr(ax, "set_ylim3d"): ax.set_ylim3d(ent["y"]) else: ax.set_ylim(ent["y"]) if ent["z"] is not None and hasattr(ax, "set_zlim3d"): ax.set_zlim3d(ent["z"]) # Labels & title: keep them present every frame try: ax.set_xlabel(ent.get("xlabel", "")) ax.set_ylabel(ent.get("ylabel", "")) ttl = ent.get("title", "") if ttl: ax.set_title(ttl) except Exception: pass # Optional: stabilize 3D geometry if fix_box_aspect_3d and hasattr(ax, "set_box_aspect"): ax.set_box_aspect((1, 1, 1)) try: ax.set_proj_type("ortho") except Exception: pass except Exception: pass # ---------- probe render (capture fixed limits/layout) ---------- saved_limits = None if fixed_limits: try: probe_args, probe_kwargs = _apply_slice(start, end) for key in ("save_path", "show"): if key in probe_kwargs: probe_kwargs[key] = False if inject_ax: fig_probe, ax_inj = plt.subplots() if ax_kw_key is not None: probe_kwargs[ax_kw_key] = ax_inj elif ax_arg_index is not None: probe_args = list(probe_args) probe_args.insert(ax_arg_index, ax_inj) else: probe_kwargs["ax"] = ax_inj plt.close("all") ret_probe = plot_func(*probe_args, **probe_kwargs) fig_probe = _resolve_fig(ret_probe) try: fig_probe.canvas.draw_idle(); plt.pause(0.001) except Exception: pass saved_limits = _capture_limits_and_layout(fig_probe) plt.close(fig_probe) if verbose: print(f"[gifify] Captured {len(saved_limits)} axes for fixed limits/layout.") except Exception as e: saved_limits = None if verbose: print(f"[gifify] Probe failed; continuing without fixed limits (reason: {e})") # ---------- frame render ---------- os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) with tempfile.TemporaryDirectory() as tmpdir: if verbose: print(f"Saving frames to temp dir: {tmpdir}") with imageio.get_writer(save_path, mode="I", duration=1.0 / fps, loop=loop) as writer: for frame_idx, (i0, i1) in enumerate(spans): plt.close("all") args_s, kwargs_s = _apply_slice(i0, i1) if fixed_limits and (limit_km is not None): kwargs_s.setdefault("limit", limit_km) if inject_ax: fig_tmp, injected_ax = plt.subplots() if ax_kw_key is not None: kwargs_s[ax_kw_key] = injected_ax elif ax_arg_index is not None: args_s = list(args_s) args_s.insert(ax_arg_index, injected_ax) else: kwargs_s["ax"] = injected_ax ret = plot_func(*args_s, **kwargs_s) fig_to_save = _resolve_fig(ret) # lock axes/labels if fixed_limits and saved_limits: _apply_limits_by_layout(fig_to_save, saved_limits) if dpi is not None: try: fig_to_save.set_dpi(dpi) except Exception: pass frame_path = os.path.join(tmpdir, f"frame_{frame_idx:06d}.png") fig_to_save.savefig(frame_path, dpi=dpi if dpi is not None else fig_to_save.get_dpi(), bbox_inches="tight", pad_inches=0.2) writer.append_data(imageio.imread(frame_path)) plt.close(fig_to_save) if verbose: print(f"Rendered frame {frame_idx + 1}/{len(spans)} slice=[{i0}:{i1}]") if verbose: print(f"GIF saved to {save_path}") return { "frames": len(spans), "path": save_path, "mode": mode, "chunk_size": chunk_size, "step": step, "range": (start, end), "fixed_limits": fixed_limits, }