Skip to main content

A package for saving JAX-compatible dataclasses with Orbax

Project description

Orbax 🤝 Dataclasses

A convenient way to serialize dataclasses (and for orbax) in an easier to read way (avoid Pickle!)

Usage:

Suppose we have a train state

import flax.linen as nn
import optax
import flax.training.train_state
import flax_orbax

model = nn.Sequential([nn.Dense(10, kernel_init=nn.initializers.ones), nn.Dense(10, kernel_init=nn.initializers.ones)])
params = model.init(jax.random.key(0), jax.numpy.ones((1, 20)))['params']
tx = flax_orbax.wrap(optax.adam)(1e-3) # Add flax_orbax.wrap to keep track of objects that aren't serializable
state = flax.training.train_state.TrainState.create(apply_fn=model, params=params, tx=tx)

Now, we can save this easily

import orbax.checkpoint as ocp
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
ckptr = flax_orbax.Checkpointer()
ckptr.save(path / '0', state)
ckptr.wait_until_finished()
ckptr.restore(path / '0') # Unlike StandardCheckpointer(), this will return a train state! not a dict

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax_orbax-0.1.2.tar.gz (73.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax_orbax-0.1.2-py3-none-any.whl (6.7 kB view details)

Uploaded Python 3

File details

Details for the file flax_orbax-0.1.2.tar.gz.

File metadata

  • Download URL: flax_orbax-0.1.2.tar.gz
  • Upload date:
  • Size: 73.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.4

File hashes

Hashes for flax_orbax-0.1.2.tar.gz
Algorithm Hash digest
SHA256 1efd8b521bcfe2f99c2f76418e6495d307f8c28bbb72634e4293ace9197aadc9
MD5 bf96a6337536278e0c0e854cae7bda99
BLAKE2b-256 9c8354b323eb14f3f687bfdff62236970a9130e2ce8c3981fadfa75afe32e8b0

See more details on using hashes here.

File details

Details for the file flax_orbax-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for flax_orbax-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c00b251ca9cd028c4c87ee3acd6f19e3d6890352c3f58b89248ed6418ab18e5f
MD5 205ff52696c1cc786587f5d134027a9d
BLAKE2b-256 cbad55a6d2cba37a1d5ada5396aa08c6737f22c9d49c30997945ff1bae160c56

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page