Optimized PyTree Utilities.
Reason this release was yanked:
old version
Project description
OpTree
Optimized PyTree Utilities.
Table of Contents
Installation
pip3 install --upgrade optree
Install the latest version from GitHub:
pip3 install git+https://github.com/metaopt/optree.git#egg=optree
Or, clone this repo and install manually:
git clone --depth=1 --recurse-submodules https://github.com/metaopt/optree.git
cd optree
pip3 install .
Compiling from the source requires Python 3.6+, a compiler (gcc
/ clang
/ icc
/ cl.exe
) supports C++20 and a cmake
installation.
PyTrees
A PyTree is a recursive structure that can be an arbitrarily nested Python container (e.g., tuple
, list
, dict
, OrderedDict
, NamedTuple
, etc.) or an opaque Python object.
The key concepts of tree operations are tree flattening and its inverse (tree unflattening).
Additional tree operations can be performed based on these two basic functions (e.g., tree_map = tree_unflatten ∘ map ∘ tree_flatten
).
Tree flattening is traversing the entire tree in a left-to-right depth-first manner and returning the leaves of the tree in a deterministic order.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': 5, 'd': 6}
>>> optree.tree_flatten(tree)
([1, 2, 3, 4, 5, 6], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}))
>>> optree.tree_flatten(1)
([1], PyTreeSpec(*))
>>> optree.tree_flatten(None)
([], PyTreeSpec(None))
This usually implies that the equal pytrees return equal lists of trees and the same tree structure. See also section Key Ordering for Dictionaries.
>>> {'a': [1, 2], 'b': [3]} == {'b': [3], 'a': [1, 2]}
True
>>> optree.tree_leaves({'a': [1, 2], 'b': [3]}) == optree.tree_leaves({'b': [3], 'a': [1, 2]})
True
>>> optree.tree_structure({'a': [1, 2], 'b': [3]}) == optree.tree_structure({'b': [3], 'a': [1, 2]})
True
Tree Nodes and Leaves
A tree is a collection of non-leaf nodes and leaf nodes, where the leaf nodes have no children to flatten.
optree.tree_flatten(...)
will flatten the tree and return a list of leaf nodes while the non-leaf nodes will store in the tree specification.
Built-in PyTree Node Types
OpTree out-of-box supports the following Python container types in the registry:
tuple
list
dict
collections.namedtuple
and its subclassescollections.OrderedDict
collections.defaultdict
collections.deque
which are considered non-leaf nodes in the tree.
Python objects that the type is not registered will be treated as leaf nodes.
The registration lookup uses the is
operator to determine whether the type is matched.
So subclasses will need to explicitly register in the registration, otherwise, an object of that type will be considered as a leaf.
The NoneType
is a special case discussed in section None
is non-leaf Node vs. None
is Leaf.
Registering a Custom Container-like Type as Non-leaf Nodes
A container-like Python type can be registered in the container registry with a pair of functions that specify:
flatten_func(container) -> (children, metadata)
: convert an instance of the container type to a(children, metadata)
pair, wherechildren
is an iterable of subtrees.unflatten_func(metadata, children) -> container
: convert such a pair back to an instance of the container type.
The metadata
is some necessary data apart from the children to reconstruct the container, e.g., the keys of the dictionary (the children are values).
>>> import torch
>>> optree.register_pytree_node(
... torch.Tensor,
... flatten_func=lambda tensor: (
... (tensor.cpu().numpy(),),
... dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
... ),
... unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
... )
<class 'torch.Tensor'>
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
>>> tree
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
>>> leaves, treespec = optree.tree_flatten(tree)
>>> leaves, treespec
(
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
PyTreeSpec({
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
})
)
>>> optree.tree_unflatten(treespec, leaves)
{'bias': tensor([0., 0.]), 'weight': tensor([[1., 1.]], device='cuda:0')}
Users can also extend the pytree registry by decorating the custom class and defining an instance method tree_flatten
and a class method tree_unflatten
.
>>> from collections import UserDict
...
... @optree.register_pytree_node_class
... class MyDict(UserDict):
... def tree_flatten(self):
... reversed_keys = sorted(self.keys(), reverse=True)
... return [self[key] for key in reversed_keys], reversed_keys
...
... @classmethod
... def tree_unflatten(metadata, children):
... return MyDict(zip(metadata, children))
>>> optree.tree_flatten(MyDict(b=2, a=1, c=3))
([3, 2, 1], PyTreeSpec(CustomTreeNode(MyDict[['c', 'b', 'a']], [*, *, *])))
Limitations of the PyTree Type Registry
There are several limitations of the pytree type registry:
- The type registry is per-interpreter-dependent. This means registering a custom type in the registry affects all modules that use OpTree. The type registry does not support per-module isolation such as namespaces.
- The elements in the type registry are immutable. Users either cannot register the same type twice (i.e., update the type registry). Nor cannot remove a type from the type registry.
- Users cannot modify the behavior of already registered built-in types listed Built-in PyTree Node Types, such as key order sorting for
dict
andcollections.defaultdict
. - Inherited subclasses are not implicitly registered. The registration lookup uses
type(obj) is registered_type
rather thanisinstance(obj, registered_type)
. Users need to explicitly register all custom classes explicitly.
None
is non-leaf Node vs. None
is Leaf
The None
object is a special object in the Python language.
It serves some of the same purposes as null
(a pointer does not point to anything) in other programming languages, which denotes a variable is empty or marks default parameters.
However, the None
object is a singleton object rather than a pointer.
It may also serve as a sentinel value.
In addition, if a function has returned without any return value, it also implicitly returns the None
object.
By default, the None
object is considered a non-leaf node in the tree with arity 0, i.e., a non-leaf node that has no children.
This is slightly different than the definition of a non-leaf node as discussed above.
While flattening a tree, it will remain in the tree structure definitions rather than in the leaves list.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> optree.tree_flatten(tree)
([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}))
>>> optree.tree_flatten(tree, none_is_leaf=True)
([1, 2, 3, 4, None, 5], PyTreeSpec(NoneIsLeaf, {'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}))
>>> optree.tree_flatten(1)
([1], PyTreeSpec(*))
>>> optree.tree_flatten(None)
([], PyTreeSpec(None))
>>> optree.tree_flatten(None, none_is_leaf=True)
([None], PyTreeSpec(NoneIsLeaf, *))
OpTree provides a keyword argument none_is_leaf
to determine whether to consider the None
object as a leaf, like other opaque objects.
If none_is_leaf=True
, the None
object will place in the leaves list.
Otherwise, the None
object will remain in the tree specification (structure).
>>> import torch
>>> linear = torch.nn.Linear(in_features=3, out_features=2, bias=False)
>>> linear._parameters
OrderedDict([
('weight', Parameter containing:
tensor([[-0.6677, 0.5209, 0.3295],
[-0.4876, -0.3142, 0.1785]], requires_grad=True)),
('bias', None)
])
>>> optree.tree_map(torch.zeros_like, linear._parameters)
OrderedDict([
('weight', tensor([[0., 0., 0.],
[0., 0., 0.]])),
('bias', None)
])
>>> optree.tree_map(torch.zeros_like, linear._parameters, none_is_leaf=True)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not NoneType
>>> optree.tree_map(lambda t: torch.zeros_like(t) if t is not None else 0, linear._parameters, none_is_leaf=True)
OrderedDict([
('weight', tensor([[0., 0., 0.],
[0., 0., 0.]])),
('bias', 0)
])
Key Ordering for Dictionaries
The built-in Python dictionary (i.e., builtins.dict
) is an unordered mapping that holds the keys and values.
The leaves of a dictionary are the values. Although since Python 3.6, the built-in dictionary is insertion ordered (PEP 468).
The dictionary equality operator (==
) does not check for key ordering.
To ensure that "equal dict
" implies "equal ordering of leaves", the order of values of the dictionary is sorted by the keys.
This behavior is also applied to collections.defaultdict
.
>>> optree.tree_flatten({'a': [1, 2], 'b': [3]})
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))
>>> optree.tree_flatten({'b': [3], 'a': [1, 2]})
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))
Note that there are no restrictions on the dict
to require the keys are comparable (sortable).
There can be multiple types of keys in the dictionary.
The keys are sorted in ascending order by key=lambda k: k
first if capable otherwise fallback to key=lambda k: (k.__class__.__qualname__, k)
. This handles most cases.
>>> sorted({1: 2, 1.5: 1}.keys())
[1, 1.5]
>>> sorted({'a': 3, 1: 2, 1.5: 1}.keys())
TypeError: '<' not supported between instances of 'int' and 'str'
>>> sorted({'a': 3, 1: 2, 1.5: 1}.keys(), key=lambda k: (k.__class__.__qualname__, k))
[1.5, 1, 'a']
If users want to keep the values in the insertion order, they should use collection.OrderedDict
, which will take the order of keys under consideration:
>>> OrderedDict([('a', [1, 2]), ('b', [3])]) == OrderedDict([('b', [3]), ('a', [1, 2])])
False
>>> optree.tree_flatten(OrderedDict([('a', [1, 2]), ('b', [3])]))
([1, 2, 3], PyTreeSpec(OrderedDict([('a', [*, *]), ('b', [*])])))
>>> optree.tree_flatten(OrderedDict([('b', [3]), ('a', [1, 2])]))
([3, 1, 2], PyTreeSpec(OrderedDict([('b', [*]), ('a', [*, *])])))
Benchmark
We benchmark the performance of:
- tree flatten
- tree unflatten
- tree copy (i.e.,
unflatten(flatten(...))
) - tree map
compared with the following libraries:
- OpTree (
@44f7410
) - JAX XLA (
jax[cpu] == 0.3.23
) - PyTorch (
torch == 1.12.1
)
All results are reported on a workstation with an AMD Ryzen 9 5950X CPU @ 4.45GHz in an isolated virtual environment with Python 3.10.8. Run with the following command:
python3 benchmark.py --number=10000 --repeat=5
The test inputs are nested containers (i.e., pytrees) extracted from torch.nn.Module
objects.
They are:
tiny_custom = nn.Sequential(
nn.Linear(1, 1, bias=True),
nn.BatchNorm1d(1, affine=True, track_running_stats=True),
nn.ReLU(),
nn.Linear(1, 1, bias=False),
nn.Sigmoid(),
)
and AlexNet, ResNet18, ResNet50, ResNet101, ResNet152, VisionTransformerH14 (ViT-H/14), and SwinTransformerB (Swin-B) from torchvsion
.
Please refer to benchmark.py
for more details.
TinyCustom(num_leaves=16, num_nodes=53, treespec=PyTreeSpec([OrderedDict([('tenso...), buffers=OrderedDict([])))])]))
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 27.18us <= optree.tree_leaves(x) (None is Node)
~ OpTree : 27.38us -- x1.01 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
✔ OpTree : 27.09us -- x1.00 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 76.18us -- x2.80 <= jax.tree_util.tree_leaves(x)
PyTorch: 671.56us -- x24.71 <= torch_utils_pytree.tree_flatten(x)[0]
### Tree UnFlatten ###
✔ OpTree : 63.18us <= optree.tree_unflatten(spec, flat) (None is Node)
~ OpTree : 63.30us -- x1.00 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 133.89us -- x2.12 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 248.43us -- x3.93 <= torch_utils_pytree.tree_unflatten(flat, spec)
### Tree Copy ###
✔ OpTree : 91.51us <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 94.17us -- x1.03 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
~ OpTree : 94.23us -- x1.03 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 216.89us -- x2.37 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 940.32us -- x10.28 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
### Tree Map ###
✔ OpTree : 99.84us <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 102.82us -- x1.03 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
~ OpTree : 102.21us -- x1.02 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 236.62us -- x2.37 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 984.33us -- x9.86 <= torch_utils_pytree.tree_map(fn1, x)
### Tree Map (nargs) ###
✔ OpTree : 139.50us <= optree.tree_map(fn3, x, y, z) (None is Node)
~ OpTree : 142.06us -- x1.02 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 143.19us -- x1.03 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 377.20us -- x2.70 <= jax.tree_util.tree_map(fn3, x, y, z)
AlexNet(num_leaves=32, num_nodes=188, treespec=PyTreeSpec(OrderedDict([('featur...]), buffers=OrderedDict([])))])))
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
✔ OpTree : 80.37us <= optree.tree_leaves(x) (None is Node)
~ OpTree : 87.58us -- x1.09 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
~ OpTree : 87.06us -- x1.08 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 297.95us -- x3.71 <= jax.tree_util.tree_leaves(x)
PyTorch: 2650.24us -- x32.98 <= torch_utils_pytree.tree_flatten(x)[0]
### Tree UnFlatten ###
✔ OpTree : 245.71us <= optree.tree_unflatten(spec, flat) (None is Node)
~ OpTree : 247.18us -- x1.01 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 545.17us -- x2.22 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 984.27us -- x4.01 <= torch_utils_pytree.tree_unflatten(flat, spec)
### Tree Copy ###
✔ OpTree : 332.03us <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 338.27us -- x1.02 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
~ OpTree : 337.91us -- x1.02 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 866.38us -- x2.61 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 3654.14us -- x11.01 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
### Tree Map ###
✔ OpTree : 347.63us <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 353.01us -- x1.02 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
~ OpTree : 353.95us -- x1.02 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 882.86us -- x2.54 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 3703.29us -- x10.65 <= torch_utils_pytree.tree_map(fn1, x)
### Tree Map (nargs) ###
✔ OpTree : 498.91us <= optree.tree_map(fn3, x, y, z) (None is Node)
~ OpTree : 499.80us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 499.17us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 1447.79us -- x2.90 <= jax.tree_util.tree_map(fn3, x, y, z)
ResNet18(num_leaves=244, num_nodes=698, treespec=PyTreeSpec(OrderedDict([('conv1'...]), buffers=OrderedDict([])))])))
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
✔ OpTree : 283.47us <= optree.tree_leaves(x) (None is Node)
~ OpTree : 290.72us -- x1.03 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
~ OpTree : 288.77us -- x1.02 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 928.19us -- x3.27 <= jax.tree_util.tree_leaves(x)
PyTorch: 9306.54us -- x32.83 <= torch_utils_pytree.tree_flatten(x)[0]
### Tree UnFlatten ###
~ OpTree : 816.54us <= optree.tree_unflatten(spec, flat) (None is Node)
✔ OpTree : 814.80us -- x1.00 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 1669.60us -- x2.04 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 3476.16us -- x4.26 <= torch_utils_pytree.tree_unflatten(flat, spec)
### Tree Copy ###
~ OpTree : 1145.87us <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 1149.91us -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
✔ OpTree : 1145.46us -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 2676.30us -- x2.34 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 13040.00us -- x11.38 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
### Tree Map ###
~ OpTree : 1246.56us <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 1241.41us -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
✔ OpTree : 1236.65us -- x0.99 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 2842.26us -- x2.28 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 13090.70us -- x10.50 <= torch_utils_pytree.tree_map(fn1, x)
### Tree Map (nargs) ###
✔ OpTree : 1754.06us <= optree.tree_map(fn3, x, y, z) (None is Node)
~ OpTree : 1758.19us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 1763.34us -- x1.01 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 4581.00us -- x2.61 <= jax.tree_util.tree_map(fn3, x, y, z)
ResNet50(num_leaves=640, num_nodes=1702, treespec=PyTreeSpec(OrderedDict([('conv1'...]), buffers=OrderedDict([])))])))
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 749.18us <= optree.tree_leaves(x) (None is Node)
✔ OpTree : 735.95us -- x0.98 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
~ OpTree : 737.40us -- x0.98 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 2210.58us -- x2.95 <= jax.tree_util.tree_leaves(x)
PyTorch: 21962.28us -- x29.32 <= torch_utils_pytree.tree_flatten(x)[0]
### Tree UnFlatten ###
~ OpTree : 1967.52us <= optree.tree_unflatten(spec, flat) (None is Node)
✔ OpTree : 1944.32us -- x0.99 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 3909.87us -- x1.99 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 8222.80us -- x4.18 <= torch_utils_pytree.tree_unflatten(flat, spec)
### Tree Copy ###
~ OpTree : 2782.21us <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 2788.91us -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
✔ OpTree : 2771.00us -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 6235.03us -- x2.24 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 30523.06us -- x10.97 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
### Tree Map ###
~ OpTree : 2997.47us <= optree.tree_map(fn1, x) (None is Node)
✔ OpTree : 2993.21us -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
~ OpTree : 3004.80us -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 6582.85us -- x2.20 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 30674.16us -- x10.23 <= torch_utils_pytree.tree_map(fn1, x)
### Tree Map (nargs) ###
~ OpTree : 4190.19us <= optree.tree_map(fn3, x, y, z) (None is Node)
✔ OpTree : 4187.02us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 4200.12us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 10519.68us -- x2.51 <= jax.tree_util.tree_map(fn3, x, y, z)
ResNet101(num_leaves=1252, num_nodes=3317, treespec=PyTreeSpec(OrderedDict([('conv1'...]), buffers=OrderedDict([])))])))
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 1443.28us <= optree.tree_leaves(x) (None is Node)
~ OpTree : 1462.98us -- x1.01 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
✔ OpTree : 1414.41us -- x0.98 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 4301.47us -- x2.98 <= jax.tree_util.tree_leaves(x)
PyTorch: 42706.19us -- x29.59 <= torch_utils_pytree.tree_flatten(x)[0]
### Tree UnFlatten ###
~ OpTree : 3889.41us <= optree.tree_unflatten(spec, flat) (None is Node)
✔ OpTree : 3885.40us -- x1.00 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 7656.81us -- x1.97 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 16058.19us -- x4.13 <= torch_utils_pytree.tree_unflatten(flat, spec)
### Tree Copy ###
~ OpTree : 5442.14us <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 5422.35us -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
✔ OpTree : 5407.01us -- x0.99 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 12184.79us -- x2.24 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 59239.08us -- x10.89 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
### Tree Map ###
~ OpTree : 5857.83us <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 5845.61us -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
✔ OpTree : 5819.69us -- x0.99 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 12816.78us -- x2.19 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 59487.90us -- x10.16 <= torch_utils_pytree.tree_map(fn1, x)
### Tree Map (nargs) ###
~ OpTree : 8145.57us <= optree.tree_map(fn3, x, y, z) (None is Node)
✔ OpTree : 8138.75us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 8148.81us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 20070.87us -- x2.46 <= jax.tree_util.tree_map(fn3, x, y, z)
ResNet152(num_leaves=1864, num_nodes=4932, treespec=PyTreeSpec(OrderedDict([('conv1'...]), buffers=OrderedDict([])))])))
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 2180.13us <= optree.tree_leaves(x) (None is Node)
~ OpTree : 2170.51us -- x1.00 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
✔ OpTree : 2140.18us -- x0.98 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 6225.77us -- x2.86 <= jax.tree_util.tree_leaves(x)
PyTorch: 62329.75us -- x28.59 <= torch_utils_pytree.tree_flatten(x)[0]
### Tree UnFlatten ###
~ OpTree : 5734.21us <= optree.tree_unflatten(spec, flat) (None is Node)
✔ OpTree : 5715.35us -- x1.00 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 11297.46us -- x1.97 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 22897.60us -- x3.99 <= torch_utils_pytree.tree_unflatten(flat, spec)
### Tree Copy ###
~ OpTree : 7997.82us <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 8009.89us -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
✔ OpTree : 7960.10us -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 17619.27us -- x2.20 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 85951.24us -- x10.75 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
### Tree Map ###
~ OpTree : 8524.99us <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 8522.01us -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
✔ OpTree : 8512.22us -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 18695.07us -- x2.19 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 86562.20us -- x10.15 <= torch_utils_pytree.tree_map(fn1, x)
### Tree Map (nargs) ###
✔ OpTree : 11886.65us <= optree.tree_map(fn3, x, y, z) (None is Node)
~ OpTree : 11928.96us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 11902.22us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 29821.16us -- x2.51 <= jax.tree_util.tree_map(fn3, x, y, z)
VisionTransformerH14(num_leaves=784, num_nodes=3420, treespec=PyTreeSpec(OrderedDict([('conv_p...]), buffers=OrderedDict([])))])))
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 1651.80us <= optree.tree_leaves(x) (None is Node)
~ OpTree : 1651.72us -- x1.00 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
✔ OpTree : 1647.63us -- x1.00 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 4725.58us -- x2.86 <= jax.tree_util.tree_leaves(x)
PyTorch: 44551.83us -- x26.97 <= torch_utils_pytree.tree_flatten(x)[0]
### Tree UnFlatten ###
✔ OpTree : 4321.63us <= optree.tree_unflatten(spec, flat) (None is Node)
~ OpTree : 4335.10us -- x1.00 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 9133.98us -- x2.11 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 17448.01us -- x4.04 <= torch_utils_pytree.tree_unflatten(flat, spec)
### Tree Copy ###
~ OpTree : 6116.50us <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 6100.51us -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
✔ OpTree : 6095.86us -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 14116.93us -- x2.31 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 62494.90us -- x10.22 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
### Tree Map ###
✔ OpTree : 6262.65us <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 6272.06us -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
~ OpTree : 6272.94us -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 14489.26us -- x2.31 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 62355.37us -- x9.96 <= torch_utils_pytree.tree_map(fn1, x)
### Tree Map (nargs) ###
✔ OpTree : 8886.77us <= optree.tree_map(fn3, x, y, z) (None is Node)
~ OpTree : 8893.93us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 8932.03us -- x1.01 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 23434.28us -- x2.64 <= jax.tree_util.tree_map(fn3, x, y, z)
SwinTransformerB(num_leaves=706, num_nodes=2867, treespec=PyTreeSpec(OrderedDict([('featur...]), buffers=OrderedDict([])))])))
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 1369.87us <= optree.tree_leaves(x) (None is Node)
~ OpTree : 1382.60us -- x1.01 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
✔ OpTree : 1368.06us -- x1.00 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 4066.15us -- x2.97 <= jax.tree_util.tree_leaves(x)
PyTorch: 37490.24us -- x27.37 <= torch_utils_pytree.tree_flatten(x)[0]
### Tree UnFlatten ###
~ OpTree : 3707.21us <= optree.tree_unflatten(spec, flat) (None is Node)
✔ OpTree : 3693.14us -- x1.00 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 7749.16us -- x2.09 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 14828.41us -- x4.00 <= torch_utils_pytree.tree_unflatten(flat, spec)
### Tree Copy ###
~ OpTree : 5154.76us <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
✔ OpTree : 5127.40us -- x0.99 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
~ OpTree : 5149.86us -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 12031.65us -- x2.33 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 52536.88us -- x10.19 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
### Tree Map ###
~ OpTree : 5359.65us <= optree.tree_map(fn1, x) (None is Node)
✔ OpTree : 5334.72us -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
~ OpTree : 5335.08us -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 12371.49us -- x2.31 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 52645.16us -- x9.82 <= torch_utils_pytree.tree_map(fn1, x)
### Tree Map (nargs) ###
~ OpTree : 7535.16us <= optree.tree_map(fn3, x, y, z) (None is Node)
✔ OpTree : 7520.10us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 7531.14us -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 19884.10us -- x2.64 <= jax.tree_util.tree_map(fn3, x, y, z)
License
OpTree is released under the Apache License 2.0.
OpTree is heavily based on JAX's implementation of the PyTree utility, with deep refactoring and several improvements. The original licenses can be found at JAX's Apache License 2.0 and Tensorflow's Apache License 2.0.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for optree-0.3.0-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | aaa6b1abf757f181146474b53d89d7017b300f019c5785c81f64b8b75d9c5dbd |
|
MD5 | bb37b68a6bbd31b80f44064bfe74c347 |
|
BLAKE2b-256 | 0e4d7f08d584d275a1149f69740a57fcd2344a2cc49f9e9f1b6fb8849d2ed774 |
Hashes for optree-0.3.0-cp311-cp311-musllinux_1_1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 320ffbb43e14ca5d86bc1c0eec1356176f6b45480f74eed81b015bd520150693 |
|
MD5 | 40983a96f37867237c1b49e7bf7fc6e0 |
|
BLAKE2b-256 | d87d5932c342787708f083010d558e6f2d879acd731dce28fd4702b97a3ae32e |
Hashes for optree-0.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 821e787374f34948bcdccd734c37c52542792103f64fdbf4b0b00c9d42e0a54a |
|
MD5 | aace85e556901be860cb5b7ba4939e70 |
|
BLAKE2b-256 | 32978064f0bef4665c3cf5d97c7adcca057de15c48b2868af035e7c131ef3148 |
Hashes for optree-0.3.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7aa7b458015716d2c67622a3b45edc3ee0911860a6591d97eb7f39001b4f576e |
|
MD5 | a2c0f71e75fe3e14bf106560d89ca483 |
|
BLAKE2b-256 | 6375e7d6e8da4ac9eb23bcf4a4e9ba84a117d2870e7f803739c282f6de57fc2e |
Hashes for optree-0.3.0-cp311-cp311-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e8d89511430b970d7f016969aaca8807e1ac587be9c4f39413d59ee80fef67f8 |
|
MD5 | ba404c1824f0c772be9eca504f9704b7 |
|
BLAKE2b-256 | 9cb4f8c6118ae4b93cc7e579a741a4c7d3ef7bfa08993bfc63bf92bf0d4589d1 |
Hashes for optree-0.3.0-cp311-cp311-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d2bdb592e36b3848a96f82785728331be0341c374ac941ab12bded985d3b22c9 |
|
MD5 | a2b894232b0a5e8b7407e028f5979345 |
|
BLAKE2b-256 | 380290a3f8f2b902ac4f06f778a45c20210db113f80474efdd7b0dc5f9c98c21 |
Hashes for optree-0.3.0-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5a728a5ea36c543c0962af5e41e065ddfaf03fa47837332ebf3746701eb33085 |
|
MD5 | 82aac7100cb2d4d9b270e6cd0cbe8764 |
|
BLAKE2b-256 | d1bd1b1da7b0b826789e09ef3a2d0932583ccd20e82d16a6f48a0441db94b9c1 |
Hashes for optree-0.3.0-cp310-cp310-musllinux_1_1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a413e02ed54609a57022ada12293855b6263ae82ccd073970abc396b326799d5 |
|
MD5 | 3beafdda31997ea87ce9f919abf2479b |
|
BLAKE2b-256 | be2aabd85b8a83fea18f08e9ed16aa7e3446ad66cc3abbc0c3007a4ed608717a |
Hashes for optree-0.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 48b1933bb508d532c9dc7fc66cd74c48229099b26f0e3efcb8122e382d2ecd1f |
|
MD5 | cb2bac6f09c9914bed408479e9307c59 |
|
BLAKE2b-256 | 0d476b34b9277a903c6cdc31c352b5156d1e2f395c75ddbe59710b6da8549eab |
Hashes for optree-0.3.0-cp310-cp310-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 05a45b46ec3c54d9f965123dac47b02d9e07ed59d52b5e626e503abd8a8bdf4f |
|
MD5 | 521eade12ff247ada0d1f6f8b8e5e93f |
|
BLAKE2b-256 | dcf35ad1f995a9b593bdf7ec4e4157ff60b1a3da5eb20cf31bb2eee0970bcecf |
Hashes for optree-0.3.0-cp310-cp310-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 88b3fbfb2cc121c8a7ef0ad9c2372ee29ae2c9ffa4ecf64c4544328cba3fe478 |
|
MD5 | cf627ed9e3c96c1d335b1ea8d5786202 |
|
BLAKE2b-256 | 9278a951c279ae215d2ed4556a94997b7b15797076b43270ac41daa08c20875c |
Hashes for optree-0.3.0-cp310-cp310-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bde2ecb9f08ec0ab84ebb2cfb02a2f7188afcfff471ba0a017355aba8604b4f2 |
|
MD5 | 96b5a1cd887050ca89c9be70f05b1393 |
|
BLAKE2b-256 | bc8ed7832a4994e1278abe68ded888ca94835abc5b44361b58025849c74794f8 |
Hashes for optree-0.3.0-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 96cf0ed8aeed865c4734cc075fe58570494bc081394951fa2c0b81d09d3bc040 |
|
MD5 | 9407ae241e4b38f8acc38fb26565f5b9 |
|
BLAKE2b-256 | 8308f0d3ef52ffb11cce1a7940f91e59789892bd1ac2a115c2f0d497eb974445 |
Hashes for optree-0.3.0-cp39-cp39-musllinux_1_1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2dc89d618aa95991dc7adae3170d3db32a0f1586512c167a39b17affaee8325e |
|
MD5 | 1c26b4ae39fe9e3d9ceb529733477058 |
|
BLAKE2b-256 | 9359e53e0d0dde4413ca412f377f760212a78e8f1dbc26f58d066535b3dbc64c |
Hashes for optree-0.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a4000606d171c2d93e4793b00c3014fe59f72805f33f42462ecf37839d16e3a3 |
|
MD5 | 7c98d479550e2a1902dce54332c798f7 |
|
BLAKE2b-256 | f9b9a2b51956ceb86f3bf0ae9b15f402570b15c301536a4b02022fbafa54c814 |
Hashes for optree-0.3.0-cp39-cp39-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 524d912f3330878fd79c1c1f5a966816078d8d4c39b173e06e1bab2faf944d2d |
|
MD5 | 115d83596891fc0d6ff71e8abd9dfe9b |
|
BLAKE2b-256 | a722f3da0b121498ae1f15c4b81eae385933b47a63160eb18a7dbf2e1fa28bef |
Hashes for optree-0.3.0-cp39-cp39-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c3318d4416e193d280b63978dce3ebe7feb694d404ec8d75ccaeb7a7f73b2f93 |
|
MD5 | df6390c72e6da6a5948a713de4df80f1 |
|
BLAKE2b-256 | 68e4481b00345ee02223dfe55e5d3c40c1e834af54645b5acae619ae65d68d10 |
Hashes for optree-0.3.0-cp39-cp39-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e736375a5041618ccd2245a71f042b806c60b903ca72ee37abfdf28bec0c96ee |
|
MD5 | 3e784e21c0b3145ffbec0073d2ed3d6b |
|
BLAKE2b-256 | 8d99c6293d5e42a937ebcda1d62cf8b45114e0fae11530713dd37f34e236a737 |
Hashes for optree-0.3.0-cp38-cp38-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6ab6586bb2cf870c7cb3ccb13cb811fb674ce82e78ba440f8b13631d75dceeac |
|
MD5 | 40ea75a0aa29f722a64fcd98f088ce47 |
|
BLAKE2b-256 | 602e4bf81fc5e66f4b3334ea9c8c88b3a1ca51cc39decf52ff9ed94336ad8c3f |
Hashes for optree-0.3.0-cp38-cp38-musllinux_1_1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 58674c7f3c1543f62327e73f68fea501185c5a5e60f8d4c1bfc1b03475850d68 |
|
MD5 | a846ad1e4e11f81ff3185c2c36b9405d |
|
BLAKE2b-256 | 93e71573efb02aa72c68b7c16257e458531ddf8767fdba659da99e148154dc21 |
Hashes for optree-0.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 461f4e94573d9dac1c24d5f86a2ad5326289ee521ce9268e0d800964911eaee6 |
|
MD5 | 70b9c90054dbfae7e7d61af6089c8613 |
|
BLAKE2b-256 | 266c53a759bc6038698c1b14b01ccc366add240ebebc7cbab4542f2583bb044e |
Hashes for optree-0.3.0-cp38-cp38-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 039c9dc2b1e13d34c30a517a0cb945976ca8340d61fdfb50c63a8f526eb48c6d |
|
MD5 | 564916239c3cffdef48d48ac982f907f |
|
BLAKE2b-256 | 8410861ecf45d1f7a1f9d52775b7f1122aca211e8d99a34d4c7ed69c2caa2ff1 |
Hashes for optree-0.3.0-cp38-cp38-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9352dce87d2bfe27bdd0e06e8f6d63970db1456e1ea3c0085d17c0cf26611ebb |
|
MD5 | a52e59826605a5c27f183cc675a9dd62 |
|
BLAKE2b-256 | e4797fefc0d8edf5eceabcd38643dfa7de4f265a99248e612da6b6b03f4d8714 |
Hashes for optree-0.3.0-cp38-cp38-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 118aed7cfc89b7bf650be47b330bb5ed9b6c345bb5ccd2231689f0ea91c32640 |
|
MD5 | c873d2b0efa7848de3f8ed186fdd5d95 |
|
BLAKE2b-256 | 6b8449b5602ec1c7f0ca7fa50373cbcf33f22fb2a6e1e1276cccdb356da239d2 |
Hashes for optree-0.3.0-cp37-cp37m-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3d328127166a1d9a08871022f426f51b2fd330bf12af22d466d32d56057f3ccb |
|
MD5 | a54ad89afb6513894c6c77e382dcb0f6 |
|
BLAKE2b-256 | 9996ae41c74885f962d4f6944e55e749866df53eec9f6aba9687aaacab17e26c |
Hashes for optree-0.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2cb3cf095d9d764b7ed47b37849cb5db7110b7d5b692521023eca579d90167b7 |
|
MD5 | 399bad3a150ba7f61edfd73f07cee9f3 |
|
BLAKE2b-256 | 2ebdf67f7fadf956eb4bb967034bb3a0e78548ac6bc1819f4ee0045d8b32e0bf |
Hashes for optree-0.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 376c6e9c399e58461c2686e91506bfe92884bb1c387d28f9d7b960f71ffee67b |
|
MD5 | 607437de6288fc0f2ed4dd7faa807f5c |
|
BLAKE2b-256 | 20d9b66040bc114a0ea976c204b154475f07c65cfc0dd5d40abbd3a8f06a813c |
Hashes for optree-0.3.0-cp37-cp37m-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c3156f38e152aa06b8f11124c52a9db886fb382791d12f807221ace1c95f05e7 |
|
MD5 | ca1c35ca13de730fe4197dc999bb054a |
|
BLAKE2b-256 | 133fb70f23fc51ea2f310227aaeb02db4d5d7791277d33a9921c6a76ad3c2b0d |