HDF5 serialization for JAX and NumPy pytrees with full type fidelity and reference identity
Project description
Jaxon
Jaxon is a focused Python library for saving and loading arbitrary nested data structures ("pytrees") to the Hierarchical Data Format (HDF5).
HDF5 is an open, self-describing format with native support for multidimensional
arrays and metadata. Jaxon stores enough information to reconstruct the original
Python objects, so HDF5 files produced by Jaxon can be inspected with standard
tools such as h5dump or HDFView, and can be read even when the original code
is no longer available.
Jaxon is well suited for machine learning and scientific computing. It is especially
compatible with packages that use Python dataclasses and
JAX, such as Equinox.
Jaxon intentionally has a narrow scope — saving and loading pytrees is all it does.
The save and load API is stable.
Requires Python ≥ 3.12.
Installation
pip install jaxon
Quick start
from jaxon import save, load
import numpy as np
import jax.numpy as jnp
pytree = {
"mylist": ["foo", "bar", 42],
"myset": {"a", "b", "z", (42, b"blob")},
"numpy_array": np.arange(3),
"jax_array": jnp.arange(3),
}
save("data.hdf5", pytree)
print(load("data.hdf5"))
{'mylist': ['foo', 'bar', 42], 'myset': {'b', 'a', (42, b'blob'), 'z'}, 'numpy_array': array([0, 1, 2]), 'jax_array': Array([0, 1, 2], dtype=int32)}
save also accepts a file-like object instead of a path:
import tempfile
with tempfile.TemporaryFile() as f:
save(f, pytree)
result = load(f)
Supported types
A pytree is any nested combination of the types listed below. Dictionary keys may themselves be pytrees (as long as they are hashable). Circular references are detected and raise an error.
| Type | Stored as |
|---|---|
list, tuple, dict, set, frozenset |
HDF5 group |
str |
HDF5 UTF-8 fixed-length string |
int, float, bool, complex |
String representation (see Python numerics) |
None, Ellipsis, slice, range |
String representation |
bytes, bytearray, memoryview |
HDF5 attribute (or dataset) |
np.ndarray, jax.Array |
HDF5 attribute (or dataset) |
np.bool_ |
HDF5 attribute |
np.int8, np.int16, np.int32, np.int64 |
HDF5 attribute |
np.uint8, np.uint16, np.uint32, np.uint64 |
HDF5 attribute |
np.float16, np.float32, np.float64 |
HDF5 attribute |
np.complex64, np.complex128 |
HDF5 attribute |
np.longdouble, np.clongdouble |
HDF5 attribute (see Platform-specific types) |
| Any Python dataclass | HDF5 group containing all fields (see Dataclasses) |
Only the array contents are stored; metadata such as array titles is not preserved.
Python numeric types
By default Jaxon preserves the Python types int, float, bool, and complex
exactly using a string representation. To store them as compact binary HDF5
attributes (which is more efficient for large numbers of scalars) pass
exact_python_numeric_types=False to save:
save("data.hdf5", {"x": 1, "y": 3.14}, exact_python_numeric_types=False)
result = load("data.hdf5")
# result["x"] is np.int64(1), result["y"] is np.float64(3.14)
To convert only specific Python types, use py_to_np_types:
save("data.hdf5", data, py_to_np_types=(int, float)) # bool and complex stay as strings
Platform-specific numeric types
np.longdouble and np.clongdouble are supported on all platforms, but their
precision depends on the hardware:
- Linux x86-64: 80-bit extended precision (stored in 128 bits); accessible
also as
np.float128andnp.complex256. - Windows / macOS ARM: 64-bit (same precision as
np.float64);np.float128does not exist on these platforms.
A file containing np.longdouble scalars saved on Linux can be loaded on Windows,
but values will be truncated to 64-bit precision. Jaxon does not warn about this
because the truncation happens inside h5py.
Dataclasses
Jaxon stores the module name, class name, and all field values of a dataclass.
On load, the class is instantiated via __new__ (bypassing __init__) and each
field is set directly, which works even for frozen dataclasses.
from dataclasses import dataclass
import numpy as np
from jaxon import save, load
@dataclass
class Model:
weights: np.ndarray
bias: float
name: str
m = Model(weights=np.array([1.0, 2.0]), bias=0.5, name="linear")
save("model.hdf5", m)
print(load("model.hdf5")) # Model(weights=array([1., 2.]), bias=0.5, name='linear')
Machine learning packages such as Equinox automatically make all modules Python dataclasses, so Jaxon is fully compatible with them.
Schema evolution
If a dataclass has changed between saving and loading (fields added or removed), the following options control behaviour:
result = load(
"model.hdf5",
allow_missing_fields=True, # fields in file but absent from the class: warn, skip
allow_unknown_fields=True, # fields in class but absent from file: use default or JAXON_NOT_LOADED
)
Reference identity
Jaxon preserves reference identity across a save/load cycle. If the same object appears at multiple locations in a pytree, it will be the same object after loading:
a = np.array([1, 2, 3])
pytree = {"x": a, "y": a}
save("data.hdf5", pytree)
result = load("data.hdf5")
assert result["x"] is result["y"] # True
Custom types
to_jaxon / from_jaxon interface
For types that are not natively supported and not dataclasses, implement
to_jaxon and from_jaxon:
class MyModel:
def __init__(self, weights, config):
self.weights = weights
self.config = config
def to_jaxon(self):
# return a supported container
return {"weights": self.weights, "config": self.config}
def from_jaxon(self, data):
# populate self from the container returned by to_jaxon
self.weights = data["weights"]
self.config = data["config"]
save("model.hdf5", MyModel(np.eye(3), {"lr": 0.01}))
result = load("model.hdf5") # MyModel instance
to_jaxon takes priority over the dataclass fallback if both apply. Jaxon stores
the fully-qualified class name so the correct class is instantiated on load.
Custom marshaler/unmarshaler functions
For types you cannot modify, pass callables to save and load:
def marshal(obj):
if isinstance(obj, MyClass):
return "mymodule.MyClass", {"value": obj.value}
return None # signal that this marshaler cannot handle the type
def unmarshal(qualified_name, data):
if qualified_name == "mymodule.MyClass":
return MyClass(data["value"])
return None
save("data.hdf5", obj, custom_marshalers=[marshal])
result = load("data.hdf5", custom_unmarshalers=[unmarshal])
Multiple marshalers can be provided; the first one returning non-None is used.
Custom marshalers take priority over to_jaxon/from_jaxon and the dataclass
fallback.
Serialization with dill
As a last resort, Jaxon can serialize unsupported types using dill (an extended pickle) and store the result as a binary blob. This must be opted into explicitly:
save("data.hdf5", obj, allow_dill=True)
result = load("data.hdf5", allow_dill=True)
Note that dill-serialized objects are opaque binary blobs and cannot be inspected with HDF5 viewers.
Partial loading with load_filter
Large pytrees can be partially loaded by providing a filter function. The filter
receives the path to each node as a list and returns True to load it or False
to skip it (skipped nodes are replaced with JAXON_NOT_LOADED):
from jaxon import load, has_common_prefix, JAXON_NOT_LOADED
# only load pytree["weights"] and anything nested under it
result = load("model.hdf5", load_filter=lambda path: has_common_prefix(["weights"], path))
result["weights"] # np.ndarray — loaded
result["config"] # JAXON_NOT_LOADED — skipped
has_common_prefix(prefix, path) is a convenience function that returns True
when path starts with prefix. For dictionaries the path element is the dict key,
for lists and sets it is the integer index, and for dataclasses it is the field name
string. Dict keys themselves are always loaded regardless of the filter.
Storage hints
By default all arrays, byte buffers, and memoryviews are stored as HDF5 attributes. For very large arrays it can be preferable to use HDF5 datasets instead, which support chunking and compression:
from jaxon import save, JaxonStorageHints
big_array = np.zeros((1000, 1000))
save(
"data.hdf5",
{"array": big_array},
storage_hints=[(big_array, JaxonStorageHints(store_in_dataset=True))],
)
The hint is identified by object identity (is), so the object passed in
storage_hints must be the same object that appears in the pytree.
Acknowledgements
Jaxon is built on the following libraries:
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxon-1.2.0.tar.gz.
File metadata
- Download URL: jaxon-1.2.0.tar.gz
- Upload date:
- Size: 39.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
86834f4b2a3d6f2c5a984b999208e0a5b0d6d5ba1d83ec228227584b19374e72
|
|
| MD5 |
7d23b3f91dd1b84ff8ac3b3d09feadce
|
|
| BLAKE2b-256 |
017d607641e584945f205d6db422bb6a01dce06e1dab331fca442134b61bb20a
|
Provenance
The following attestation bundles were made for jaxon-1.2.0.tar.gz:
Publisher:
release.yml on 42Ar/jaxon
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxon-1.2.0.tar.gz -
Subject digest:
86834f4b2a3d6f2c5a984b999208e0a5b0d6d5ba1d83ec228227584b19374e72 - Sigstore transparency entry: 1853411661
- Sigstore integration time:
-
Permalink:
42Ar/jaxon@a7691bf2cf4e826cdfe265be2b984cea68aabc44 -
Branch / Tag:
refs/tags/v1.2.0 - Owner: https://github.com/42Ar
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@a7691bf2cf4e826cdfe265be2b984cea68aabc44 -
Trigger Event:
push
-
Statement type:
File details
Details for the file jaxon-1.2.0-py3-none-any.whl.
File metadata
- Download URL: jaxon-1.2.0-py3-none-any.whl
- Upload date:
- Size: 33.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1b82640adf33bf6f5c96f0a64f9c27705eae3c49a6cbde16d0c2fb3cdbde0428
|
|
| MD5 |
c768d98da0c50f9dab2ce13493b33a0e
|
|
| BLAKE2b-256 |
6d597f0a4aa636b77a3f974f8b128debe0d71662b905106c17ba6396267ad9f5
|
Provenance
The following attestation bundles were made for jaxon-1.2.0-py3-none-any.whl:
Publisher:
release.yml on 42Ar/jaxon
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxon-1.2.0-py3-none-any.whl -
Subject digest:
1b82640adf33bf6f5c96f0a64f9c27705eae3c49a6cbde16d0c2fb3cdbde0428 - Sigstore transparency entry: 1853411698
- Sigstore integration time:
-
Permalink:
42Ar/jaxon@a7691bf2cf4e826cdfe265be2b984cea68aabc44 -
Branch / Tag:
refs/tags/v1.2.0 - Owner: https://github.com/42Ar
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@a7691bf2cf4e826cdfe265be2b984cea68aabc44 -
Trigger Event:
push
-
Statement type: