Optimized PyTree Utilities.
Reason this release was yanked:
old version
Project description
OpTree
Optimized PyTree Utilities.
Table of Contents
Installation
pip3 install --upgrade optree
conda install -c conda-forge optree
Install the latest version from GitHub:
pip3 install git+https://github.com/metaopt/optree.git#egg=optree
Or, clone this repo and install manually:
git clone --depth=1 https://github.com/metaopt/optree.git
cd optree
pip3 install .
Compiling from the source requires Python 3.7+, a compiler (gcc
/ clang
/ icc
/ cl.exe
) supports C++20 and a cmake
installation.
PyTrees
A PyTree is a recursive structure that can be an arbitrarily nested Python container (e.g., tuple
, list
, dict
, OrderedDict
, NamedTuple
, etc.) or an opaque Python object.
The key concepts of tree operations are tree flattening and its inverse (tree unflattening).
Additional tree operations can be performed based on these two basic functions (e.g., tree_map = tree_unflatten ∘ map ∘ tree_flatten
).
Tree flattening is traversing the entire tree in a left-to-right depth-first manner and returning the leaves of the tree in a deterministic order.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': 5, 'd': 6}
>>> optree.tree_flatten(tree)
([1, 2, 3, 4, 5, 6], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}))
>>> optree.tree_flatten(1)
([1], PyTreeSpec(*))
>>> optree.tree_flatten(None)
([], PyTreeSpec(None))
This usually implies that the equal pytrees return equal lists of leaves and the same tree structure. See also section Key Ordering for Dictionaries.
>>> {'a': [1, 2], 'b': [3]} == {'b': [3], 'a': [1, 2]}
True
>>> optree.tree_leaves({'a': [1, 2], 'b': [3]}) == optree.tree_leaves({'b': [3], 'a': [1, 2]})
True
>>> optree.tree_structure({'a': [1, 2], 'b': [3]}) == optree.tree_structure({'b': [3], 'a': [1, 2]})
True
Tree Nodes and Leaves
A tree is a collection of non-leaf nodes and leaf nodes, where the leaf nodes have no children to flatten.
optree.tree_flatten(...)
will flatten the tree and return a list of leaf nodes while the non-leaf nodes will store in the tree specification.
Built-in PyTree Node Types
OpTree out-of-box supports the following Python container types in the registry:
tuple
list
dict
collections.namedtuple
and its subclassescollections.OrderedDict
collections.defaultdict
collections.deque
which are considered non-leaf nodes in the tree.
Python objects that the type is not registered will be treated as leaf nodes.
The registration lookup uses the is
operator to determine whether the type is matched.
So subclasses will need to explicitly register in the registry, otherwise, an object of that type will be considered as a leaf.
The NoneType
is a special case discussed in section None
is non-leaf Node vs. None
is Leaf.
Registering a Container-like Custom Type as Non-leaf Nodes
A container-like Python type can be registered in the type registry with a pair of functions that specify:
flatten_func(container) -> (children, metadata, entries)
: convert an instance of the container type to a(children, metadata, entries)
triple, wherechildren
is an iterable of subtrees andentries
is an iterable of path entries of the container (e.g., indices or keys).unflatten_func(metadata, children) -> container
: convert such a pair back to an instance of the container type.
The metadata
is some necessary data apart from the children to reconstruct the container, e.g., the keys of the dictionary (the children are values).
The entries
can be omitted (only returns a pair) or is optional to implement (returns None
). If so, use range(len(children))
(i.e., flat indices) as path entries of the current node. The function signature can be flatten_func(container) -> (children, metadata)
or flatten_func(container) -> (children, metadata, None)
.
The following examples show how to register custom types and utilize them for tree_flatten
and tree_map
. Please refer to section Notes about the PyTree Type Registry for more information.
# Registry a Python type with lambda functions
optree.register_pytree_node(
set,
# (set) -> (children, metadata, None)
lambda s: (sorted(s), None, None),
# (metadata, children) -> (set)
lambda _, children: set(children),
namespace='set',
)
# Register a Python type into a namespace
import torch
optree.register_pytree_node(
torch.Tensor,
# (tensor) -> (children, metadata)
flatten_func=lambda tensor: (
(tensor.cpu().numpy(),),
dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
),
# (metadata, children) -> tensor
unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
namespace='torch2numpy',
)
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
>>> tree
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
# Flatten without specifying the namespace
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
# Flatten with the namespace
>>> leaves, treespec = optree.tree_flatten(tree, namespace='torch2numpy')
>>> leaves, treespec
(
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
},
namespace='torch2numpy'
)
)
# `entries` are not defined and use `range(len(children))`
>>> optree.tree_paths(tree, namespace='torch2numpy')
[('bias', 0), ('weight', 0)]
# Unflatten back to a copy of the original object
>>> optree.tree_unflatten(treespec, leaves)
{'bias': tensor([0., 0.]), 'weight': tensor([[1., 1.]], device='cuda:0')}
Users can also extend the pytree registry by decorating the custom class and defining an instance method tree_flatten
and a class method tree_unflatten
.
from collections import UserDict
@optree.register_pytree_node_class(namespace='mydict')
class MyDict(UserDict):
def tree_flatten(self): # -> (children, metadata, entries)
reversed_keys = sorted(self.keys(), reverse=True)
return (
[self[key] for key in reversed_keys], # children
reversed_keys, # metadata
reversed_keys, # entries
)
@classmethod
def tree_unflatten(cls, metadata, children):
return cls(zip(metadata, children))
>>> tree = MyDict(b=4, a=(2, 3), c=MyDict({'d': 5, 'f': 6}))
# Flatten without specifying the namespace
>>> optree.tree_flatten_with_path(tree) # `MyDict`s are leaf nodes
(
[()],
[MyDict(b=4, a=(2, 3), c=MyDict({'d': 5, 'f': 6}))],
PyTreeSpec(*)
)
# Flatten with the namespace
>>> optree.tree_flatten_with_path(tree, namespace='mydict')
(
[('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)],
[6, 5, 4, 2, 3],
PyTreeSpec(
CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyDict[['f', 'd']], [*, *]), *, (*, *)]),
namespace='mydict'
)
)
Notes about the PyTree Type Registry
There are several key attributes of the pytree type registry:
-
The type registry is per-interpreter-dependent. This means registering a custom type in the registry affects all modules that use OpTree.
- !!! WARNING !!! For safety reasons, a `namespace` must be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.
-
The elements in the type registry are immutable. Users can neither register the same type twice in the same namespace (i.e., update the type registry), nor remove a type from the type registry. To update the behavior of an already registered type, simply register it again with another
namespace
. -
Users cannot modify the behavior of already registered built-in types listed in Built-in PyTree Node Types, such as key order sorting for
dict
andcollections.defaultdict
. -
Inherited subclasses are not implicitly registered. The registration lookup uses
type(obj) is registered_type
rather thanisinstance(obj, registered_type)
. Users need to register the subclasses explicitly. To register all subclasses, it is easy to implement withmetaclass
or__init_subclass__
, for example:from collections import UserDict @optree.register_pytree_node_class(namespace='mydict') class MyDict(UserDict): def __init_subclass__(cls): # define this in the base class super().__init_subclass__() # Register a subclass to namespace 'mydict' optree.register_pytree_node_class(cls, namespace='mydict') def tree_flatten(self): # -> (children, metadata, entries) reversed_keys = sorted(self.keys(), reverse=True) return ( [self[key] for key in reversed_keys], # children reversed_keys, # metadata reversed_keys, # entries ) @classmethod def tree_unflatten(cls, metadata, children): return cls(zip(metadata, children)) # Subclasses will be automatically registered in namespace 'mydict' class MyAnotherDict(MyDict): pass
>>> tree = MyDict(b=4, a=(2, 3), c=MyAnotherDict({'d': 5, 'f': 6})) >>> optree.tree_flatten_with_path(tree, namespace='mydict') ( [('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)], [6, 5, 4, 2, 3], PyTreeSpec( CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyAnotherDict[['f', 'd']], [*, *]), *, (*, *)]), namespace='mydict' ) )
-
Be careful about the potential infinite recursion of the custom flatten function. The returned
children
from the custom flatten function are considered subtrees. They will be further flattened recursively. Thechildren
can have the same type as the current node. Users must design their termination condition carefully.import numpy as np import torch optree.register_pytree_node( np.ndarray, # Children are nest lists of Python objects lambda array: (np.atleast_1d(array).tolist(), array.ndim == 0), lambda scalar, rows: np.asarray(rows) if not scalar else np.asarray(rows[0]), namespace='numpy1', ) optree.register_pytree_node( np.ndarray, # Children are Python objects lambda array: ( list(array.ravel()), # list(1DArray[T]) -> List[T] dict(shape=array.shape, dtype=array.dtype) ), lambda metadata, children: np.asarray(children, dtype=metadata['dtype']).reshape(metadata['shape']), namespace='numpy2', ) optree.register_pytree_node( np.ndarray, # Returns a list of `np.ndarray`s without termination condition lambda array: ([array.ravel()], array.dtype), lambda shape, children: children[0].reshape(shape), namespace='numpy3', ) optree.register_pytree_node( torch.Tensor, # Children are nest lists of Python objects lambda tensor: (torch.atleast_1d(tensor).tolist(), tensor.ndim == 0), lambda scalar, rows: torch.tensor(rows) if not scalar else torch.tensor(rows[0])), namespace='torch1', ) optree.register_pytree_node( torch.Tensor, # Returns a list of `torch.Tensor`s without termination condition lambda tensor: ( list(tensor.view(-1)), # list(1DTensor[T]) -> List[0DTensor[T]] (STILL TENSORS!) tensor.shape ), lambda shape, children: torch.stack(children).reshape(shape), namespace='torch2', )
>>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy1') ( [0, 1, 2, 3, 4, 5, 6, 7, 8], PyTreeSpec( CustomTreeNode(ndarray[False], [[*, *, *], [*, *, *], [*, *, *]]), namespace='numpy1' ) ) # Implicitly casts `float`s to `np.float64` >>> optree.tree_map(lambda x: x + 1.5, np.arange(9).reshape(3, 3), namespace='numpy1') array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5], [7.5, 8.5, 9.5]]) >>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy2') ( [0, 1, 2, 3, 4, 5, 6, 7, 8], PyTreeSpec( CustomTreeNode(ndarray[{'shape': (3, 3), 'dtype': dtype('int64')}], [*, *, *, *, *, *, *, *, *]), namespace='numpy2' ) ) # Explicitly casts `float`s to `np.int64` >>> optree.tree_map(lambda x: x + 1.5, np.arange(9).reshape(3, 3), namespace='numpy2') array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # Children are also `np.ndarray`s, recurse without termination condition. >>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy3') RecursionError: maximum recursion depth exceeded during flattening the tree >>> optree.tree_flatten(torch.arange(9).reshape(3, 3), namespace='torch1') ( [0, 1, 2, 3, 4, 5, 6, 7, 8], PyTreeSpec( CustomTreeNode(Tensor[False], [[*, *, *], [*, *, *], [*, *, *]]), namespace='torch1' ) ) # Implicitly casts `float`s to `torch.float32` >>> optree.tree_map(lambda x: x + 1.5, torch.arange(9).reshape(3, 3), namespace='torch1') tensor([[1.5000, 2.5000, 3.5000], [4.5000, 5.5000, 6.5000], [7.5000, 8.5000, 9.5000]]) # Children are also `torch.Tensor`s, recurse without termination condition. >>> optree.tree_flatten(torch.arange(9).reshape(3, 3), namespace='torch2') RecursionError: maximum recursion depth exceeded during flattening the tree
None
is Non-leaf Node vs. None
is Leaf
The None
object is a special object in the Python language.
It serves some of the same purposes as null
(a pointer does not point to anything) in other programming languages, which denotes a variable is empty or marks default parameters.
However, the None
object is a singleton object rather than a pointer.
It may also serve as a sentinel value.
In addition, if a function has returned without any return value or the return statement is omitted, the function will also implicitly return the None
object.
By default, the None
object is considered a non-leaf node in the tree with arity 0, i.e., a non-leaf node that has no children.
This is like the behavior of an empty tuple.
While flattening a tree, it will remain in the tree structure definitions rather than in the leaves list.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> optree.tree_flatten(tree)
([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}))
>>> optree.tree_flatten(tree, none_is_leaf=True)
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
>>> optree.tree_flatten(1)
([1], PyTreeSpec(*))
>>> optree.tree_flatten(None)
([], PyTreeSpec(None))
>>> optree.tree_flatten(None, none_is_leaf=True)
([None], PyTreeSpec(*, NoneIsLeaf))
OpTree provides a keyword argument none_is_leaf
to determine whether to consider the None
object as a leaf, like other opaque objects.
If none_is_leaf=True
, the None
object will place in the leaves list.
Otherwise, the None
object will remain in the tree specification (structure).
>>> import torch
>>> linear = torch.nn.Linear(in_features=3, out_features=2, bias=False)
>>> linear._parameters # a container has None
OrderedDict([
('weight', Parameter containing:
tensor([[-0.6677, 0.5209, 0.3295],
[-0.4876, -0.3142, 0.1785]], requires_grad=True)),
('bias', None)
])
>>> optree.tree_map(torch.zeros_like, linear._parameters)
OrderedDict([
('weight', tensor([[0., 0., 0.],
[0., 0., 0.]])),
('bias', None)
])
>>> optree.tree_map(torch.zeros_like, linear._parameters, none_is_leaf=True)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not NoneType
>>> optree.tree_map(lambda t: torch.zeros_like(t) if t is not None else 0, linear._parameters, none_is_leaf=True)
OrderedDict([
('weight', tensor([[0., 0., 0.],
[0., 0., 0.]])),
('bias', 0)
])
Key Ordering for Dictionaries
The built-in Python dictionary (i.e., builtins.dict
) is an unordered mapping that holds the keys and values.
The leaves of a dictionary are the values. Although since Python 3.6, the built-in dictionary is insertion ordered (PEP 468).
The dictionary equality operator (==
) does not check for key ordering.
To ensure that "equal dict
" implies "equal ordering of leaves", the order of values of the dictionary is sorted by the keys.
This behavior is also applied to collections.defaultdict
.
>>> optree.tree_flatten({'a': [1, 2], 'b': [3]})
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))
>>> optree.tree_flatten({'b': [3], 'a': [1, 2]})
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))
Note that there are no restrictions on the dict
to require the keys are comparable (sortable).
There can be multiple types of keys in the dictionary.
The keys are sorted in ascending order by key=lambda k: k
first if capable otherwise fallback to key=lambda k: (k.__class__.__qualname__, k)
. This handles most cases.
>>> sorted({1: 2, 1.5: 1}.keys())
[1, 1.5]
>>> sorted({'a': 3, 1: 2, 1.5: 1}.keys())
TypeError: '<' not supported between instances of 'int' and 'str'
>>> sorted({'a': 3, 1: 2, 1.5: 1}.keys(), key=lambda k: (k.__class__.__qualname__, k))
[1.5, 1, 'a']
If users want to keep the values in the insertion order, they should use collection.OrderedDict
, which will take the order of keys under consideration:
>>> OrderedDict([('a', [1, 2]), ('b', [3])]) == OrderedDict([('b', [3]), ('a', [1, 2])])
False
>>> optree.tree_flatten(OrderedDict([('a', [1, 2]), ('b', [3])]))
([1, 2, 3], PyTreeSpec(OrderedDict([('a', [*, *]), ('b', [*])])))
>>> optree.tree_flatten(OrderedDict([('b', [3]), ('a', [1, 2])]))
([3, 1, 2], PyTreeSpec(OrderedDict([('b', [*]), ('a', [*, *])])))
Benchmark
We benchmark the performance of:
- tree flatten
- tree unflatten
- tree copy (i.e.,
unflatten(flatten(...))
) - tree map
compared with the following libraries:
- OpTree (
@v0.3.0
) - JAX XLA (
jax[cpu] == 0.3.24
) - PyTorch (
torch == 1.13.0
) - DM-Tree (
dm-tree == 0.1.7
)
All results are reported on a workstation with an AMD Ryzen 9 5950X CPU @ 4.45GHz in an isolated virtual environment with Python 3.10.4. Run with the following command:
conda create --name optree-benchmark anaconda::python=3.10 --yes --no-default-packages
conda activate optree-benchmark
python3 -m pip install --editable '.[benchmark]' --extra-index-url https://download.pytorch.org/whl/cpu
python3 benchmark.py --number=10000 --repeat=5
The test inputs are nested containers (i.e., pytrees) extracted from torch.nn.Module
objects.
They are:
tiny_mlp = nn.Sequential(
nn.Linear(1, 1, bias=True),
nn.BatchNorm1d(1, affine=True, track_running_stats=True),
nn.ReLU(),
nn.Linear(1, 1, bias=False),
nn.Sigmoid(),
)
and AlexNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, VisionTransformerH14 (ViT-H/14), and SwinTransformerB (Swin-B) from torchvsion
.
Please refer to benchmark.py
for more details.
Tree Flatten
Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
---|---|---|---|---|---|---|---|---|
TinyMLP | 53 | 26.40 | 68.19 | 586.87 | 34.14 | 2.58 | 22.23 | 1.29 |
AlexNet | 188 | 84.28 | 259.51 | 2182.07 | 125.12 | 3.08 | 25.89 | 1.48 |
ResNet18 | 698 | 288.57 | 807.27 | 7881.69 | 429.39 | 2.80 | 27.31 | 1.49 |
ResNet34 | 1242 | 580.75 | 1564.97 | 15082.84 | 819.02 | 2.69 | 25.97 | 1.41 |
ResNet50 | 1702 | 791.18 | 2081.17 | 20982.82 | 1104.62 | 2.63 | 26.52 | 1.40 |
ResNet101 | 3317 | 1603.93 | 3939.37 | 40382.14 | 2208.63 | 2.46 | 25.18 | 1.38 |
ResNet152 | 4932 | 2446.56 | 6267.98 | 56892.36 | 3139.17 | 2.56 | 23.25 | 1.28 |
ViT-H/14 | 3420 | 1681.48 | 4488.33 | 41703.16 | 2504.86 | 2.67 | 24.80 | 1.49 |
Swin-B | 2881 | 1565.41 | 4091.10 | 34241.99 | 1936.75 | 2.61 | 21.87 | 1.24 |
Average | 2.68 | 24.78 | 1.38 |
Tree UnFlatten
Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
---|---|---|---|---|---|---|---|---|
TinyMLP | 53 | 59.47 | 163.00 | 257.56 | 967.00 | 2.74 | 4.33 | 16.26 |
AlexNet | 188 | 234.68 | 701.56 | 1011.04 | 4000.43 | 2.99 | 4.31 | 17.05 |
ResNet18 | 698 | 758.82 | 2036.76 | 3391.87 | 12060.09 | 2.68 | 4.47 | 15.89 |
ResNet34 | 1242 | 1459.17 | 3886.79 | 6519.28 | 21435.14 | 2.66 | 4.47 | 14.69 |
ResNet50 | 1702 | 2003.60 | 5137.90 | 8341.17 | 29067.89 | 2.56 | 4.16 | 14.51 |
ResNet101 | 3317 | 4005.73 | 10203.31 | 17316.07 | 59531.47 | 2.55 | 4.32 | 14.86 |
ResNet152 | 4932 | 5644.08 | 15153.87 | 25438.67 | 88626.45 | 2.68 | 4.51 | 15.70 |
ViT-H/14 | 3420 | 4492.64 | 12544.41 | 18091.68 | 67876.19 | 2.79 | 4.03 | 15.11 |
Swin-B | 2881 | 3637.86 | 9973.78 | 15353.31 | 57655.54 | 2.74 | 4.22 | 15.85 |
Average | 2.71 | 4.31 | 15.55 |
Tree Copy
Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
---|---|---|---|---|---|---|---|---|
TinyMLP | 53 | 90.43 | 234.08 | 871.84 | 1006.30 | 2.59 | 9.64 | 11.13 |
AlexNet | 188 | 324.33 | 931.61 | 3263.23 | 4106.04 | 2.87 | 10.06 | 12.66 |
ResNet18 | 698 | 1111.82 | 2840.68 | 11836.55 | 12564.23 | 2.55 | 10.65 | 11.30 |
ResNet34 | 1242 | 2029.24 | 5129.39 | 20888.04 | 23559.50 | 2.53 | 10.29 | 11.61 |
ResNet50 | 1702 | 2884.39 | 7118.82 | 30239.69 | 29509.25 | 2.47 | 10.48 | 10.23 |
ResNet101 | 3317 | 5773.17 | 14396.40 | 60021.18 | 62725.03 | 2.49 | 10.40 | 10.86 |
ResNet152 | 4932 | 8552.95 | 21321.48 | 85857.53 | 86037.99 | 2.49 | 10.04 | 10.06 |
ViT-H/14 | 3420 | 6116.61 | 16038.69 | 59993.87 | 70215.65 | 2.62 | 9.81 | 11.48 |
Swin-B | 2881 | 5466.03 | 14449.60 | 50528.12 | 60269.63 | 2.64 | 9.24 | 11.03 |
Average | 2.58 | 10.07 | 11.15 |
Tree Map
Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
---|---|---|---|---|---|---|---|---|
TinyMLP | 53 | 98.24 | 255.39 | 868.24 | 1032.07 | 2.60 | 8.84 | 10.51 |
AlexNet | 188 | 337.87 | 1004.53 | 3304.64 | 4099.77 | 2.97 | 9.78 | 12.13 |
ResNet18 | 698 | 1171.69 | 3059.72 | 11921.16 | 12727.38 | 2.61 | 10.17 | 10.86 |
ResNet34 | 1242 | 2267.61 | 5793.53 | 22222.92 | 22437.44 | 2.55 | 9.80 | 9.89 |
ResNet50 | 1702 | 2961.05 | 7792.69 | 30132.32 | 31460.04 | 2.63 | 10.18 | 10.62 |
ResNet101 | 3317 | 6101.05 | 14342.22 | 56480.19 | 61830.65 | 2.35 | 9.26 | 10.13 |
ResNet152 | 4932 | 8568.48 | 21641.40 | 83021.19 | 87077.66 | 2.53 | 9.69 | 10.16 |
ViT-H/14 | 3420 | 6735.93 | 18027.05 | 63986.88 | 75742.33 | 2.68 | 9.50 | 11.24 |
Swin-B | 2881 | 5756.71 | 14528.51 | 51052.90 | 60715.06 | 2.52 | 8.87 | 10.55 |
Average | 2.61 | 9.56 | 10.68 |
Tree Map (nargs)
Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
---|---|---|---|---|---|---|---|---|
TinyMLP | 53 | 144.61 | 391.05 | N/A | 3774.79 | 2.70 | N/A | 26.10 |
AlexNet | 188 | 480.23 | 1515.41 | N/A | 15105.87 | 3.16 | N/A | 31.46 |
ResNet18 | 698 | 1690.19 | 4997.44 | N/A | 52089.06 | 2.96 | N/A | 30.82 |
ResNet34 | 1242 | 3084.36 | 8572.54 | N/A | 93923.27 | 2.78 | N/A | 30.45 |
ResNet50 | 1702 | 4441.17 | 11962.92 | N/A | 126937.65 | 2.69 | N/A | 28.58 |
ResNet101 | 3317 | 8155.78 | 22232.67 | N/A | 251333.88 | 2.73 | N/A | 30.82 |
ResNet152 | 4932 | 12862.88 | 33714.46 | N/A | 368424.63 | 2.62 | N/A | 28.64 |
ViT-H/14 | 3420 | 9511.10 | 27920.13 | N/A | 281245.95 | 2.94 | N/A | 29.57 |
Swin-B | 2881 | 7628.29 | 22421.37 | N/A | 238211.56 | 2.94 | N/A | 31.23 |
Average | 2.83 | N/A | 29.74 |
TinyMLP(num_nodes=53, num_leaves=16, treespec=PyTreeSpec([OrderedDict([('tenso...), buffers=OrderedDict([])))])]))
| Subject | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :--------------- | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
| Tree Flatten | 26.40 | 68.19 | 586.87 | 34.14 | 2.58 | 22.23 | 1.29 |
| Tree UnFlatten | 59.47 | 163.00 | 257.56 | 967.00 | 2.74 | 4.33 | 16.26 |
| Tree Copy | 90.43 | 234.08 | 871.84 | 1006.30 | 2.59 | 9.64 | 11.13 |
| Tree Map | 98.24 | 255.39 | 868.24 | 1032.07 | 2.60 | 8.84 | 10.51 |
| Tree Map (nargs) | 144.61 | 391.05 | N/A | 3774.79 | 2.70 | N/A | 26.10 |
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 26.40μs <= optree.tree_leaves(x) (None is Node)
✔ OpTree : 25.98μs -- x0.98 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
~ OpTree : 26.19μs -- x0.99 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 68.19μs -- x2.58 <= jax.tree_util.tree_leaves(x)
PyTorch: 586.87μs -- x22.23 <= torch_utils_pytree.tree_flatten(x)[0]
DM-Tree: 34.14μs -- x1.29 <= dm_tree.flatten(x)
### Tree UnFlatten ###
✔ OpTree : 59.47μs <= optree.tree_unflatten(spec, flat) (None is Node)
~ OpTree : 59.71μs -- x1.00 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 163.00μs -- x2.74 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 257.56μs -- x4.33 <= torch_utils_pytree.tree_unflatten(flat, spec)
DM-Tree: 967.00μs -- x16.26 <= dm_tree.unflatten_as(spec, flat)
### Tree Copy ###
✔ OpTree : 90.43μs <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 91.51μs -- x1.01 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
~ OpTree : 91.34μs -- x1.01 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 234.08μs -- x2.59 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 871.84μs -- x9.64 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
DM-Tree: 1006.30μs -- x11.13 <= dm_tree.unflatten_as(x, dm_tree.flatten(x))
### Tree Map ###
~ OpTree : 98.24μs <= optree.tree_map(fn1, x) (None is Node)
✔ OpTree : 97.62μs -- x0.99 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
~ OpTree : 98.05μs -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 255.39μs -- x2.60 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 868.24μs -- x8.84 <= torch_utils_pytree.tree_map(fn1, x)
DM-Tree: 1032.07μs -- x10.51 <= dm_tree.map_structure(fn1, x)
### Tree Map (nargs) ###
~ OpTree : 144.61μs <= optree.tree_map(fn3, x, y, z) (None is Node)
✔ OpTree : 141.03μs -- x0.98 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 142.79μs -- x0.99 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 391.05μs -- x2.70 <= jax.tree_util.tree_map(fn3, x, y, z)
DM-Tree: 3774.79μs -- x26.10 <= dm_tree.map_structure_up_to(x, fn3, x, y, z)
AlexNet(num_nodes=188, num_leaves=32, treespec=PyTreeSpec(OrderedDict([('featur...]), buffers=OrderedDict([])))])))
| Subject | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :--------------- | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
| Tree Flatten | 84.28 | 259.51 | 2182.07 | 125.12 | 3.08 | 25.89 | 1.48 |
| Tree UnFlatten | 234.68 | 701.56 | 1011.04 | 4000.43 | 2.99 | 4.31 | 17.05 |
| Tree Copy | 324.33 | 931.61 | 3263.23 | 4106.04 | 2.87 | 10.06 | 12.66 |
| Tree Map | 337.87 | 1004.53 | 3304.64 | 4099.77 | 2.97 | 9.78 | 12.13 |
| Tree Map (nargs) | 480.23 | 1515.41 | N/A | 15105.87 | 3.16 | N/A | 31.46 |
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
✔ OpTree : 84.28μs <= optree.tree_leaves(x) (None is Node)
~ OpTree : 84.64μs -- x1.00 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
~ OpTree : 85.54μs -- x1.01 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 259.51μs -- x3.08 <= jax.tree_util.tree_leaves(x)
PyTorch: 2182.07μs -- x25.89 <= torch_utils_pytree.tree_flatten(x)[0]
DM-Tree: 125.12μs -- x1.48 <= dm_tree.flatten(x)
### Tree UnFlatten ###
~ OpTree : 234.68μs <= optree.tree_unflatten(spec, flat) (None is Node)
✔ OpTree : 234.02μs -- x1.00 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 701.56μs -- x2.99 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 1011.04μs -- x4.31 <= torch_utils_pytree.tree_unflatten(flat, spec)
DM-Tree: 4000.43μs -- x17.05 <= dm_tree.unflatten_as(spec, flat)
### Tree Copy ###
~ OpTree : 324.33μs <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 324.09μs -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
✔ OpTree : 323.18μs -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 931.61μs -- x2.87 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 3263.23μs -- x10.06 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
DM-Tree: 4106.04μs -- x12.66 <= dm_tree.unflatten_as(x, dm_tree.flatten(x))
### Tree Map ###
✔ OpTree : 337.87μs <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 340.56μs -- x1.01 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
~ OpTree : 338.36μs -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 1004.53μs -- x2.97 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 3304.64μs -- x9.78 <= torch_utils_pytree.tree_map(fn1, x)
DM-Tree: 4099.77μs -- x12.13 <= dm_tree.map_structure(fn1, x)
### Tree Map (nargs) ###
~ OpTree : 480.23μs <= optree.tree_map(fn3, x, y, z) (None is Node)
~ OpTree : 481.45μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
✔ OpTree : 479.35μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 1515.41μs -- x3.16 <= jax.tree_util.tree_map(fn3, x, y, z)
DM-Tree: 15105.87μs -- x31.46 <= dm_tree.map_structure_up_to(x, fn3, x, y, z)
ResNet18(num_nodes=698, num_leaves=244, treespec=PyTreeSpec(OrderedDict([('conv1'...]), buffers=OrderedDict([])))])))
| Subject | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :--------------- | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
| Tree Flatten | 288.57 | 807.27 | 7881.69 | 429.39 | 2.80 | 27.31 | 1.49 |
| Tree UnFlatten | 758.82 | 2036.76 | 3391.87 | 12060.09 | 2.68 | 4.47 | 15.89 |
| Tree Copy | 1111.82 | 2840.68 | 11836.55 | 12564.23 | 2.55 | 10.65 | 11.30 |
| Tree Map | 1171.69 | 3059.72 | 11921.16 | 12727.38 | 2.61 | 10.17 | 10.86 |
| Tree Map (nargs) | 1690.19 | 4997.44 | N/A | 52089.06 | 2.96 | N/A | 30.82 |
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
✔ OpTree : 288.57μs <= optree.tree_leaves(x) (None is Node)
~ OpTree : 294.01μs -- x1.02 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
~ OpTree : 294.99μs -- x1.02 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 807.27μs -- x2.80 <= jax.tree_util.tree_leaves(x)
PyTorch: 7881.69μs -- x27.31 <= torch_utils_pytree.tree_flatten(x)[0]
DM-Tree: 429.39μs -- x1.49 <= dm_tree.flatten(x)
### Tree UnFlatten ###
✔ OpTree : 758.82μs <= optree.tree_unflatten(spec, flat) (None is Node)
~ OpTree : 765.90μs -- x1.01 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 2036.76μs -- x2.68 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 3391.87μs -- x4.47 <= torch_utils_pytree.tree_unflatten(flat, spec)
DM-Tree: 12060.09μs -- x15.89 <= dm_tree.unflatten_as(spec, flat)
### Tree Copy ###
~ OpTree : 1111.82μs <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 1103.80μs -- x0.99 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
✔ OpTree : 1100.03μs -- x0.99 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 2840.68μs -- x2.55 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 11836.55μs -- x10.65 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
DM-Tree: 12564.23μs -- x11.30 <= dm_tree.unflatten_as(x, dm_tree.flatten(x))
### Tree Map ###
✔ OpTree : 1171.69μs <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 1172.46μs -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
~ OpTree : 1181.17μs -- x1.01 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 3059.72μs -- x2.61 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 11921.16μs -- x10.17 <= torch_utils_pytree.tree_map(fn1, x)
DM-Tree: 12727.38μs -- x10.86 <= dm_tree.map_structure(fn1, x)
### Tree Map (nargs) ###
✔ OpTree : 1690.19μs <= optree.tree_map(fn3, x, y, z) (None is Node)
~ OpTree : 1760.86μs -- x1.04 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 1761.76μs -- x1.04 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 4997.44μs -- x2.96 <= jax.tree_util.tree_map(fn3, x, y, z)
DM-Tree: 52089.06μs -- x30.82 <= dm_tree.map_structure_up_to(x, fn3, x, y, z)
ResNet34(num_nodes=1242, num_leaves=436, treespec=PyTreeSpec(OrderedDict([('conv1'...]), buffers=OrderedDict([])))])))
| Subject | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :--------------- | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
| Tree Flatten | 580.75 | 1564.97 | 15082.84 | 819.02 | 2.69 | 25.97 | 1.41 |
| Tree UnFlatten | 1459.17 | 3886.79 | 6519.28 | 21435.14 | 2.66 | 4.47 | 14.69 |
| Tree Copy | 2029.24 | 5129.39 | 20888.04 | 23559.50 | 2.53 | 10.29 | 11.61 |
| Tree Map | 2267.61 | 5793.53 | 22222.92 | 22437.44 | 2.55 | 9.80 | 9.89 |
| Tree Map (nargs) | 3084.36 | 8572.54 | N/A | 93923.27 | 2.78 | N/A | 30.45 |
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 580.75μs <= optree.tree_leaves(x) (None is Node)
~ OpTree : 577.02μs -- x0.99 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
✔ OpTree : 571.14μs -- x0.98 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 1564.97μs -- x2.69 <= jax.tree_util.tree_leaves(x)
PyTorch: 15082.84μs -- x25.97 <= torch_utils_pytree.tree_flatten(x)[0]
DM-Tree: 819.02μs -- x1.41 <= dm_tree.flatten(x)
### Tree UnFlatten ###
✔ OpTree : 1459.17μs <= optree.tree_unflatten(spec, flat) (None is Node)
~ OpTree : 1465.07μs -- x1.00 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 3886.79μs -- x2.66 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 6519.28μs -- x4.47 <= torch_utils_pytree.tree_unflatten(flat, spec)
DM-Tree: 21435.14μs -- x14.69 <= dm_tree.unflatten_as(spec, flat)
### Tree Copy ###
~ OpTree : 2029.24μs <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
✔ OpTree : 2021.45μs -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
~ OpTree : 2024.29μs -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 5129.39μs -- x2.53 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 20888.04μs -- x10.29 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
DM-Tree: 23559.50μs -- x11.61 <= dm_tree.unflatten_as(x, dm_tree.flatten(x))
### Tree Map ###
~ OpTree : 2267.61μs <= optree.tree_map(fn1, x) (None is Node)
✔ OpTree : 2257.85μs -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
~ OpTree : 2268.77μs -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 5793.53μs -- x2.55 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 22222.92μs -- x9.80 <= torch_utils_pytree.tree_map(fn1, x)
DM-Tree: 22437.44μs -- x9.89 <= dm_tree.map_structure(fn1, x)
### Tree Map (nargs) ###
~ OpTree : 3084.36μs <= optree.tree_map(fn3, x, y, z) (None is Node)
~ OpTree : 3080.29μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
✔ OpTree : 3072.45μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 8572.54μs -- x2.78 <= jax.tree_util.tree_map(fn3, x, y, z)
DM-Tree: 93923.27μs -- x30.45 <= dm_tree.map_structure_up_to(x, fn3, x, y, z)
ResNet50(num_nodes=1702, num_leaves=640, treespec=PyTreeSpec(OrderedDict([('conv1'...]), buffers=OrderedDict([])))])))
| Subject | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :--------------- | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
| Tree Flatten | 791.18 | 2081.17 | 20982.82 | 1104.62 | 2.63 | 26.52 | 1.40 |
| Tree UnFlatten | 2003.60 | 5137.90 | 8341.17 | 29067.89 | 2.56 | 4.16 | 14.51 |
| Tree Copy | 2884.39 | 7118.82 | 30239.69 | 29509.25 | 2.47 | 10.48 | 10.23 |
| Tree Map | 2961.05 | 7792.69 | 30132.32 | 31460.04 | 2.63 | 10.18 | 10.62 |
| Tree Map (nargs) | 4441.17 | 11962.92 | N/A | 126937.65 | 2.69 | N/A | 28.58 |
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 791.18μs <= optree.tree_leaves(x) (None is Node)
~ OpTree : 791.38μs -- x1.00 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
✔ OpTree : 779.75μs -- x0.99 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 2081.17μs -- x2.63 <= jax.tree_util.tree_leaves(x)
PyTorch: 20982.82μs -- x26.52 <= torch_utils_pytree.tree_flatten(x)[0]
DM-Tree: 1104.62μs -- x1.40 <= dm_tree.flatten(x)
### Tree UnFlatten ###
~ OpTree : 2003.60μs <= optree.tree_unflatten(spec, flat) (None is Node)
✔ OpTree : 2000.10μs -- x1.00 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 5137.90μs -- x2.56 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 8341.17μs -- x4.16 <= torch_utils_pytree.tree_unflatten(flat, spec)
DM-Tree: 29067.89μs -- x14.51 <= dm_tree.unflatten_as(spec, flat)
### Tree Copy ###
~ OpTree : 2884.39μs <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 2879.97μs -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
✔ OpTree : 2868.84μs -- x0.99 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 7118.82μs -- x2.47 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 30239.69μs -- x10.48 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
DM-Tree: 29509.25μs -- x10.23 <= dm_tree.unflatten_as(x, dm_tree.flatten(x))
### Tree Map ###
✔ OpTree : 2961.05μs <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 3079.33μs -- x1.04 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
~ OpTree : 3116.74μs -- x1.05 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 7792.69μs -- x2.63 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 30132.32μs -- x10.18 <= torch_utils_pytree.tree_map(fn1, x)
DM-Tree: 31460.04μs -- x10.62 <= dm_tree.map_structure(fn1, x)
### Tree Map (nargs) ###
~ OpTree : 4441.17μs <= optree.tree_map(fn3, x, y, z) (None is Node)
✔ OpTree : 4430.87μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 4449.43μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 11962.92μs -- x2.69 <= jax.tree_util.tree_map(fn3, x, y, z)
DM-Tree: 126937.65μs -- x28.58 <= dm_tree.map_structure_up_to(x, fn3, x, y, z)
ResNet101(num_nodes=3317, num_leaves=1252, treespec=PyTreeSpec(OrderedDict([('conv1'...]), buffers=OrderedDict([])))])))
| Subject | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :--------------- | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
| Tree Flatten | 1603.93 | 3939.37 | 40382.14 | 2208.63 | 2.46 | 25.18 | 1.38 |
| Tree UnFlatten | 4005.73 | 10203.31 | 17316.07 | 59531.47 | 2.55 | 4.32 | 14.86 |
| Tree Copy | 5773.17 | 14396.40 | 60021.18 | 62725.03 | 2.49 | 10.40 | 10.86 |
| Tree Map | 6101.05 | 14342.22 | 56480.19 | 61830.65 | 2.35 | 9.26 | 10.13 |
| Tree Map (nargs) | 8155.78 | 22232.67 | N/A | 251333.88 | 2.73 | N/A | 30.82 |
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 1603.93μs <= optree.tree_leaves(x) (None is Node)
✔ OpTree : 1458.25μs -- x0.91 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
~ OpTree : 1492.62μs -- x0.93 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 3939.37μs -- x2.46 <= jax.tree_util.tree_leaves(x)
PyTorch: 40382.14μs -- x25.18 <= torch_utils_pytree.tree_flatten(x)[0]
DM-Tree: 2208.63μs -- x1.38 <= dm_tree.flatten(x)
### Tree UnFlatten ###
~ OpTree : 4005.73μs <= optree.tree_unflatten(spec, flat) (None is Node)
✔ OpTree : 3957.47μs -- x0.99 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 10203.31μs -- x2.55 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 17316.07μs -- x4.32 <= torch_utils_pytree.tree_unflatten(flat, spec)
DM-Tree: 59531.47μs -- x14.86 <= dm_tree.unflatten_as(spec, flat)
### Tree Copy ###
~ OpTree : 5773.17μs <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
✔ OpTree : 5741.73μs -- x0.99 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
~ OpTree : 5759.01μs -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 14396.40μs -- x2.49 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 60021.18μs -- x10.40 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
DM-Tree: 62725.03μs -- x10.86 <= dm_tree.unflatten_as(x, dm_tree.flatten(x))
### Tree Map ###
~ OpTree : 6101.05μs <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 6145.86μs -- x1.01 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
✔ OpTree : 5709.67μs -- x0.94 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 14342.22μs -- x2.35 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 56480.19μs -- x9.26 <= torch_utils_pytree.tree_map(fn1, x)
DM-Tree: 61830.65μs -- x10.13 <= dm_tree.map_structure(fn1, x)
### Tree Map (nargs) ###
~ OpTree : 8155.78μs <= optree.tree_map(fn3, x, y, z) (None is Node)
~ OpTree : 8144.17μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
✔ OpTree : 8113.58μs -- x0.99 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 22232.67μs -- x2.73 <= jax.tree_util.tree_map(fn3, x, y, z)
DM-Tree: 251333.88μs -- x30.82 <= dm_tree.map_structure_up_to(x, fn3, x, y, z)
ResNet152(num_nodes=4932, num_leaves=1864, treespec=PyTreeSpec(OrderedDict([('conv1'...]), buffers=OrderedDict([])))])))
| Subject | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :--------------- | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
| Tree Flatten | 2446.56 | 6267.98 | 56892.36 | 3139.17 | 2.56 | 23.25 | 1.28 |
| Tree UnFlatten | 5644.08 | 15153.87 | 25438.67 | 88626.45 | 2.68 | 4.51 | 15.70 |
| Tree Copy | 8552.95 | 21321.48 | 85857.53 | 86037.99 | 2.49 | 10.04 | 10.06 |
| Tree Map | 8568.48 | 21641.40 | 83021.19 | 87077.66 | 2.53 | 9.69 | 10.16 |
| Tree Map (nargs) | 12862.88 | 33714.46 | N/A | 368424.63 | 2.62 | N/A | 28.64 |
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 2446.56μs <= optree.tree_leaves(x) (None is Node)
~ OpTree : 2455.99μs -- x1.00 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
✔ OpTree : 2429.96μs -- x0.99 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 6267.98μs -- x2.56 <= jax.tree_util.tree_leaves(x)
PyTorch: 56892.36μs -- x23.25 <= torch_utils_pytree.tree_flatten(x)[0]
DM-Tree: 3139.17μs -- x1.28 <= dm_tree.flatten(x)
### Tree UnFlatten ###
✔ OpTree : 5644.08μs <= optree.tree_unflatten(spec, flat) (None is Node)
~ OpTree : 5723.38μs -- x1.01 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 15153.87μs -- x2.68 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 25438.67μs -- x4.51 <= torch_utils_pytree.tree_unflatten(flat, spec)
DM-Tree: 88626.45μs -- x15.70 <= dm_tree.unflatten_as(spec, flat)
### Tree Copy ###
~ OpTree : 8552.95μs <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 8531.50μs -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
✔ OpTree : 8528.88μs -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 21321.48μs -- x2.49 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 85857.53μs -- x10.04 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
DM-Tree: 86037.99μs -- x10.06 <= dm_tree.unflatten_as(x, dm_tree.flatten(x))
### Tree Map ###
~ OpTree : 8568.48μs <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 8569.48μs -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
✔ OpTree : 8542.91μs -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 21641.40μs -- x2.53 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 83021.19μs -- x9.69 <= torch_utils_pytree.tree_map(fn1, x)
DM-Tree: 87077.66μs -- x10.16 <= dm_tree.map_structure(fn1, x)
### Tree Map (nargs) ###
~ OpTree : 12862.88μs <= optree.tree_map(fn3, x, y, z) (None is Node)
✔ OpTree : 12806.09μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 12909.94μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 33714.46μs -- x2.62 <= jax.tree_util.tree_map(fn3, x, y, z)
DM-Tree: 368424.63μs -- x28.64 <= dm_tree.map_structure_up_to(x, fn3, x, y, z)
ViT-H/14(num_nodes=3420, num_leaves=784, treespec=PyTreeSpec(OrderedDict([('conv_p...]), buffers=OrderedDict([])))])))
| Subject | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :--------------- | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
| Tree Flatten | 1681.48 | 4488.33 | 41703.16 | 2504.86 | 2.67 | 24.80 | 1.49 |
| Tree UnFlatten | 4492.64 | 12544.41 | 18091.68 | 67876.19 | 2.79 | 4.03 | 15.11 |
| Tree Copy | 6116.61 | 16038.69 | 59993.87 | 70215.65 | 2.62 | 9.81 | 11.48 |
| Tree Map | 6735.93 | 18027.05 | 63986.88 | 75742.33 | 2.68 | 9.50 | 11.24 |
| Tree Map (nargs) | 9511.10 | 27920.13 | N/A | 281245.95 | 2.94 | N/A | 29.57 |
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
✔ OpTree : 1681.48μs <= optree.tree_leaves(x) (None is Node)
~ OpTree : 1702.21μs -- x1.01 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
~ OpTree : 1694.58μs -- x1.01 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 4488.33μs -- x2.67 <= jax.tree_util.tree_leaves(x)
PyTorch: 41703.16μs -- x24.80 <= torch_utils_pytree.tree_flatten(x)[0]
DM-Tree: 2504.86μs -- x1.49 <= dm_tree.flatten(x)
### Tree UnFlatten ###
✔ OpTree : 4492.64μs <= optree.tree_unflatten(spec, flat) (None is Node)
~ OpTree : 4535.79μs -- x1.01 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 12544.41μs -- x2.79 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 18091.68μs -- x4.03 <= torch_utils_pytree.tree_unflatten(flat, spec)
DM-Tree: 67876.19μs -- x15.11 <= dm_tree.unflatten_as(spec, flat)
### Tree Copy ###
~ OpTree : 6116.61μs <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
✔ OpTree : 6075.72μs -- x0.99 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
~ OpTree : 6104.80μs -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 16038.69μs -- x2.62 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 59993.87μs -- x9.81 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
DM-Tree: 70215.65μs -- x11.48 <= dm_tree.unflatten_as(x, dm_tree.flatten(x))
### Tree Map ###
~ OpTree : 6735.93μs <= optree.tree_map(fn1, x) (None is Node)
✔ OpTree : 6679.19μs -- x0.99 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
~ OpTree : 6726.99μs -- x1.00 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 18027.05μs -- x2.68 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 63986.88μs -- x9.50 <= torch_utils_pytree.tree_map(fn1, x)
DM-Tree: 75742.33μs -- x11.24 <= dm_tree.map_structure(fn1, x)
### Tree Map (nargs) ###
~ OpTree : 9511.10μs <= optree.tree_map(fn3, x, y, z) (None is Node)
✔ OpTree : 9503.85μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
~ OpTree : 9550.25μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 27920.13μs -- x2.94 <= jax.tree_util.tree_map(fn3, x, y, z)
DM-Tree: 281245.95μs -- x29.57 <= dm_tree.map_structure_up_to(x, fn3, x, y, z)
Swin-B(num_nodes=2881, num_leaves=706, treespec=PyTreeSpec(OrderedDict([('featur...]), buffers=OrderedDict([])))])))
| Subject | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
| :--------------- | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
| Tree Flatten | 1565.41 | 4091.10 | 34241.99 | 1936.75 | 2.61 | 21.87 | 1.24 |
| Tree UnFlatten | 3637.86 | 9973.78 | 15353.31 | 57655.54 | 2.74 | 4.22 | 15.85 |
| Tree Copy | 5466.03 | 14449.60 | 50528.12 | 60269.63 | 2.64 | 9.24 | 11.03 |
| Tree Map | 5756.71 | 14528.51 | 51052.90 | 60715.06 | 2.52 | 8.87 | 10.55 |
| Tree Map (nargs) | 7628.29 | 22421.37 | N/A | 238211.56 | 2.94 | N/A | 31.23 |
### Check ###
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=False)[::-1]) == tree
✔ COPY: optree.tree_unflatten(*optree.tree_flatten(tree, none_is_leaf=True)[::-1]) == tree
✔ FLATTEN (OpTree vs. JAX XLA): optree.tree_leaves(tree, none_is_leaf=False) == jax.tree_util.tree_leaves(tree)
✔ FLATTEN (OpTree vs. PyTorch): optree.tree_leaves(tree, none_is_leaf=True) == torch_utils_pytree.tree_flatten(tree)[0]
✔ TREEMAP (OpTree vs. JAX XLA): optree.tree_map(fn, tree, none_is_leaf=False) == jax.tree_util.tree_map(fn, tree)
✔ TREEMAP (OpTree vs. PyTorch): optree.tree_map(fn, tree, none_is_leaf=True) == torch_utils_pytree.tree_map(fn, tree)
### Tree Flatten ###
~ OpTree : 1565.41μs <= optree.tree_leaves(x) (None is Node)
~ OpTree : 1565.91μs -- x1.00 <= optree.tree_leaves(x, none_is_leaf=False) (None is Node)
✔ OpTree : 1550.64μs -- x0.99 <= optree.tree_leaves(x, none_is_leaf=True) (None is Leaf)
JAX XLA: 4091.10μs -- x2.61 <= jax.tree_util.tree_leaves(x)
PyTorch: 34241.99μs -- x21.87 <= torch_utils_pytree.tree_flatten(x)[0]
DM-Tree: 1936.75μs -- x1.24 <= dm_tree.flatten(x)
### Tree UnFlatten ###
~ OpTree : 3637.86μs <= optree.tree_unflatten(spec, flat) (None is Node)
✔ OpTree : 3596.18μs -- x0.99 <= optree.tree_unflatten(spec, flat) (None is Leaf)
JAX XLA: 9973.78μs -- x2.74 <= jax.tree_util.tree_unflatten(spec, flat)
PyTorch: 15353.31μs -- x4.22 <= torch_utils_pytree.tree_unflatten(flat, spec)
DM-Tree: 57655.54μs -- x15.85 <= dm_tree.unflatten_as(spec, flat)
### Tree Copy ###
✔ OpTree : 5466.03μs <= optree.tree_unflatten(*optree.tree_flatten(x)[::-1]) (None is Node)
~ OpTree : 5467.68μs -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=False)[::-1]) (None is Node)
~ OpTree : 5469.55μs -- x1.00 <= optree.tree_unflatten(*optree.tree_flatten(x, none_is_leaf=True)[::-1]) (None is Leaf)
JAX XLA: 14449.60μs -- x2.64 <= jax.tree_util.tree_unflatten(*jax.tree_util.tree_flatten(x)[::-1])
PyTorch: 50528.12μs -- x9.24 <= torch_utils_pytree.tree_unflatten(*torch_utils_pytree.tree_flatten(x))
DM-Tree: 60269.63μs -- x11.03 <= dm_tree.unflatten_as(x, dm_tree.flatten(x))
### Tree Map ###
~ OpTree : 5756.71μs <= optree.tree_map(fn1, x) (None is Node)
~ OpTree : 5712.77μs -- x0.99 <= optree.tree_map(fn1, x, none_is_leaf=False) (None is Node)
✔ OpTree : 5706.22μs -- x0.99 <= optree.tree_map(fn1, x, none_is_leaf=True) (None is Leaf)
JAX XLA: 14528.51μs -- x2.52 <= jax.tree_util.tree_map(fn1, x)
PyTorch: 51052.90μs -- x8.87 <= torch_utils_pytree.tree_map(fn1, x)
DM-Tree: 60715.06μs -- x10.55 <= dm_tree.map_structure(fn1, x)
### Tree Map (nargs) ###
~ OpTree : 7628.29μs <= optree.tree_map(fn3, x, y, z) (None is Node)
~ OpTree : 7622.97μs -- x1.00 <= optree.tree_map(fn3, x, y, z, none_is_leaf=False) (None is Node)
✔ OpTree : 7548.69μs -- x0.99 <= optree.tree_map(fn3, x, y, z, none_is_leaf=True) (None is Leaf)
JAX XLA: 22421.37μs -- x2.94 <= jax.tree_util.tree_map(fn3, x, y, z)
DM-Tree: 238211.56μs -- x31.23 <= dm_tree.map_structure_up_to(x, fn3, x, y, z)
License
OpTree is released under the Apache License 2.0.
OpTree is heavily based on JAX's implementation of the PyTree utility, with deep refactoring and several improvements. The original licenses can be found at JAX's Apache License 2.0 and Tensorflow's Apache License 2.0.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for optree-0.5.1-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4217fef7c9d36995e8864e0fdb2f77df04de71a06c3fbac5f42cf06fe5630e70 |
|
MD5 | 995a02d5eec7d7f5b5c26730e72c37a2 |
|
BLAKE2b-256 | 0398faf95d148b848b6345c059ad7d50e86956bc78cc735f362ec1eee064f273 |
Hashes for optree-0.5.1-cp311-cp311-musllinux_1_1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | dd3fac659ab7a1b7a5f7e5ddb9ce12fe0ce077dcc88eba3625cee633e804cac8 |
|
MD5 | a8583756fdf2ebbb2cdf3800d596c8d1 |
|
BLAKE2b-256 | ee26a3591fb581d216142f0c011bc10803dddb5fcec136c7e2fd792ac92f9d7f |
Hashes for optree-0.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c450fe8ede001f0e2409ed7211aef741975943663f3a19a3633851ece16fb8ce |
|
MD5 | 2427164e05468e0c81e822318888485e |
|
BLAKE2b-256 | e7dbb8b22ec85ed6031227f6f27692b51ba47f138bbf242bfa19e80b185be62b |
Hashes for optree-0.5.1-cp311-cp311-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1cdbc4a0a6e935b9914c339976b23281eda79511e2324df7d9fb1a517de9bb13 |
|
MD5 | 596a74503c3559beca9377d0ccfa9fe1 |
|
BLAKE2b-256 | e59507c6a85fb045a4cc5a3ffa379412fbc4ba8efb1eaa3e2693b8b27bd940f4 |
Hashes for optree-0.5.1-cp311-cp311-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5dbd77bf923c38d35c199861d1415ef23d5ad8b6622b938a409a2707f47aff15 |
|
MD5 | d46701ec82c7c916605c9221ce2211fd |
|
BLAKE2b-256 | cbadd12b6504aee558f174b704c9ecc2186772a9bb9e2f86a87c153376fc21e7 |
Hashes for optree-0.5.1-cp311-cp311-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5936fff9a3577082fd84c881ebe1dbef897aaeec62340d1b192ebe1948583565 |
|
MD5 | 742a45d77d3391dc982950e09b3dce3a |
|
BLAKE2b-256 | f8f6740033ba05fea72a7273e4ca31fb62a5f8aaebf57b1043be759071e1e0b7 |
Hashes for optree-0.5.1-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 77bd8251790e9399f07c2c05d0a6f96ac79b47fe1bef72567285bbe6e3802e01 |
|
MD5 | 2f6844af1b39a3837f41cd994e26e25a |
|
BLAKE2b-256 | aa9b03b45bc11777072283a9e63fe226fd4cee7098803326ebe5a2eb0a40d58a |
Hashes for optree-0.5.1-cp310-cp310-musllinux_1_1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f6f3fd04b0cdfde55a99ae3279a52b5567d07619c9db6b28ac3460f4372d620b |
|
MD5 | 225f7a82a3886473420c05a149f7d3cd |
|
BLAKE2b-256 | 9ee09ae3ee1693cbb12eacc69d1802471164f6db51f6ea3753929b78a8f7f0e2 |
Hashes for optree-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7fec6e4d2eb1a5f97c1a0161d17ac074c2fc4ac9eaf1e404beb70f713807b3bc |
|
MD5 | cad9f06845cbcc7a19b629119884b8d5 |
|
BLAKE2b-256 | 77ae2ff6fb4bd5d06d002e14ad9edcbffbe04214d952a76499d120b0c411ae29 |
Hashes for optree-0.5.1-cp310-cp310-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d48d5290116626296ca6123edef3b9ffb74c9ced66e807ff08be1afa876c56a9 |
|
MD5 | 875fe2652489f37b4645e0d7474c9177 |
|
BLAKE2b-256 | 9a96968c2dc291823e2fea3021d479ea9a5346cf62265bfae235164fec38b953 |
Hashes for optree-0.5.1-cp310-cp310-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ce816487fa3e8958479955cd01e9d97bc026ff1a3f7377e247c8fbd2db5866cd |
|
MD5 | 70bf5f35f6c9fb87cf5e4c2a095c2722 |
|
BLAKE2b-256 | ea3d5acb31090a323a2026901cc1b56fd3204c6be4b43c8d6685a6616f7806a6 |
Hashes for optree-0.5.1-cp310-cp310-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 98427ac443679d007b1fdb81dff5c560c563ad473259b4dea671422c0cf0c162 |
|
MD5 | c121d47b9c6d64baa4b1cd8fcdc9db8b |
|
BLAKE2b-256 | 94ee7068ff1a46d0f9a941c5e22a1435894645eb07d5391ac90bc5561c977391 |
Hashes for optree-0.5.1-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8435d8b0e8e4cb1b87060261b988c5d33baa047830869fb568994404d791feec |
|
MD5 | 092eb7b1e118a4deaaf344ce225b7e83 |
|
BLAKE2b-256 | bf926b912349064d72fc33dd762e0c49a05f39d574190433f000cda409032006 |
Hashes for optree-0.5.1-cp39-cp39-musllinux_1_1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 91d95d5bf96253838158dfee62de1093e307a10bfa2c6390a520e4b94be730b5 |
|
MD5 | 3736e58717a5be7ef474cbb2d7917a43 |
|
BLAKE2b-256 | afb8f6d0d334a878e17216dc737f0d85a4825920108647d2551f91c43618446b |
Hashes for optree-0.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | da37c47846e72fb3a78efe2d245e9e776b1dc93a9ddbe47cb6a1a17e6fec24d9 |
|
MD5 | 30a9dddde9caa70c209502ca7ede679d |
|
BLAKE2b-256 | 6ce6c511b20d3afb34e8713168dba85f24cae864fa205c8be3ba34e399805cdf |
Hashes for optree-0.5.1-cp39-cp39-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9eea6b980dca402f49c7b234722aa5016cc776f8e14e56e02e80f256c07748f1 |
|
MD5 | 3fe47bc899dec4841e546b5fc4d0f168 |
|
BLAKE2b-256 | 13ce00dc2b5f4a6aa164fd8b3449b9a862a27e2274c9c66cbd1b72e887fbaa92 |
Hashes for optree-0.5.1-cp39-cp39-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8d3c1b51d7be06f46bf042470caf60d7023f8c98441ecab1d81cb7bc300ce0b2 |
|
MD5 | 89b20342357d7a17303dba5de7c915ed |
|
BLAKE2b-256 | 67a7589c7d2f0a97d0e810a5e788d82c0a5403fdc3362004aec0942cbd26eaac |
Hashes for optree-0.5.1-cp39-cp39-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3d33a9fd36eaa483e005884075f01304c9e781accdad12bb7f8434a4fe61b0b3 |
|
MD5 | 7965702dfe09b008ebdc0c8a8f77e49e |
|
BLAKE2b-256 | e564f46ca33436a109ee6aab5671c114df125bcf03810ac6c57406ef87a0a06f |
Hashes for optree-0.5.1-cp38-cp38-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7669a664c66824d40b0dcd7e670e894890c65a369c0b84616cc95e86941618f6 |
|
MD5 | 99f393b7a809595ceaa900936f88b707 |
|
BLAKE2b-256 | f5ab2d27ecc68cef386e9e935b2c0efb1f63ca16db8f025a913b8c5221374140 |
Hashes for optree-0.5.1-cp38-cp38-musllinux_1_1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 01cd1eb717c510d47881009f8ba72d93f19289ec5395f0df12c939850ed39915 |
|
MD5 | e1c0bbbb0ae0dc70daace6f34c0a27f7 |
|
BLAKE2b-256 | 0175c0873cb8811642485da6d36b11d120fe1efd3cc506c6790850cc3a5360a0 |
Hashes for optree-0.5.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8cd5561e8fb290599157a3beecdf5e82e15d0ec95f19306879ae2362181cb005 |
|
MD5 | dd15ecbb9717421c4355d67c1515572e |
|
BLAKE2b-256 | 631189f8afc65b3f3ef429c0c60eab57d7499b528f7f6846deba9f548b3ac950 |
Hashes for optree-0.5.1-cp38-cp38-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2cadc8f9b5b26929d033227ddab2a487194051f0f24466feafe9a210a288702b |
|
MD5 | 42c236c0731201c8d283d4a5c96718eb |
|
BLAKE2b-256 | f6d0daef7dc5e830d912c12b4757ddd7b222a0b6c39234458a78aedc3bfdf59f |
Hashes for optree-0.5.1-cp38-cp38-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fa7054ba4ad7b25a49d1a412c6513a547b0aeaeb2aad5fb0d6e8d85e4d597e4c |
|
MD5 | 338170f791430351058d18de178cb858 |
|
BLAKE2b-256 | e111b7d75946d01203ce4bcb1990c91e19945946d10074f5be7642d32c382904 |
Hashes for optree-0.5.1-cp38-cp38-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d0cd3fa55db58c45542e864c7c588aeeb5cce91ff5a91db073b72895b27e3122 |
|
MD5 | 2682843e6359a823eea62170040d8622 |
|
BLAKE2b-256 | ebcc5dc23c386f80a2d67a24a9f04c5c958e7ad9df93c6be819fb4f142ab62da |
Hashes for optree-0.5.1-cp37-cp37m-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2dd6375a0b924e72f3ed777f8e5875c27b2dd3ad9eb17caa04e1092a1a5ba70e |
|
MD5 | 42e2e94ed654356fdfb68cbad374d6ea |
|
BLAKE2b-256 | eb7d17848e79c94d8ba6613c7be3357d2dadf3cb954516cdd304eb00d845e43a |
Hashes for optree-0.5.1-cp37-cp37m-musllinux_1_1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ddc0ee6d21be9afc11c29ea819fdf5488c57df6c82fcf8d36af25494e7894f1e |
|
MD5 | 2a869fd513dadbacb8410f53b3d110ce |
|
BLAKE2b-256 | ce2fb06de9cfa38a65aff1b6594c723a38acfa96e9f59bb853e55a6f8484135a |
Hashes for optree-0.5.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e5a890d556f4f25e371d5626d545047d40fd881a2f342f5c3709b2232c119e9f |
|
MD5 | 33237e0707b3727eb3f0261213de6347 |
|
BLAKE2b-256 | 6656260a61fc915138aa137ca72bae0e5c8a4f1d5978ac319285f7876c579a98 |
Hashes for optree-0.5.1-cp37-cp37m-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | cd541a4eb511220eb787a3601ae0100b712f59c9d99eec96050f95a1ac70721b |
|
MD5 | ae22f114ab4986f7e167887ea3f620cb |
|
BLAKE2b-256 | f30d9b541df1c6410ffb3ed9dc7482ccc373e21d81c958bae156e74cc6b48d45 |