Skip to main content

A package for saving JAX-compatible dataclasses with Orbax

Project description

Orbax 🤝 Dataclasses

A convenient way to serialize dataclasses (and for orbax) in an easier to read way (avoid Pickle!)

Usage:

Suppose we have a train state

import flax.linen as nn
import optax
import flax.training.train_state
import flax_orbax

model = nn.Sequential([nn.Dense(10, kernel_init=nn.initializers.ones), nn.Dense(10, kernel_init=nn.initializers.ones)])
params = model.init(jax.random.key(0), jax.numpy.ones((1, 20)))['params']
tx = flax_orbax.wrap(optax.adam)(1e-3) # Add flax_orbax.wrap to keep track of objects that aren't serializable
state = flax.training.train_state.TrainState.create(apply_fn=model, params=params, tx=tx)

Now, we can save this easily

import orbax.checkpoint as ocp
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
ckptr = flax_orbax.Checkpointer()
ckptr.save(path / '0', state)
ckptr.wait_until_finished()
ckptr.restore(path / '0') # Unlike StandardCheckpointer(), this will return a train state! not a dict

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax_orbax-0.1.1.tar.gz (73.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax_orbax-0.1.1-py3-none-any.whl (6.5 kB view details)

Uploaded Python 3

File details

Details for the file flax_orbax-0.1.1.tar.gz.

File metadata

  • Download URL: flax_orbax-0.1.1.tar.gz
  • Upload date:
  • Size: 73.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.4

File hashes

Hashes for flax_orbax-0.1.1.tar.gz
Algorithm Hash digest
SHA256 9fba7977558a29b6abbeca95723e2715f7079b9773aa461168bd9221260c4112
MD5 b7020c4fd472a7e5541b8b943750efd8
BLAKE2b-256 3ec4a5811f09d71ab223b94eff8189826cffbfae2f2bad09f8c697e8486f5db0

See more details on using hashes here.

File details

Details for the file flax_orbax-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for flax_orbax-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4f6b66e82b209237949a7d459f8d0c017ad58c9d8343f6522ef41f4760410dd2
MD5 82438938c8a7849d6173de8a3b156e6a
BLAKE2b-256 891d9eafa510e08fe4c079bc7ff883877a71d17cf4949f6ae2836a1f23fa66ac

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page