A package for saving JAX-compatible dataclasses with Orbax
Project description
Orbax 🤝 Dataclasses
A convenient way to serialize dataclasses (and for orbax) in an easier to read way (avoid Pickle!)
Usage:
It's pretty common to store parameters and model information in the same object in Jax scripts, e.g.
@jax.tree_util.register_dataclass
@dataclasses.dataclass
class Model:
module: nn.Module = dataclasses.field(metadata=dict(static=True))
params: dict
# or e.g. a flax.training.train_state.TrainState
If we try saving it using standard Orbax:
ckptr = ocp.StandardCheckpointer()
ckptr.save('/tmp/checkpoint', model)
ckptr.restore('/tmp/checkpoint') # -> we'll get a dict: {'params': ...} back, the object is lost
The idea of this library is to be able to seamlessly store the full object
ckptr = flax_orbax.Checkpointer()
ckptr.save('/tmp/checkpoint', model)
ckptr.restore('/tmp/checkpoint') # -> we will get a Model() back
A few sharp edges
Some things in a typical JAX workflow are not serializable -- e.g. the optax transformation. flax_orbax.wrap handles this:
# instead of optax.adam(1e-3), do:
tx = flax_orbax.wrap(optax.adam)(1e-3)
If you want to restore on a different device (with a different sharding):
ckptr = flax_orbax.Checkpointer()
ckptr.restore(args=flax_orbax.RestoreArgs(sharding=sharding))
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flax_orbax-0.1.4.tar.gz.
File metadata
- Download URL: flax_orbax-0.1.4.tar.gz
- Upload date:
- Size: 73.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
270632eea3c1d78a07d6bd7a7243405534fd86e7f617628d82deb6c20e913a21
|
|
| MD5 |
f6187256262865c9a01012214c91d53c
|
|
| BLAKE2b-256 |
29968c55e4dd263e81ad971b07c891b1d840820a7023be91e867f9d00e461da1
|
File details
Details for the file flax_orbax-0.1.4-py3-none-any.whl.
File metadata
- Download URL: flax_orbax-0.1.4-py3-none-any.whl
- Upload date:
- Size: 6.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dc3fc0e71d854468426e5ed208f1e872b58604b3a617016217479209af4c6c6c
|
|
| MD5 |
08e9b2109365cb20ada591ab1737116e
|
|
| BLAKE2b-256 |
e79e2871c797cc03826b3ca990b717d772149e51c9c41e09e572e28f23434392
|