Serialize JAX, Flax, Haiku, or Objax model params with `safetensors`
Project description
🔐 Serialize JAX, Flax, Haiku, or Objax model params with safetensors
safejax is a Python package to serialize JAX, Flax, Haiku, or Objax model params using safetensors
as the tensor storage format, instead of relying on pickle. For more details on why
safetensors is safer than pickle please check huggingface/safetensors.
Note that safejax supports the serialization of jax, flax, dm-haiku, and objax model
parameters and has been tested with all those frameworks, but there may be some cases where it
does not work as expected, as this is still in an early development phase, so please if you have
any feedback or bug reports, open an issue at safejax/issues.
🛠️ Requirements & Installation
safejax requires Python 3.7 or above
pip install safejax --upgrade
💻 Usage
flax
-
Convert
paramstobytesin memoryfrom safejax.flax import serialize, deserialize params = model.init(...) encoded_bytes = serialize(params) decoded_params = deserialize(encoded_bytes) model.apply(decoded_params, ...)
-
Convert
paramstobytesinparams.safetensorsfilefrom safejax.flax import serialize, deserialize params = model.init(...) encoded_bytes = serialize(params, filename="./params.safetensors") decoded_params = deserialize("./params.safetensors") model.apply(decoded_params, ...)
dm-haiku
-
Just contains
paramsfrom safejax.haiku import serialize, deserialize params = model.init(...) encoded_bytes = serialize(params) decoded_params = deserialize(encoded_bytes) model.apply(decoded_params, ...)
-
If it contains
paramsandstatee.g. ExponentialMovingAverage in BatchNormfrom safejax.haiku import serialize, deserialize params, state = model.init(...) params_state = {"params": params, "state": state} encoded_bytes = serialize(params_state) decoded_params_state = deserialize(encoded_bytes) # .keys() contains `params` and `state` model.apply(decoded_params_state["params"], decoded_params_state["state"], ...)
-
If it contains
paramsandstate, but we want to serialize those individuallyfrom safejax.haiku import serialize, deserialize params, state = model.init(...) encoded_bytes = serialize(params) decoded_params = deserialize(encoded_bytes) encoded_bytes = serialize(state) decoded_state = deserialize(encoded_bytes) model.apply(decoded_params, decoded_state, ...)
objax
-
Convert
paramstobytesin memory, and convert back toVarCollectionfrom safejax.objax import serialize, deserialize params = model.vars() encoded_bytes = serialize(params=params) decoded_params = deserialize(encoded_bytes) for key, value in decoded_params.items(): if key in model.vars(): model.vars()[key].assign(value.value) model(...)
-
Convert
paramstobytesinparams.safetensorsfilefrom safejax.objax import serialize, deserialize params = model.vars() encoded_bytes = serialize(params=params, filename="./params.safetensors") decoded_params = deserialize("./params.safetensors") for key, value in decoded_params.items(): if key in model.vars(): model.vars()[key].assign(value.value) model(...)
-
Convert
paramstobytesinparams.safetensorsand assign during deserializationfrom safejax.objax import serialize, deserialize_with_assignment params = model.vars() encoded_bytes = serialize(params=params, filename="./params.safetensors") deserialize_with_assignment(filename="./params.safetensors", model_vars=params) model(...)
More in-detail examples can be found at examples/ for flax, dm-haiku, and objax.
🤔 Why safejax?
safetensors defines an easy and fast (zero-copy) format to store tensors,
while pickle has some known weaknesses and security issues. safetensors
is also a storage format that is intended to be trivial to the framework
used to load the tensors. More in-depth information can be found at
huggingface/safetensors.
jax uses pytrees to store the model parameters in memory, so
it's a dictionary-like class containing nested jnp.DeviceArray tensors.
dm-haiku uses a custom dictionary formatted as <level_1>/~/<level_2>, where the
levels are the ones that define the tree structure and /~/ is the separator between those
e.g. res_net50/~/intial_conv, and that key does not contain a jnp.DeviceArray, but a
dictionary with key value pairs e.g. for both weights as w and biases as b.
objax defines a custom dictionary-like class named VarCollection that contains
some variables inheriting from BaseVar which is another custom objax type.
flax defines a dictionary-like class named FrozenDict that is used to
store the tensors in memory, it can be dumped either into bytes in MessagePack
format or as a state_dict.
Of all those, flax is the only framework that defines its custom functions to
serialize and deserialize the model params under flax.serialization.But flax still
uses pickle as the format for storing the tensors, and there are no plans from HuggingFace
to extend safetensors to support anything more than tensors e.g. FrozenDicts, see their
response at huggingface/safetensors/discussions/138.
So the motivation to create safejax is to easily provide a way to serialize FrozenDicts
using safetensors as the tensor storage format instead of pickle, as well as to provide
a common and easy way to serialize and deserialize any JAX model params (Flax, Haiku, or Objax)
using safetensors format.
📄 Main differences with flax.serialization
flax.serialization.to_bytesusespickleas the tensor storage format, whilesafejax.serializeusessafetensorsflax.serialization.from_bytesrequires thetargetto be instantiated, whilesafejax.deserializejust needs the encoded bytes
🏋🏼 Benchmark
Benchmarks are no longer running with hyperfine,
as most of the elapsed time is not during the actual serialization but in the imports and
the model parameter initialization. So we've refactored those to run with pure
Python code using time.perf_counter to measure the elapsed time in seconds.
$ python benchmarks/resnet50.py
safejax (100 runs): 2.0974 s
flax (100 runs): 4.8734 s
This means that for ResNet50, safejax is x2.3 times faster than flax.serialization when
it comes to serialization, also to restate the fact that safejax stores the tensors with
safetensors while flax saves those with pickle.
But if we use hyperfine as mentioned above, it needs
to be installed first, and the hatch/pyenv environment needs to be activated
first (or just install the requirements). But, due to the overhead of the script, the
elapsed time during the serialization will be minimal compared to the rest, so the overall
result won't reflect well enough the efficiency diff between both approaches, as above.
$ hyperfine --warmup 2 "python benchmarks/hyperfine/resnet50.py serialization_safejax" "python benchmarks/hyperfine/resnet50.py serialization_flax"
Benchmark 1: python benchmarks/hyperfine/resnet50.py serialization_safejax
Time (mean ± σ): 1.778 s ± 0.038 s [User: 3.345 s, System: 0.511 s]
Range (min … max): 1.741 s … 1.877 s 10 runs
Benchmark 2: python benchmarks/hyperfine/resnet50.py serialization_flax
Time (mean ± σ): 1.790 s ± 0.011 s [User: 3.371 s, System: 0.478 s]
Range (min … max): 1.771 s … 1.810 s 10 runs
Summary
'python benchmarks/hyperfine/resnet50.py serialization_safejax' ran
1.01 ± 0.02 times faster than 'python benchmarks/hyperfine/resnet50.py serialization_flax'
As we can see the difference is almost not noticeable, since the benchmark is using a
2-tensor dictionary, which should be faster using any method. The main difference is on
the safetensors usage for the tensor storage instead of pickle.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file safejax-0.5.0.tar.gz.
File metadata
- Download URL: safejax-0.5.0.tar.gz
- Upload date:
- Size: 16.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.23.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
093898d98ba32ca3d5376177d638d0a97666c9b1b041154276144dc4526a0086
|
|
| MD5 |
28745fbb4bdfc06970963e0c1778b71a
|
|
| BLAKE2b-256 |
229db19f7d173ac0080a8ea8c796073e860b484e3eeda9fc55f6080ea8b31de4
|
File details
Details for the file safejax-0.5.0-py3-none-any.whl.
File metadata
- Download URL: safejax-0.5.0-py3-none-any.whl
- Upload date:
- Size: 12.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.23.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
182146a12c621f2695e6b927d6b71d4684003c39d757c77a9d5ca8a3b8105a78
|
|
| MD5 |
42fe8b80947d03c2baa50a4a9fbce912
|
|
| BLAKE2b-256 |
f21c4688d6da36e8a9cd161d1b945efb03771566a9c30e294df2028527559dd4
|