Serialize JAX/Flax models with `safetensors`
Project description
🔐 Serialize JAX/Flax models with safetensors
safejax
is a Python package to serialize JAX and Flax models using safetensors
as the tensor storage format, instead of relying on pickle
. For more details on why
safetensors
is safer than pickle
please check https://github.com/huggingface/safetensors.
🛠️ Requirements & Installation
safejax
requires Python 3.7 or above
pip install safejax --upgrade
💻 Usage
import jax
from flax import linen as nn
from jax import numpy as jnp
from safejax.flax import serialize
class SingleLayerModel(nn.Module):
features: int
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
return x
model = SingleLayerModel(features=1)
rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, 1)))
serialized = serialize(frozen_or_unfrozen_dict=params)
assert isinstance(serialized, bytes)
assert len(serialized) > 0
More examples can be found at examples/
.
🤔 Why safejax
?
safetensors
defines an easy and fast (zero-copy) format to store tensors,
while pickle
has some known weaknesses and security issues. safetensors
is also a storage format that is intended to be trivial to the framework
used to load the tensors. More in depth information can be found at
https://github.com/huggingface/safetensors.
flax
defines a dictionary-like class named FrozenDict
that is used to
store the tensors in memory, it can be dumped either into bytes
in MessagePack
format or as a state_dict
.
Anyway, flax
still uses pickle
as the format for storing the tensors, so
there are no plans from HuggingFace to extend safetensors
to support anything
more than tensors e.g. FrozenDict
s, see their response at
https://github.com/huggingface/safetensors/discussions/138.
So safejax
was created so as to easily provide a way to serialize FrozenDict
s
using safetensors
as the tensor storage format instead of pickle
.
📄 Main differences with flax.serialization
flax.serialization.to_bytes
usespickle
as the tensor storage format, whilesafejax.flax.serialize
usessafetensors
flax.serialization.from_bytes
requires thetarget
to be instantiated, whilesafejax.flax.deserialize
just needs the encoded bytes
🏋🏼 Benchmark
Benchmarks use hyperfine
so it needs
to be installed first.
$ hyperfine --warmup 2 "hatch run python benchmark.py benchmark_safejax" "hatch run python benchmark.py benchmark_flax"
Benchmark 1: hatch run python benchmark.py benchmark_safejax
Time (mean ± σ): 671.3 ms ± 7.5 ms [User: 2169.9 ms, System: 391.4 ms]
Range (min … max): 652.2 ms … 680.7 ms 10 runs
Benchmark 2: hatch run python benchmark.py benchmark_flax
Time (mean ± σ): 676.0 ms ± 12.8 ms [User: 2245.6 ms, System: 324.0 ms]
Range (min … max): 650.3 ms … 690.3 ms 10 runs
Summary
'hatch run python benchmark.py benchmark_safejax' ran
1.01 ± 0.02 times faster than 'hatch run python benchmark.py benchmark_flax'
As we can see the difference is almost not noticeable, since the benchmark is using a
2-tensor dictionary, which should be faster using any method. The main difference is on
the safetensors
usage for the tensor storage instead of pickle
.
More in detailed and complex benchmarks will be prepared soon!
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file safejax-0.1.0.tar.gz
.
File metadata
- Download URL: safejax-0.1.0.tar.gz
- Upload date:
- Size: 5.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.23.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fedfbb45ad06d9f60cde0d4ac78929f9cbbe0ac218e38be9cb84f5b04a59eda5 |
|
MD5 | 51960b5ed56b360a7eacfc668a2d906d |
|
BLAKE2b-256 | 9b553367549053d76533f20d0509bef15182423f007ef1540d0245580b9492bf |
File details
Details for the file safejax-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: safejax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 5.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.23.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0213b31e6b78cf56fd9599fe518c033df8a0846dd3f3ca55f890452f5ee2184a |
|
MD5 | 0e747097a42a0f1e13aa4af7ae2417dc |
|
BLAKE2b-256 | 0e5a8389dde86b07440b9e8a8b8f1fb2572356c8354af958ce0d1bcbca88c0d4 |