Skip to main content

Serialize JAX/Flax models with `safetensors`

Project description

🔐 Serialize JAX/Flax models with safetensors

safejax is a Python package to serialize JAX and Flax models using safetensors as the tensor storage format, instead of relying on pickle. For more details on why safetensors is safer than pickle please check https://github.com/huggingface/safetensors.

🛠️ Requirements & Installation

safejax requires Python 3.7 or above

pip install safejax --upgrade

💻 Usage

import jax
from flax import linen as nn
from jax import numpy as jnp

from safejax.flax import serialize


class SingleLayerModel(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.features)(x)
        return x


model = SingleLayerModel(features=1)

rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, 1)))

serialized = serialize(frozen_or_unfrozen_dict=params)
assert isinstance(serialized, bytes)
assert len(serialized) > 0

More examples can be found at examples/.

🤔 Why safejax?

safetensors defines an easy and fast (zero-copy) format to store tensors, while pickle has some known weaknesses and security issues. safetensors is also a storage format that is intended to be trivial to the framework used to load the tensors. More in depth information can be found at https://github.com/huggingface/safetensors.

flax defines a dictionary-like class named FrozenDict that is used to store the tensors in memory, it can be dumped either into bytes in MessagePack format or as a state_dict.

Anyway, flax still uses pickle as the format for storing the tensors, so there are no plans from HuggingFace to extend safetensors to support anything more than tensors e.g. FrozenDicts, see their response at https://github.com/huggingface/safetensors/discussions/138.

So safejax was created so as to easily provide a way to serialize FrozenDicts using safetensors as the tensor storage format instead of pickle.

📄 Main differences with flax.serialization

  • flax.serialization.to_bytes uses pickle as the tensor storage format, while safejax.flax.serialize uses safetensors
  • flax.serialization.from_bytes requires the target to be instantiated, while safejax.flax.deserialize just needs the encoded bytes

🏋🏼 Benchmark

Benchmarks use hyperfine so it needs to be installed first.

$ hyperfine --warmup 2 "hatch run python benchmark.py benchmark_safejax" "hatch run python benchmark.py benchmark_flax"
Benchmark 1: hatch run python benchmark.py benchmark_safejax
  Time (mean ± σ):     671.3 ms ±   7.5 ms    [User: 2169.9 ms, System: 391.4 ms]
  Range (min  max):   652.2 ms  680.7 ms    10 runs
 
Benchmark 2: hatch run python benchmark.py benchmark_flax
  Time (mean ± σ):     676.0 ms ±  12.8 ms    [User: 2245.6 ms, System: 324.0 ms]
  Range (min  max):   650.3 ms  690.3 ms    10 runs
 
Summary
  'hatch run python benchmark.py benchmark_safejax' ran
    1.01 ± 0.02 times faster than 'hatch run python benchmark.py benchmark_flax'

As we can see the difference is almost not noticeable, since the benchmark is using a 2-tensor dictionary, which should be faster using any method. The main difference is on the safetensors usage for the tensor storage instead of pickle.

More in detailed and complex benchmarks will be prepared soon!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

safejax-0.1.0.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

safejax-0.1.0-py3-none-any.whl (5.5 kB view details)

Uploaded Python 3

File details

Details for the file safejax-0.1.0.tar.gz.

File metadata

  • Download URL: safejax-0.1.0.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.23.1

File hashes

Hashes for safejax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 fedfbb45ad06d9f60cde0d4ac78929f9cbbe0ac218e38be9cb84f5b04a59eda5
MD5 51960b5ed56b360a7eacfc668a2d906d
BLAKE2b-256 9b553367549053d76533f20d0509bef15182423f007ef1540d0245580b9492bf

See more details on using hashes here.

File details

Details for the file safejax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: safejax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 5.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.23.1

File hashes

Hashes for safejax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0213b31e6b78cf56fd9599fe518c033df8a0846dd3f3ca55f890452f5ee2184a
MD5 0e747097a42a0f1e13aa4af7ae2417dc
BLAKE2b-256 0e5a8389dde86b07440b9e8a8b8f1fb2572356c8354af958ce0d1bcbca88c0d4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page