A package for saving JAX-compatible dataclasses with Orbax
Project description
Orbax 🤝 Dataclasses
A convenient way to serialize dataclasses (and for orbax) in an easier to read way (avoid Pickle!)
Usage:
Suppose we have a train state
import flax.linen as nn
import optax
import flax.training.train_state
import flax_orbax
model = nn.Sequential([nn.Dense(10, kernel_init=nn.initializers.ones), nn.Dense(10, kernel_init=nn.initializers.ones)])
params = model.init(jax.random.key(0), jax.numpy.ones((1, 20)))['params']
tx = flax_orbax.wrap(optax.adam)(1e-3) # Add flax_orbax.wrap to keep track of objects that aren't serializable
state = flax.training.train_state.TrainState.create(apply_fn=model, params=params, tx=tx)
Now, we can save this easily
import orbax.checkpoint as ocp
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
ckptr = flax_orbax.Checkpointer()
ckptr.save(path / '0', state)
ckptr.wait_until_finished()
ckptr.restore(path / '0') # Unlike StandardCheckpointer(), this will return a train state! not a dict
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
flax_orbax-0.1.0.tar.gz
(73.0 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flax_orbax-0.1.0.tar.gz.
File metadata
- Download URL: flax_orbax-0.1.0.tar.gz
- Upload date:
- Size: 73.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6477ebb9f25369031e4311359c97ecf42b6e282496d43ded214b6e759361e27e
|
|
| MD5 |
b5fce3fef97e93f0b18b68b3033f8a44
|
|
| BLAKE2b-256 |
1d2661284c6c530df1b40a2c52f097a8ea4df37a1af39cc1fd907ca25cc51068
|
File details
Details for the file flax_orbax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: flax_orbax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 6.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d5b9cac5b027a483c9c6e2f50d8411a8ef23b8841f06bb5d5a02e5fd3fe3d0cc
|
|
| MD5 |
58c2710208095611d1ca02cf025dea86
|
|
| BLAKE2b-256 |
37a808bd96f668abee74bc91843185392c46407dfc8a9cd1888dd895b3ebe8ab
|