Skip to main content

A simple package to save and load JAX PyTrees to and from Safetensors

Project description

Pytree2Safetensors

Pytree2Safetensors is a simple package to save and load JAX PyTrees to and from Safetensors, a popular file format for saving neural network weights.

To install, run

pip install --upgrade pytree2safetensors

Pytree2Safetensors depends on jax, safetensors, and jaxtyping. You also need to have at least Python 3.10

Specification

Serialising/Deserialising

keypath2string(path: KeyPath) -> str

Serializes a JAX key path (i.e., a path to a leaf in a pytree) to a string by joining together a string representation of each key in the path. Prefixes of these representation tell what type of key it is. A GetAttryKey is prefixed with ".", a DictKey is prefixed with "@", and a SequenceKey is prefixed with "#". If the initial key is a GetAttryKey, the initial "." is left off.

Examples:

keypath2string((GetAttrKey("layers"), SequenceKey(10), DictKey("query"),))
# => "layers#10@query
keypath2string((SequenceKey(2), GetAttrKey("layers"), SequenceKey(10), DictKey("query"),))
# => "#2.layers#10@query

string2keypath(string: str) -> KeyPath

Inverse of keypath2string

pytree2dict(tree: PyTree) -> dict

Returns a dictionary of serialized key paths mapping to leaves of the tree.

dict2pytree(dictionary: dict) -> tree

Inverse of pytree2dict, except that it wraps attributes in PyTreeContainers instead of using the original object. This is because there is no way for the deserialiser to know what the original object was. You can use load_into_pytree to load weights into an initialized pytree.

PyTreeContainer

A class which implements the bare minimum to be a node in a pytree according to JAX.

Saving

save_pytree(tree: PyTree, path: str) -> None

Saves the pytree as a safetensors at the given path. Equivalent to

safetensors.flax.save_file(pytree2dict(tree), path)

Loading

load_file

Alias of safetensors.flax.load_file

load_pytree(path: str) -> PyTree

Loads a file and uses dict2pytree to convert the safetensors dict to a pytree.

set_weights(module: PyTree, dictionary: dict) -> PyTree

Given a pytree module and a safetensors dict, load the weights in the safetensors dict into the module using string2keypath to determine their paths. Returns a new pytree.

load_into_pytree(module: PyTree, path: str) -> PyTree

Equivalent to set_weights(module, load_file(path)).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytree2safetensors-0.1.4.tar.gz (4.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytree2safetensors-0.1.4-py3-none-any.whl (6.2 kB view details)

Uploaded Python 3

File details

Details for the file pytree2safetensors-0.1.4.tar.gz.

File metadata

  • Download URL: pytree2safetensors-0.1.4.tar.gz
  • Upload date:
  • Size: 4.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for pytree2safetensors-0.1.4.tar.gz
Algorithm Hash digest
SHA256 4d7c9bed9ed5b4bf6cc5dc6615b87c8bc96ed4ede82b7e9c1c22531e47babc65
MD5 470d04195338b4140cf13159556f59ec
BLAKE2b-256 057df4953f902c9291bbdc9bbc2a3f41962338385524144d86322e0af0a205a5

See more details on using hashes here.

File details

Details for the file pytree2safetensors-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for pytree2safetensors-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 5713fcb9b5f1fab0717c88bcfc3d7fe72b2dc39d865288f26e66fe5fad391657
MD5 3323d611078558552e0063bc43d4f902
BLAKE2b-256 a9b1a6ca7d1339a88126199b07fb02863832d96b2a08af6870e58a798cf1695a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page