import inspect
import pickle
from datetime import datetime
from pathlib import Path
import types
import numpy as np
from .yudata import yudata
try:
import h5py
except Exception as e:
h5py = None
_H5PY_IMPORT_ERROR = e
_DEFAULT_EXCLUDE_NAMES = {
"__name__", "__file__", "__package__", "__spec__", "__builtins__",
"__loader__", "__cached__", "__doc__", "__annotations__",
}
# ---- key encoding (safe for HDF5 paths) ----
def _enc_key(k):
# percent-escape '%' first, then '/'
return k.replace("%", "%25").replace("/", "%2F")
def _dec_key(k):
return k.replace("%2F", "/").replace("%25", "%")
def _is_scalar_number(x):
return isinstance(x, (int, float, bool, np.integer, np.floating, np.bool_))
def _is_nonsaved_symbol(v):
"""
Skip symbols that are usually "code" not "data":
- modules (np, inspect, etc.)
- functions/methods/builtins
- classes/types (datetime, Path, custom classes)
"""
if isinstance(v, types.ModuleType):
return True
if isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
types.MethodType,
types.BuiltinMethodType,
),
):
return True
if isinstance(v, type):
return True
return False
def _write_value(h5group, key, value, *, compression="gzip", compression_opts=4):
"""
Write a single value under h5group using name=key (encoded).
For dicts -> subgroup recursion
For numpy/array-like -> dataset
For str/bytes -> dataset with attrs
Fallback -> pickle bytes dataset with attrs
"""
name = _enc_key(key) if isinstance(key, str) else _enc_key(str(key))
# nested dict -> subgroup
if isinstance(value, dict):
grp = h5group.require_group(name)
grp.attrs["__kind__"] = "dict"
for k2, v2 in value.items():
_write_value(
grp,
str(k2),
v2,
compression=compression,
compression_opts=compression_opts,
)
return
# strings
if isinstance(value, str):
dt = h5py.string_dtype(encoding="utf-8")
ds = h5group.create_dataset(name, data=value, dtype=dt)
ds.attrs["__kind__"] = "str"
return
# bytes
if isinstance(value, (bytes, bytearray)):
arr = np.frombuffer(bytes(value), dtype=np.uint8)
ds = h5group.create_dataset(name, data=arr)
ds.attrs["__kind__"] = "bytes"
return
# scalar numbers/bools
if _is_scalar_number(value):
ds = h5group.create_dataset(name, data=value)
ds.attrs["__kind__"] = "scalar"
return
# numpy arrays / array-like
try:
arr = value if isinstance(value, np.ndarray) else np.asarray(value)
# object arrays are not reliably HDF5-portable; store via pickle fallback
if isinstance(arr, np.ndarray) and arr.dtype == object:
raise TypeError("object arrays stored via pickle fallback")
ds = h5group.create_dataset(
name,
data=arr,
compression=compression if arr.size > 0 else None,
compression_opts=compression_opts if arr.size > 0 else None,
)
ds.attrs["__kind__"] = "ndarray"
return
except Exception:
pass
# fallback: pickle
payload = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
arr = np.frombuffer(payload, dtype=np.uint8)
ds = h5group.create_dataset(name, data=arr)
ds.attrs["__kind__"] = "pickle"
ds.attrs["__py_type__"] = str(type(value))
return
def _read_node(node):
"""
Read from an h5py Group or Dataset and reconstruct the Python object.
"""
if isinstance(node, h5py.Group):
out = {}
for name, child in node.items():
out[_dec_key(name)] = _read_node(child)
return out
# Dataset
kind = node.attrs.get("__kind__", "")
if kind == "str":
v = node[()]
if isinstance(v, bytes):
return v.decode("utf-8", errors="replace")
if isinstance(v, np.ndarray) and v.shape == ():
v = v.item()
if isinstance(v, bytes):
return v.decode("utf-8", errors="replace")
return v
if kind == "bytes":
arr = np.array(node[()], dtype=np.uint8).ravel()
return arr.tobytes()
if kind == "scalar":
v = node[()]
if isinstance(v, np.generic):
return v.item()
return v
if kind == "ndarray":
return np.array(node[()])
if kind == "pickle":
arr = np.array(node[()], dtype=np.uint8).ravel()
payload = arr.tobytes()
return pickle.loads(payload)
# default (older files): best-effort
v = node[()]
if isinstance(v, np.ndarray):
return v
if isinstance(v, np.generic):
return v.item()
return v
[docs]
def yucache(
data=None,
filename=None,
*,
exclude_names=None,
exclude_private=True,
include_locals=True,
only_picklable=False,
compression="gzip",
compression_opts=4,
add_timestamp=True,
):
"""
Save a dictionary to an HDF5 cache at your yudata() location.
If data is None:
- captures caller globals (+ locals if include_locals=True)
- skips modules/functions/classes (so it's mostly "data variables")
If data is a dict:
- saves only that dict (after filtering)
Values:
dict -> subgroup recursion
ndarray/array-like -> dataset (compressed)
str/bytes/scalars -> dataset
fallback -> pickled bytes dataset
"""
if h5py is None:
raise ImportError(
f"h5py is required for yucache() HDF5 caches: {_H5PY_IMPORT_ERROR}"
)
if exclude_names is None:
exclude_names = set(_DEFAULT_EXCLUDE_NAMES)
else:
exclude_names = set(exclude_names) | set(_DEFAULT_EXCLUDE_NAMES)
# collect
if data is None:
frame = inspect.currentframe()
if frame is None or frame.f_back is None:
raise RuntimeError("yucache(): could not access caller frame.")
caller = frame.f_back
collected = dict(caller.f_globals)
if include_locals:
collected.update(caller.f_locals)
data = collected
else:
if not isinstance(data, dict):
raise TypeError("yucache(data=...): data must be a dict or None.")
# filter
filtered = {}
for k, v in data.items():
if not isinstance(k, str):
k = str(k)
if k in exclude_names:
continue
if exclude_private and k.startswith("_"):
continue
# skip code-ish symbols so we cache "variables", not imports/definitions
if _is_nonsaved_symbol(v):
continue
if only_picklable:
try:
pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL)
except Exception:
continue
filtered[k] = v
# filename
if filename is None:
filename = "workspace_cache"
if add_timestamp:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{filename}_{ts}"
out = yudata(f"{filename}.h5")
# write
with h5py.File(out, "w") as f:
f.attrs["__yucache__"] = "1"
f.attrs["__key_encoding__"] = "percent(%25) and slash(%2F)"
root = f.require_group("vars")
root.attrs["__kind__"] = "dict"
for k, v in filtered.items():
_write_value(
root,
k,
v,
compression=compression,
compression_opts=compression_opts,
)
return out
[docs]
def yuload(name_or_path, *, summary=True, max_items=200):
"""
Load a yucache() HDF5 cache and return the reconstructed dictionary.
- If name_or_path exists as a path, use it.
- Otherwise treat it as a name under yudata(...).
If summary=True, prints a path summary of the HDF5 contents.
"""
if h5py is None:
raise ImportError(f"h5py is required for yuload(): {_H5PY_IMPORT_ERROR}")
if not isinstance(name_or_path, (str, Path)):
raise TypeError("yuload(name_or_path): must be str or pathlib.Path")
p = Path(name_or_path)
# If it's not an existing path, resolve under yudata
if not p.exists():
s = str(name_or_path)
if not s.lower().endswith((".h5", ".hdf5", ".hdf")):
s = s + ".h5"
p = Path(yudata(s))
def _print_summary(f):
print(f"[yuload] file: {str(p)}")
print("[yuload] tree:")
n = 0
def _visit(name, obj):
nonlocal n
if n >= max_items:
return
# name is path relative to root, e.g. 'vars/x'
if isinstance(obj, h5py.Group):
kind = obj.attrs.get("__kind__", "group")
print(f" /{name} <group> kind={kind} n_children={len(obj)}")
else:
kind = obj.attrs.get("__kind__", "dataset")
shape = getattr(obj, "shape", None)
dtype = getattr(obj, "dtype", None)
print(f" /{name} <dataset> kind={kind} shape={shape} dtype={dtype}")
n += 1
f.visititems(_visit)
if n >= max_items:
print(f"[yuload] ... truncated at {max_items} items")
with h5py.File(str(p), "r") as f:
if summary:
_print_summary(f)
if "vars" in f:
return _read_node(f["vars"])
return _read_node(f)