Skip to main content

A package for saving JAX-compatible dataclasses with Orbax

Project description

Orbax 🤝 Dataclasses

A convenient way to serialize dataclasses (and for orbax) in an easier to read way (avoid Pickle!)

Usage:

It's pretty common to store parameters and model information in the same object in Jax scripts, e.g.

@jax.tree_util.register_dataclass
@dataclasses.dataclass
class Model:
    module: nn.Module = dataclasses.field(metadata=dict(static=True))
    params: dict

# or e.g. a flax.training.train_state.TrainState

If we try saving it using standard Orbax:

ckptr = ocp.StandardCheckpointer()
ckptr.save('/tmp/checkpoint', model)
ckptr.restore('/tmp/checkpoint') # -> we'll get a dict: {'params': ...} back, the object is lost

The idea of this library is to be able to seamlessly store the full object

ckptr = flax_orbax.Checkpointer()
ckptr.save('/tmp/checkpoint', model)
ckptr.restore('/tmp/checkpoint') # -> we will get a Model() back

A few sharp edges

Some things in a typical JAX workflow are not serializable -- e.g. the optax transformation. flax_orbax.wrap handles this:

# instead of optax.adam(1e-3), do:
tx = flax_orbax.wrap(optax.adam)(1e-3)

If you want to restore on a different device (with a different sharding):

ckptr = flax_orbax.Checkpointer()
ckptr.restore(args=flax_orbax.RestoreArgs(sharding=sharding))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax_orbax-0.1.4.tar.gz (73.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax_orbax-0.1.4-py3-none-any.whl (6.8 kB view details)

Uploaded Python 3

File details

Details for the file flax_orbax-0.1.4.tar.gz.

File metadata

  • Download URL: flax_orbax-0.1.4.tar.gz
  • Upload date:
  • Size: 73.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.4

File hashes

Hashes for flax_orbax-0.1.4.tar.gz
Algorithm Hash digest
SHA256 270632eea3c1d78a07d6bd7a7243405534fd86e7f617628d82deb6c20e913a21
MD5 f6187256262865c9a01012214c91d53c
BLAKE2b-256 29968c55e4dd263e81ad971b07c891b1d840820a7023be91e867f9d00e461da1

See more details on using hashes here.

File details

Details for the file flax_orbax-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for flax_orbax-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 dc3fc0e71d854468426e5ed208f1e872b58604b3a617016217479209af4c6c6c
MD5 08e9b2109365cb20ada591ab1737116e
BLAKE2b-256 e79e2871c797cc03826b3ca990b717d772149e51c9c41e09e572e28f23434392

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page