Skip to main content

Optimized PyTree Utilities.

Project description

OpTree

Python 3.7+ PyPI GitHub Workflow Status GitHub Workflow Status Codecov Documentation Status Downloads GitHub Repo Stars

Optimized PyTree Utilities.


Table of Contents


Installation

Install from PyPI (PyPI / Status):

pip3 install --upgrade optree

Install from conda-forge (conda-forge):

conda install -c conda-forge optree

Install the latest version from GitHub:

pip3 install git+https://github.com/metaopt/optree.git#egg=optree

Or, clone this repo and install manually:

git clone --depth=1 https://github.com/metaopt/optree.git
cd optree
pip3 install .

Compiling from the source requires Python 3.7+, a compiler (gcc / clang / icc / cl.exe) that supports C++20 and a cmake installation.


PyTrees

A PyTree is a recursive structure that can be an arbitrarily nested Python container (e.g., tuple, list, dict, OrderedDict, NamedTuple, etc.) or an opaque Python object. The key concepts of tree operations are tree flattening and its inverse (tree unflattening). Additional tree operations can be performed based on these two basic functions (e.g., tree_map = tree_unflatten ∘ map ∘ tree_flatten).

Tree flattening is traversing the entire tree in a left-to-right depth-first manner and returning the leaves of the tree in a deterministic order.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': 5, 'd': 6}
>>> optree.tree_flatten(tree)
([1, 2, 3, 4, 5, 6], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}))
>>> optree.tree_flatten(1)
([1], PyTreeSpec(*))
>>> optree.tree_flatten(None)
([], PyTreeSpec(None))

This usually implies that the equal pytrees return equal lists of leaves and the same tree structure. See also section Key Ordering for Dictionaries.

>>> {'a': [1, 2], 'b': [3]} == {'b': [3], 'a': [1, 2]}
True
>>> optree.tree_leaves({'a': [1, 2], 'b': [3]}) == optree.tree_leaves({'b': [3], 'a': [1, 2]})
True
>>> optree.tree_structure({'a': [1, 2], 'b': [3]}) == optree.tree_structure({'b': [3], 'a': [1, 2]})
True

Tree Nodes and Leaves

A tree is a collection of non-leaf nodes and leaf nodes, where the leaf nodes have no children to flatten. optree.tree_flatten(...) will flatten the tree and return a list of leaf nodes while the non-leaf nodes will store in the tree specification.

Built-in PyTree Node Types

OpTree out-of-box supports the following Python container types in the registry:

which are considered non-leaf nodes in the tree. Python objects that the type is not registered will be treated as leaf nodes. The registry lookup uses the is operator to determine whether the type is matched. So subclasses will need to explicitly register in the registry, otherwise, an object of that type will be considered a leaf. The NoneType is a special case discussed in section None is non-leaf Node vs. None is Leaf.

Registering a Container-like Custom Type as Non-leaf Nodes

A container-like Python type can be registered in the type registry with a pair of functions that specify:

  • flatten_func(container) -> (children, metadata, entries): convert an instance of the container type to a (children, metadata, entries) triple, where children is an iterable of subtrees and entries is an iterable of path entries of the container (e.g., indices or keys).
  • unflatten_func(metadata, children) -> container: convert such a pair back to an instance of the container type.

The metadata is some necessary data apart from the children to reconstruct the container, e.g., the keys of the dictionary (the children are values).

The entries can be omitted (only returns a pair) or is optional to implement (returns None). If so, use range(len(children)) (i.e., flat indices) as path entries of the current node. The function signature can be flatten_func(container) -> (children, metadata) or flatten_func(container) -> (children, metadata, None).

The following examples show how to register custom types and utilize them for tree_flatten and tree_map. Please refer to section Notes about the PyTree Type Registry for more information.

# Registry a Python type with lambda functions
optree.register_pytree_node(
    set,
    # (set) -> (children, metadata, None)
    lambda s: (sorted(s), None, None),
    # (metadata, children) -> (set)
    lambda _, children: set(children),
    namespace='set',
)

# Register a Python type into a namespace
import torch

class Torch2NumpyEntry(optree.PyTreeEntry):
    def __call__(self, obj):
        assert self.entry == 0
        return obj.cpu().detach().numpy()

    def codify(self, node=''):
        assert self.entry == 0
        return f'{node}.cpu().detach().numpy()'

optree.register_pytree_node(
    torch.Tensor,
    # (tensor) -> (children, metadata)
    flatten_func=lambda tensor: (
        (tensor.cpu().detach().numpy(),),
        {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
    ),
    # (metadata, children) -> tensor
    unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
    path_entry_type=Torch2NumpyEntry,
    namespace='torch2numpy',
)
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
>>> tree
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}

# Flatten without specifying the namespace
>>> optree.tree_flatten(tree)  # `torch.Tensor`s are leaf nodes
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))

# Flatten with the namespace
>>> leaves, treespec = optree.tree_flatten(tree, namespace='torch2numpy')
>>> leaves, treespec
(
    [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
    PyTreeSpec(
        {
            'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
            'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
        },
        namespace='torch2numpy'
    )
)

# `entries` are not defined and use `range(len(children))`
>>> optree.tree_paths(tree, namespace='torch2numpy')
[('bias', 0), ('weight', 0)]

# Custom path entry type defines the pytree access behavior
>>> optree.tree_accessors(tree, namespace='torch2numpy')
[
    PyTreeAccessor(*['bias'].cpu().detach().numpy(), (MappingEntry(key='bias', type=<class 'dict'>), Torch2NumpyEntry(entry=0, type=<class 'torch.Tensor'>))),
    PyTreeAccessor(*['weight'].cpu().detach().numpy(), (MappingEntry(key='weight', type=<class 'dict'>), Torch2NumpyEntry(entry=0, type=<class 'torch.Tensor'>)))
]

# Unflatten back to a copy of the original object
>>> optree.tree_unflatten(treespec, leaves)
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}

Users can also extend the pytree registry by decorating the custom class and defining an instance method tree_flatten and a class method tree_unflatten.

from collections import UserDict

@optree.register_pytree_node_class(namespace='mydict')
class MyDict(UserDict):
    TREE_PATH_ENTRY_TYPE = optree.MappingEntry  # used by accessor APIs

    def tree_flatten(self):  # -> (children, metadata, entries)
        reversed_keys = sorted(self.keys(), reverse=True)
        return (
            [self[key] for key in reversed_keys],  # children
            reversed_keys,  # metadata
            reversed_keys,  # entries
        )

    @classmethod
    def tree_unflatten(cls, metadata, children):
        return cls(zip(metadata, children))
>>> tree = MyDict(b=4, a=(2, 3), c=MyDict({'d': 5, 'f': 6}))

# Flatten without specifying the namespace
>>> optree.tree_flatten_with_path(tree)  # `MyDict`s are leaf nodes
(
    [()],
    [MyDict(b=4, a=(2, 3), c=MyDict({'d': 5, 'f': 6}))],
    PyTreeSpec(*)
)

# Flatten with the namespace
>>> optree.tree_flatten_with_path(tree, namespace='mydict')
(
    [('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)],
    [6, 5, 4, 2, 3],
    PyTreeSpec(
        CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyDict[['f', 'd']], [*, *]), *, (*, *)]),
        namespace='mydict'
    )
)
>>> optree.tree_flatten_with_accessor(tree, namespace='mydict')
(
    [
        PyTreeAccessor(*['c']['f'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='f', type=<class 'MyDict'>))),
        PyTreeAccessor(*['c']['d'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='d', type=<class 'MyDict'>))),
        PyTreeAccessor(*['b'], (MappingEntry(key='b', type=<class 'MyDict'>),)),
        PyTreeAccessor(*['a'][0], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
        PyTreeAccessor(*['a'][1], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
    ],
    [6, 5, 4, 2, 3],
    PyTreeSpec(
        CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyDict[['f', 'd']], [*, *]), *, (*, *)]),
        namespace='mydict'
    )
)

Notes about the PyTree Type Registry

There are several key attributes of the pytree type registry:

  1. The type registry is per-interpreter-dependent. This means registering a custom type in the registry affects all modules that use OpTree.

[!WARNING] For safety reasons, a namespace must be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.

  1. The elements in the type registry are immutable. Users can neither register the same type twice in the same namespace (i.e., update the type registry), nor remove a type from the type registry. To update the behavior of an already registered type, simply register it again with another namespace.

  2. Users cannot modify the behavior of already registered built-in types listed in Built-in PyTree Node Types, such as key order sorting for dict and collections.defaultdict.

  3. Inherited subclasses are not implicitly registered. The registry lookup uses type(obj) is registered_type rather than isinstance(obj, registered_type). Users need to register the subclasses explicitly. To register all subclasses, it is easy to implement with metaclass or __init_subclass__, for example:

    from collections import UserDict
    
    @optree.register_pytree_node_class(namespace='mydict')
    class MyDict(UserDict):
        TREE_PATH_ENTRY_TYPE = optree.MappingEntry  # used by accessor APIs
    
        def __init_subclass__(cls):  # define this in the base class
            super().__init_subclass__()
            # Register a subclass to namespace 'mydict'
            optree.register_pytree_node_class(cls, namespace='mydict')
    
        def tree_flatten(self):  # -> (children, metadata, entries)
            reversed_keys = sorted(self.keys(), reverse=True)
            return (
                [self[key] for key in reversed_keys],  # children
                reversed_keys,  # metadata
                reversed_keys,  # entries
            )
    
        @classmethod
        def tree_unflatten(cls, metadata, children):
            return cls(zip(metadata, children))
    
    # Subclasses will be automatically registered in namespace 'mydict'
    class MyAnotherDict(MyDict):
        pass
    
    >>> tree = MyDict(b=4, a=(2, 3), c=MyAnotherDict({'d': 5, 'f': 6}))
    >>> optree.tree_flatten_with_path(tree, namespace='mydict')
    (
        [('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)],
        [6, 5, 4, 2, 3],
        PyTreeSpec(
            CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyAnotherDict[['f', 'd']], [*, *]), *, (*, *)]),
            namespace='mydict'
        )
    )
    >>> optree.tree_accessors(tree, namespace='mydict')
    [
        PyTreeAccessor(*['c']['f'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='f', type=<class 'MyAnotherDict'>))),
        PyTreeAccessor(*['c']['d'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='d', type=<class 'MyAnotherDict'>))),
        PyTreeAccessor(*['b'], (MappingEntry(key='b', type=<class 'MyDict'>),)),
        PyTreeAccessor(*['a'][0], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
        PyTreeAccessor(*['a'][1], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
    ]
    
  4. Be careful about the potential infinite recursion of the custom flatten function. The returned children from the custom flatten function are considered subtrees. They will be further flattened recursively. The children can have the same type as the current node. Users must design their termination condition carefully.

    import numpy as np
    import torch
    
    optree.register_pytree_node(
        np.ndarray,
        # Children are nest lists of Python objects
        lambda array: (np.atleast_1d(array).tolist(), array.ndim == 0),
        lambda scalar, rows: np.asarray(rows) if not scalar else np.asarray(rows[0]),
        namespace='numpy1',
    )
    
    optree.register_pytree_node(
        np.ndarray,
        # Children are Python objects
        lambda array: (
            list(array.ravel()),  # list(1DArray[T]) -> List[T]
            dict(shape=array.shape, dtype=array.dtype)
        ),
        lambda metadata, children: np.asarray(children, dtype=metadata['dtype']).reshape(metadata['shape']),
        namespace='numpy2',
    )
    
    optree.register_pytree_node(
        np.ndarray,
        # Returns a list of `np.ndarray`s without termination condition
        lambda array: ([array.ravel()], array.dtype),
        lambda shape, children: children[0].reshape(shape),
        namespace='numpy3',
    )
    
    optree.register_pytree_node(
        torch.Tensor,
        # Children are nest lists of Python objects
        lambda tensor: (torch.atleast_1d(tensor).tolist(), tensor.ndim == 0),
        lambda scalar, rows: torch.tensor(rows) if not scalar else torch.tensor(rows[0])),
        namespace='torch1',
    )
    
    optree.register_pytree_node(
        torch.Tensor,
        # Returns a list of `torch.Tensor`s without termination condition
        lambda tensor: (
            list(tensor.view(-1)),  # list(1DTensor[T]) -> List[0DTensor[T]] (STILL TENSORS!)
            tensor.shape
        ),
        lambda shape, children: torch.stack(children).reshape(shape),
        namespace='torch2',
    )
    
    >>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy1')
    (
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        PyTreeSpec(
            CustomTreeNode(ndarray[False], [[*, *, *], [*, *, *], [*, *, *]]),
            namespace='numpy1'
        )
    )
    # Implicitly casts `float`s to `np.float64`
    >>> optree.tree_map(lambda x: x + 1.5, np.arange(9).reshape(3, 3), namespace='numpy1')
    array([[1.5, 2.5, 3.5],
           [4.5, 5.5, 6.5],
           [7.5, 8.5, 9.5]])
    
    >>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy2')
    (
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        PyTreeSpec(
            CustomTreeNode(ndarray[{'shape': (3, 3), 'dtype': dtype('int64')}], [*, *, *, *, *, *, *, *, *]),
            namespace='numpy2'
        )
    )
    # Explicitly casts `float`s to `np.int64`
    >>> optree.tree_map(lambda x: x + 1.5, np.arange(9).reshape(3, 3), namespace='numpy2')
    array([[1, 2, 3],
           [4, 5, 6],
           [7, 8, 9]])
    
    # Children are also `np.ndarray`s, recurse without termination condition.
    >>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy3')
    Traceback (most recent call last):
        ...
    RecursionError: Maximum recursion depth exceeded during flattening the tree.
    
    >>> optree.tree_flatten(torch.arange(9).reshape(3, 3), namespace='torch1')
    (
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        PyTreeSpec(
            CustomTreeNode(Tensor[False], [[*, *, *], [*, *, *], [*, *, *]]),
            namespace='torch1'
        )
    )
    # Implicitly casts `float`s to `torch.float32`
    >>> optree.tree_map(lambda x: x + 1.5, torch.arange(9).reshape(3, 3), namespace='torch1')
    tensor([[1.5000, 2.5000, 3.5000],
            [4.5000, 5.5000, 6.5000],
            [7.5000, 8.5000, 9.5000]])
    
    # Children are also `torch.Tensor`s, recurse without termination condition.
    >>> optree.tree_flatten(torch.arange(9).reshape(3, 3), namespace='torch2')
    Traceback (most recent call last):
        ...
    RecursionError: Maximum recursion depth exceeded during flattening the tree.
    

None is Non-leaf Node vs. None is Leaf

The None object is a special object in the Python language. It serves some of the same purposes as null (a pointer does not point to anything) in other programming languages, which denotes a variable is empty or marks default parameters. However, the None object is a singleton object rather than a pointer. It may also serve as a sentinel value. In addition, if a function has returned without any return value or the return statement is omitted, the function will also implicitly return the None object.

By default, the None object is considered a non-leaf node in the tree with arity 0, i.e., a non-leaf node that has no children. This is like the behavior of an empty tuple. While flattening a tree, it will remain in the tree structure definitions rather than in the leaves list.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> optree.tree_flatten(tree)
([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}))
>>> optree.tree_flatten(tree, none_is_leaf=True)
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
>>> optree.tree_flatten(1)
([1], PyTreeSpec(*))
>>> optree.tree_flatten(None)
([], PyTreeSpec(None))
>>> optree.tree_flatten(None, none_is_leaf=True)
([None], PyTreeSpec(*, NoneIsLeaf))

OpTree provides a keyword argument none_is_leaf to determine whether to consider the None object as a leaf, like other opaque objects. If none_is_leaf=True, the None object will be placed in the leaves list. Otherwise, the None object will remain in the tree specification (structure).

>>> import torch

>>> linear = torch.nn.Linear(in_features=3, out_features=2, bias=False)
>>> linear._parameters  # a container has None
OrderedDict({
    'weight': Parameter containing:
              tensor([[-0.6677,  0.5209,  0.3295],
                      [-0.4876, -0.3142,  0.1785]], requires_grad=True),
    'bias': None
})

>>> optree.tree_map(torch.zeros_like, linear._parameters)
OrderedDict({
    'weight': tensor([[0., 0., 0.],
                      [0., 0., 0.]]),
    'bias': None
})

>>> optree.tree_map(torch.zeros_like, linear._parameters, none_is_leaf=True)
Traceback (most recent call last):
    ...
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not NoneType

>>> optree.tree_map(lambda t: torch.zeros_like(t) if t is not None else 0, linear._parameters, none_is_leaf=True)
OrderedDict({
    'weight': tensor([[0., 0., 0.],
                      [0., 0., 0.]]),
    'bias': 0
})

Key Ordering for Dictionaries

The built-in Python dictionary (i.e., builtins.dict) is an unordered mapping that holds the keys and values. The leaves of a dictionary are the values. Although since Python 3.6, the built-in dictionary is insertion ordered (PEP 468). The dictionary equality operator (==) does not check for key ordering. To ensure referential transparency that "equal dict" implies "equal ordering of leaves", the order of values of the dictionary is sorted by the keys. This behavior is also applied to collections.defaultdict.

>>> optree.tree_flatten({'a': [1, 2], 'b': [3]})
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))
>>> optree.tree_flatten({'b': [3], 'a': [1, 2]})
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))

If users want to keep the values in the insertion order in pytree traversal, they should use collections.OrderedDict, which will take the order of keys under consideration:

>>> OrderedDict([('a', [1, 2]), ('b', [3])]) == OrderedDict([('b', [3]), ('a', [1, 2])])
False
>>> optree.tree_flatten(OrderedDict([('a', [1, 2]), ('b', [3])]))
([1, 2, 3], PyTreeSpec(OrderedDict({'a': [*, *], 'b': [*]})))
>>> optree.tree_flatten(OrderedDict([('b', [3]), ('a', [1, 2])]))
([3, 1, 2], PyTreeSpec(OrderedDict({'b': [*], 'a': [*, *]})))

Since OpTree v0.9.0, the key order of the reconstructed output dictionaries from tree_unflatten is guaranteed to be consistent with the key order of the input dictionaries in tree_flatten.

>>> leaves, treespec = optree.tree_flatten({'b': [3], 'a': [1, 2]})
>>> leaves, treespec
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))
>>> optree.tree_unflatten(treespec, leaves)
{'b': [3], 'a': [1, 2]}
>>> optree.tree_map(lambda x: x, {'b': [3], 'a': [1, 2]})
{'b': [3], 'a': [1, 2]}
>>> optree.tree_map(lambda x: x + 1, {'b': [3], 'a': [1, 2]})
{'b': [4], 'a': [2, 3]}

This property is also preserved during serialization/deserialization.

>>> leaves, treespec = optree.tree_flatten({'b': [3], 'a': [1, 2]})
>>> leaves, treespec
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))
>>> restored_treespec = pickle.loads(pickle.dumps(treespec))
>>> optree.tree_unflatten(treespec, leaves)
{'b': [3], 'a': [1, 2]}
>>> optree.tree_unflatten(restored_treespec, leaves)
{'b': [3], 'a': [1, 2]}

[!NOTE] Note that there are no restrictions on the dict to require the keys to be comparable (sortable). There can be multiple types of keys in the dictionary. The keys are sorted in ascending order by key=lambda k: k first if capable otherwise fallback to key=lambda k: (f'{k.__class__.__module__}.{k.__class__.__qualname__}', k). This handles most cases.

>>> sorted({1: 2, 1.5: 1}.keys())
[1, 1.5]
>>> sorted({'a': 3, 1: 2, 1.5: 1}.keys())
Traceback (most recent call last):
    ...
TypeError: '<' not supported between instances of 'int' and 'str'
>>> sorted({'a': 3, 1: 2, 1.5: 1}.keys(), key=lambda k: (f'{k.__class__.__module__}.{k.__class__.__qualname__}', k))
[1.5, 1, 'a']

Benchmark

We benchmark the performance of:

  • tree flatten
  • tree unflatten
  • tree copy (i.e., unflatten(flatten(...)))
  • tree map

compared with the following libraries:

Average Time Cost (↓) OpTree (v0.9.0) JAX XLA (v0.4.6) PyTorch (v2.0.0) DM-Tree (v0.1.8)
Tree Flatten x1.00 2.33 22.05 1.12
Tree UnFlatten x1.00 2.69 4.28 16.23
Tree Flatten with Path x1.00 16.16 Not Supported 27.59
Tree Copy x1.00 2.56 9.97 11.02
Tree Map x1.00 2.56 9.58 10.62
Tree Map (nargs) x1.00 2.89 Not Supported 31.33
Tree Map with Path x1.00 7.23 Not Supported 19.66
Tree Map with Path (nargs) x1.00 6.56 Not Supported 29.61

All results are reported on a workstation with an AMD Ryzen 9 5950X CPU @ 4.45GHz in an isolated virtual environment with Python 3.10.9. Run with the following commands:

conda create --name optree-benchmark anaconda::python=3.10 --yes --no-default-packages
conda activate optree-benchmark
python3 -m pip install --editable '.[benchmark]' --extra-index-url https://download.pytorch.org/whl/cpu
python3 benchmark.py --number=10000 --repeat=5

The test inputs are nested containers (i.e., pytrees) extracted from torch.nn.Module objects. They are:

tiny_mlp = nn.Sequential(
    nn.Linear(1, 1, bias=True),
    nn.BatchNorm1d(1, affine=True, track_running_stats=True),
    nn.ReLU(),
    nn.Linear(1, 1, bias=False),
    nn.Sigmoid(),
)

and AlexNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, VisionTransformerH14 (ViT-H/14), and SwinTransformerB (Swin-B) from torchvsion. Please refer to benchmark.py for more details.

Tree Flatten

Module Nodes OpTree (μs) JAX XLA (μs) PyTorch (μs) DM-Tree (μs) Speedup (J / O) Speedup (P / O) Speedup (D / O)
TinyMLP 53 29.70 71.06 583.66 31.32 2.39 19.65 1.05
AlexNet 188 103.92 262.56 2304.36 119.61 2.53 22.17 1.15
ResNet18 698 368.06 852.69 8440.31 420.43 2.32 22.93 1.14
ResNet34 1242 644.96 1461.55 14498.81 712.81 2.27 22.48 1.11
ResNet50 1702 919.95 2080.58 20995.96 1006.42 2.26 22.82 1.09
ResNet101 3317 1806.36 3996.90 40314.12 1955.48 2.21 22.32 1.08
ResNet152 4932 2656.92 5812.38 57775.53 2826.92 2.19 21.75 1.06
ViT-H/14 3420 1863.50 4418.24 41334.64 2128.71 2.37 22.18 1.14
Swin-B 2881 1631.06 3944.13 36131.54 2032.77 2.42 22.15 1.25
Average 2.33 22.05 1.12

Tree UnFlatten

Module Nodes OpTree (μs) JAX XLA (μs) PyTorch (μs) DM-Tree (μs) Speedup (J / O) Speedup (P / O) Speedup (D / O)
TinyMLP 53 55.13 152.07 231.94 940.11 2.76 4.21 17.05
AlexNet 188 226.29 678.29 972.90 4195.04 3.00 4.30 18.54
ResNet18 698 766.54 1953.26 3137.86 12049.88 2.55 4.09 15.72
ResNet34 1242 1309.22 3526.12 5759.16 20966.75 2.69 4.40 16.01
ResNet50 1702 1914.96 5002.83 8369.43 29597.10 2.61 4.37 15.46
ResNet101 3317 3672.61 9633.29 15683.16 57240.20 2.62 4.27 15.59
ResNet152 4932 5407.58 13970.88 23074.68 82072.54 2.58 4.27 15.18
ViT-H/14 3420 4013.18 11146.31 17633.07 66723.58 2.78 4.39 16.63
Swin-B 2881 3595.34 9505.31 15054.88 57310.03 2.64 4.19 15.94
Average 2.69 4.28 16.23

Tree Flatten with Path

Module Nodes OpTree (μs) JAX XLA (μs) PyTorch (μs) DM-Tree (μs) Speedup (J / O) Speedup (P / O) Speedup (D / O)
TinyMLP 53 36.49 543.67 N/A 919.13 14.90 N/A 25.19
AlexNet 188 115.44 2185.21 N/A 3752.11 18.93 N/A 32.50
ResNet18 698 431.84 7106.55 N/A 12286.70 16.46 N/A 28.45
ResNet34 1242 845.61 13431.99 N/A 22860.48 15.88 N/A 27.03
ResNet50 1702 1166.27 18426.52 N/A 31225.05 15.80 N/A 26.77
ResNet101 3317 2312.77 34770.49 N/A 59346.86 15.03 N/A 25.66
ResNet152 4932 3304.74 50557.25 N/A 85847.91 15.30 N/A 25.98
ViT-H/14 3420 2235.25 37473.53 N/A 64105.24 16.76 N/A 28.68
Swin-B 2881 1970.25 32205.83 N/A 55177.50 16.35 N/A 28.01
Average 16.16 N/A 27.59

Tree Copy

Module Nodes OpTree (μs) JAX XLA (μs) PyTorch (μs) DM-Tree (μs) Speedup (J / O) Speedup (P / O) Speedup (D / O)
TinyMLP 53 89.81 232.26 845.20 981.48 2.59 9.41 10.93
AlexNet 188 334.58 959.32 3360.46 4316.05 2.87 10.04 12.90
ResNet18 698 1128.11 2840.71 11471.07 12297.07 2.52 10.17 10.90
ResNet34 1242 2160.57 5333.10 20563.06 21901.91 2.47 9.52 10.14
ResNet50 1702 2746.84 6823.88 29705.99 28927.88 2.48 10.81 10.53
ResNet101 3317 5762.05 13481.45 56968.78 60115.93 2.34 9.89 10.43
ResNet152 4932 8151.21 20805.61 81024.06 84079.57 2.55 9.94 10.31
ViT-H/14 3420 5963.61 15665.91 59813.52 68377.82 2.63 10.03 11.47
Swin-B 2881 5401.59 14255.33 53361.77 62317.07 2.64 9.88 11.54
Average 2.56 9.97 11.02

Tree Map

Module Nodes OpTree (μs) JAX XLA (μs) PyTorch (μs) DM-Tree (μs) Speedup (J / O) Speedup (P / O) Speedup (D / O)
TinyMLP 53 95.13 243.86 867.34 1026.99 2.56 9.12 10.80
AlexNet 188 348.44 987.57 3398.32 4354.81 2.83 9.75 12.50
ResNet18 698 1190.62 2982.66 11719.94 12559.01 2.51 9.84 10.55
ResNet34 1242 2205.87 5417.60 20935.72 22308.51 2.46 9.49 10.11
ResNet50 1702 3128.48 7579.55 30372.71 31638.67 2.42 9.71 10.11
ResNet101 3317 6173.05 14846.57 59167.85 60245.42 2.41 9.58 9.76
ResNet152 4932 8641.22 22000.74 84018.65 86182.21 2.55 9.72 9.97
ViT-H/14 3420 6211.79 17077.49 59790.25 69763.86 2.75 9.63 11.23
Swin-B 2881 5673.66 14339.69 53309.17 59764.61 2.53 9.40 10.53
Average 2.56 9.58 10.62

Tree Map (nargs)

Module Nodes OpTree (μs) JAX XLA (μs) PyTorch (μs) DM-Tree (μs) Speedup (J / O) Speedup (P / O) Speedup (D / O)
TinyMLP 53 137.06 389.96 N/A 3908.77 2.85 N/A 28.52
AlexNet 188 467.24 1496.96 N/A 15395.13 3.20 N/A 32.95
ResNet18 698 1603.79 4534.01 N/A 50323.76 2.83 N/A 31.38
ResNet34 1242 2907.64 8435.33 N/A 90389.23 2.90 N/A 31.09
ResNet50 1702 4183.77 11382.51 N/A 121777.01 2.72 N/A 29.11
ResNet101 3317 7721.13 22247.85 N/A 238755.17 2.88 N/A 30.92
ResNet152 4932 11508.05 31429.39 N/A 360257.74 2.73 N/A 31.30
ViT-H/14 3420 8294.20 24524.86 N/A 270514.87 2.96 N/A 32.61
Swin-B 2881 7074.62 20854.80 N/A 241120.41 2.95 N/A 34.08
Average 2.89 N/A 31.33

Tree Map with Path

Module Nodes OpTree (μs) JAX XLA (μs) PyTorch (μs) DM-Tree (μs) Speedup (J / O) Speedup (P / O) Speedup (D / O)
TinyMLP 53 109.82 778.30 N/A 2186.40 7.09 N/A 19.91
AlexNet 188 365.16 2939.36 N/A 8355.37 8.05 N/A 22.88
ResNet18 698 1308.26 9529.58 N/A 25758.24 7.28 N/A 19.69
ResNet34 1242 2527.21 18084.89 N/A 45942.32 7.16 N/A 18.18
ResNet50 1702 3226.03 22935.53 N/A 61275.34 7.11 N/A 18.99
ResNet101 3317 6663.52 46878.89 N/A 126642.14 7.04 N/A 19.01
ResNet152 4932 9378.19 66136.44 N/A 176981.01 7.05 N/A 18.87
ViT-H/14 3420 7033.69 50418.37 N/A 142508.11 7.17 N/A 20.26
Swin-B 2881 6078.15 43173.22 N/A 116612.71 7.10 N/A 19.19
Average 7.23 N/A 19.66

Tree Map with Path (nargs)

Module Nodes OpTree (μs) JAX XLA (μs) PyTorch (μs) DM-Tree (μs) Speedup (J / O) Speedup (P / O) Speedup (D / O)
TinyMLP 53 146.05 917.00 N/A 3940.61 6.28 N/A 26.98
AlexNet 188 489.27 3560.76 N/A 15434.71 7.28 N/A 31.55
ResNet18 698 1712.79 11171.44 N/A 50219.86 6.52 N/A 29.32
ResNet34 1242 3112.83 21024.58 N/A 95505.71 6.75 N/A 30.68
ResNet50 1702 4220.70 26600.82 N/A 121897.57 6.30 N/A 28.88
ResNet101 3317 8631.34 54372.37 N/A 236555.54 6.30 N/A 27.41
ResNet152 4932 12710.49 77643.13 N/A 353600.32 6.11 N/A 27.82
ViT-H/14 3420 8753.09 58712.71 N/A 286365.36 6.71 N/A 32.72
Swin-B 2881 7359.29 50112.23 N/A 228866.66 6.81 N/A 31.10
Average 6.56 N/A 29.61

Changelog

See CHANGELOG.md.


License

OpTree is released under the Apache License 2.0.

OpTree is heavily based on JAX's implementation of the PyTree utility, with deep refactoring and several improvements. The original licenses can be found at JAX's Apache License 2.0 and Tensorflow's Apache License 2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

optree-0.13.1.tar.gz (155.7 kB view details)

Uploaded Source

Built Distributions

optree-0.13.1-pp310-pypy310_pp73-win_amd64.whl (285.2 kB view details)

Uploaded PyPy Windows x86-64

optree-0.13.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (384.2 kB view details)

Uploaded PyPy manylinux: glibc 2.17+ x86-64

optree-0.13.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl (394.6 kB view details)

Uploaded PyPy manylinux: glibc 2.17+ i686

optree-0.13.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (350.1 kB view details)

Uploaded PyPy manylinux: glibc 2.17+ ARM64

optree-0.13.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl (319.1 kB view details)

Uploaded PyPy macOS 11.0+ ARM64

optree-0.13.1-pp39-pypy39_pp73-win_amd64.whl (285.2 kB view details)

Uploaded PyPy Windows x86-64

optree-0.13.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (383.9 kB view details)

Uploaded PyPy manylinux: glibc 2.17+ x86-64

optree-0.13.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl (394.7 kB view details)

Uploaded PyPy manylinux: glibc 2.17+ i686

optree-0.13.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (349.9 kB view details)

Uploaded PyPy manylinux: glibc 2.17+ ARM64

optree-0.13.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl (318.9 kB view details)

Uploaded PyPy macOS 11.0+ ARM64

optree-0.13.1-cp313-cp313t-win_arm64.whl (331.5 kB view details)

Uploaded CPython 3.13t Windows ARM64

optree-0.13.1-cp313-cp313t-win_amd64.whl (331.5 kB view details)

Uploaded CPython 3.13t Windows x86-64

optree-0.13.1-cp313-cp313t-win32.whl (292.7 kB view details)

Uploaded CPython 3.13t Windows x86

optree-0.13.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (405.4 kB view details)

Uploaded CPython 3.13t manylinux: glibc 2.17+ x86-64

optree-0.13.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl (389.2 kB view details)

Uploaded CPython 3.13t manylinux: glibc 2.17+ s390x

optree-0.13.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl (415.0 kB view details)

Uploaded CPython 3.13t manylinux: glibc 2.17+ ppc64le

optree-0.13.1-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl (418.7 kB view details)

Uploaded CPython 3.13t manylinux: glibc 2.17+ i686

optree-0.13.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (374.7 kB view details)

Uploaded CPython 3.13t manylinux: glibc 2.17+ ARM64

optree-0.13.1-cp313-cp313t-macosx_11_0_arm64.whl (370.7 kB view details)

Uploaded CPython 3.13t macOS 11.0+ ARM64

optree-0.13.1-cp313-cp313t-macosx_10_13_universal2.whl (702.9 kB view details)

Uploaded CPython 3.13t macOS 10.13+ universal2 (ARM64, x86-64)

optree-0.13.1-cp313-cp313-win_arm64.whl (293.7 kB view details)

Uploaded CPython 3.13 Windows ARM64

optree-0.13.1-cp313-cp313-win_amd64.whl (293.7 kB view details)

Uploaded CPython 3.13 Windows x86-64

optree-0.13.1-cp313-cp313-win32.whl (264.3 kB view details)

Uploaded CPython 3.13 Windows x86

optree-0.13.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (390.2 kB view details)

Uploaded CPython 3.13 manylinux: glibc 2.17+ x86-64

optree-0.13.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl (369.4 kB view details)

Uploaded CPython 3.13 manylinux: glibc 2.17+ s390x

optree-0.13.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl (397.0 kB view details)

Uploaded CPython 3.13 manylinux: glibc 2.17+ ppc64le

optree-0.13.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl (402.6 kB view details)

Uploaded CPython 3.13 manylinux: glibc 2.17+ i686

optree-0.13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (355.8 kB view details)

Uploaded CPython 3.13 manylinux: glibc 2.17+ ARM64

optree-0.13.1-cp313-cp313-macosx_11_0_arm64.whl (325.7 kB view details)

Uploaded CPython 3.13 macOS 11.0+ ARM64

optree-0.13.1-cp313-cp313-macosx_10_13_universal2.whl (608.3 kB view details)

Uploaded CPython 3.13 macOS 10.13+ universal2 (ARM64, x86-64)

optree-0.13.1-cp312-cp312-win_arm64.whl (292.0 kB view details)

Uploaded CPython 3.12 Windows ARM64

optree-0.13.1-cp312-cp312-win_amd64.whl (292.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

optree-0.13.1-cp312-cp312-win32.whl (261.6 kB view details)

Uploaded CPython 3.12 Windows x86

optree-0.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (385.5 kB view details)

Uploaded CPython 3.12 manylinux: glibc 2.17+ x86-64

optree-0.13.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl (365.2 kB view details)

Uploaded CPython 3.12 manylinux: glibc 2.17+ s390x

optree-0.13.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl (392.9 kB view details)

Uploaded CPython 3.12 manylinux: glibc 2.17+ ppc64le

optree-0.13.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl (399.3 kB view details)

Uploaded CPython 3.12 manylinux: glibc 2.17+ i686

optree-0.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (352.7 kB view details)

Uploaded CPython 3.12 manylinux: glibc 2.17+ ARM64

optree-0.13.1-cp312-cp312-macosx_11_0_arm64.whl (322.3 kB view details)

Uploaded CPython 3.12 macOS 11.0+ ARM64

optree-0.13.1-cp312-cp312-macosx_10_13_universal2.whl (601.0 kB view details)

Uploaded CPython 3.12 macOS 10.13+ universal2 (ARM64, x86-64)

optree-0.13.1-cp311-cp311-win_arm64.whl (292.2 kB view details)

Uploaded CPython 3.11 Windows ARM64

optree-0.13.1-cp311-cp311-win_amd64.whl (292.2 kB view details)

Uploaded CPython 3.11 Windows x86-64

optree-0.13.1-cp311-cp311-win32.whl (260.9 kB view details)

Uploaded CPython 3.11 Windows x86

optree-0.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (391.8 kB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

optree-0.13.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl (368.7 kB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ s390x

optree-0.13.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl (398.4 kB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ ppc64le

optree-0.13.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl (402.5 kB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ i686

optree-0.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (357.3 kB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ ARM64

optree-0.13.1-cp311-cp311-macosx_11_0_arm64.whl (318.6 kB view details)

Uploaded CPython 3.11 macOS 11.0+ ARM64

optree-0.13.1-cp311-cp311-macosx_10_9_universal2.whl (589.6 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

optree-0.13.1-cp310-cp310-win_arm64.whl (282.7 kB view details)

Uploaded CPython 3.10 Windows ARM64

optree-0.13.1-cp310-cp310-win_amd64.whl (282.7 kB view details)

Uploaded CPython 3.10 Windows x86-64

optree-0.13.1-cp310-cp310-win32.whl (255.1 kB view details)

Uploaded CPython 3.10 Windows x86

optree-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (381.3 kB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

optree-0.13.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl (359.5 kB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ s390x

optree-0.13.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl (387.6 kB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ ppc64le

optree-0.13.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl (389.8 kB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ i686

optree-0.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (346.2 kB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ ARM64

optree-0.13.1-cp310-cp310-macosx_11_0_arm64.whl (311.8 kB view details)

Uploaded CPython 3.10 macOS 11.0+ ARM64

optree-0.13.1-cp310-cp310-macosx_10_9_universal2.whl (576.6 kB view details)

Uploaded CPython 3.10 macOS 10.9+ universal2 (ARM64, x86-64)

optree-0.13.1-cp39-cp39-win_arm64.whl (277.7 kB view details)

Uploaded CPython 3.9 Windows ARM64

optree-0.13.1-cp39-cp39-win_amd64.whl (277.7 kB view details)

Uploaded CPython 3.9 Windows x86-64

optree-0.13.1-cp39-cp39-win32.whl (255.1 kB view details)

Uploaded CPython 3.9 Windows x86

optree-0.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (381.2 kB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

optree-0.13.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl (359.9 kB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ s390x

optree-0.13.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl (387.7 kB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ ppc64le

optree-0.13.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl (390.0 kB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ i686

optree-0.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (347.3 kB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ ARM64

optree-0.13.1-cp39-cp39-macosx_11_0_arm64.whl (311.8 kB view details)

Uploaded CPython 3.9 macOS 11.0+ ARM64

optree-0.13.1-cp39-cp39-macosx_10_9_universal2.whl (577.0 kB view details)

Uploaded CPython 3.9 macOS 10.9+ universal2 (ARM64, x86-64)

optree-0.13.1-cp38-cp38-win_amd64.whl (282.6 kB view details)

Uploaded CPython 3.8 Windows x86-64

optree-0.13.1-cp38-cp38-win32.whl (254.9 kB view details)

Uploaded CPython 3.8 Windows x86

optree-0.13.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (380.8 kB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ x86-64

optree-0.13.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl (359.4 kB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ s390x

optree-0.13.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl (388.0 kB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ ppc64le

optree-0.13.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl (389.4 kB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ i686

optree-0.13.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (345.8 kB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ ARM64

optree-0.13.1-cp38-cp38-macosx_11_0_arm64.whl (311.6 kB view details)

Uploaded CPython 3.8 macOS 11.0+ ARM64

optree-0.13.1-cp38-cp38-macosx_10_9_universal2.whl (576.7 kB view details)

Uploaded CPython 3.8 macOS 10.9+ universal2 (ARM64, x86-64)

optree-0.13.1-cp37-cp37m-win_amd64.whl (281.0 kB view details)

Uploaded CPython 3.7m Windows x86-64

optree-0.13.1-cp37-cp37m-win32.whl (257.9 kB view details)

Uploaded CPython 3.7m Windows x86

optree-0.13.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (384.3 kB view details)

Uploaded CPython 3.7m manylinux: glibc 2.17+ x86-64

optree-0.13.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl (368.8 kB view details)

Uploaded CPython 3.7m manylinux: glibc 2.17+ s390x

optree-0.13.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl (399.3 kB view details)

Uploaded CPython 3.7m manylinux: glibc 2.17+ ppc64le

optree-0.13.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl (397.5 kB view details)

Uploaded CPython 3.7m manylinux: glibc 2.17+ i686

optree-0.13.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (356.8 kB view details)

Uploaded CPython 3.7m manylinux: glibc 2.17+ ARM64

optree-0.13.1-cp37-cp37m-macosx_10_9_x86_64.whl (333.5 kB view details)

Uploaded CPython 3.7m macOS 10.9+ x86-64

File details

Details for the file optree-0.13.1.tar.gz.

File metadata

  • Download URL: optree-0.13.1.tar.gz
  • Upload date:
  • Size: 155.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1.tar.gz
Algorithm Hash digest
SHA256 af67856aa8073d237fe67313d84f8aeafac32c1cef7239c628a2768d02679c43
MD5 89e2f51295b0c6cbe8e6347c3f3bd934
BLAKE2b-256 f7f256afdaeaae36b076659be7db8e72be0924dd64ebd1c131675c77f7e704a6

See more details on using hashes here.

File details

Details for the file optree-0.13.1-pp310-pypy310_pp73-win_amd64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-pp310-pypy310_pp73-win_amd64.whl
Algorithm Hash digest
SHA256 5c6aed6c5eabda59a91376aca08ba508a06f1c68850216a98743b5f8f55af841
MD5 32b6f761bd98d116ef15323aa17fa411
BLAKE2b-256 1fc699a4454403211cfc5853ef45a00d8238447ca779d1adddadd64154b5a355

See more details on using hashes here.

File details

Details for the file optree-0.13.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0f9707547635cfede8d79e4161c066021ffefc401d98bbf8eba452b1355a42c7
MD5 beab5002f652d6d8115c9f3997d16b48
BLAKE2b-256 42be2f9c2f1c646a9b501a7daf79a810c56fd93c87201537c1e7a76dc99bce36

See more details on using hashes here.

File details

Details for the file optree-0.13.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for optree-0.13.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 ce962f0dd387137817dcda600bd6cf2e1b65103411807b6cdbbd9ffddf1061f6
MD5 9352d1e90331289786cb7e5f17268923
BLAKE2b-256 7fe0c5c6011f52d5dd0c10390003d02ca3c1423da32aa9760e1c0c312aefa206

See more details on using hashes here.

File details

Details for the file optree-0.13.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 fafeda2e35e3270532132e27b471ea3e3aeac18f7966a4d0469137d1f36046ec
MD5 5a5579a9c2133cf13a926cbb3cd75b40
BLAKE2b-256 f282d1da61d4cdcc9004a95e44227525ac74dc5012f0ca1c39becf841c57acc3

See more details on using hashes here.

File details

Details for the file optree-0.13.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 0f1bde49e41a158af28d99fae1bd425fbd664907c53cf595106fb5b35e5cbe26
MD5 792a5d7627c3e0163524b6df5350a734
BLAKE2b-256 b876a8ca4e1866f56c1119b2e3eb6b0c0e248e9c45126fa9f6de6a64e4e39e49

See more details on using hashes here.

File details

Details for the file optree-0.13.1-pp39-pypy39_pp73-win_amd64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-pp39-pypy39_pp73-win_amd64.whl
Algorithm Hash digest
SHA256 3d0161012d80e4865017e10298ac55652cc3ad9a3eae9440229d4bf00b140e01
MD5 ad8909a2f020aa5677bccc369504e807
BLAKE2b-256 e870d5805700831b3d2327ee4ff348f4e72cb9d2f814d1ae3b8c39cdbe0e6c02

See more details on using hashes here.

File details

Details for the file optree-0.13.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 652287e43fcbb29b8d1821144987e3bc558be4e5eec0d42fce7007cc3ee8e574
MD5 205e3dd9d2da3b9d128aaad61c265a32
BLAKE2b-256 7169c414399938bbbaa72ef0193af27f9bbb0ba808dcd8a64abc1b3a295a025a

See more details on using hashes here.

File details

Details for the file optree-0.13.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for optree-0.13.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 395ac2eb69528613fd0f2ee8706890b7921b8ff3159df53b6e9f67eaf519c5cb
MD5 452c11f4d9ec71f3df2e720e0f704525
BLAKE2b-256 de129052e735667db7e82b8a60b81443f570191555a5c60f9ce3d71136d02be1

See more details on using hashes here.

File details

Details for the file optree-0.13.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 37948e2d796db23d6ccd07105b709b827eba26549d34dd2149e95887c89fe9b4
MD5 d749a40dea2004bd6e691c7a5eba35ae
BLAKE2b-256 caf44f4aab08bc444de41171337f9b0b7062d0fbfefc27b0f1630c1577f0959a

See more details on using hashes here.

File details

Details for the file optree-0.13.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 95298846c057cce2e7d114c03c645e86a5381b72388c8c390986bdefe69a759c
MD5 ba63f8d09364b56c2227bd14474367db
BLAKE2b-256 ca6f37cddb7bf5079afcb2acb90b33c3040ed2c9b93987738ba51dca14d27091

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313t-win_arm64.whl.

File metadata

  • Download URL: optree-0.13.1-cp313-cp313t-win_arm64.whl
  • Upload date:
  • Size: 331.5 kB
  • Tags: CPython 3.13t, Windows ARM64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp313-cp313t-win_arm64.whl
Algorithm Hash digest
SHA256 c4d13f55dbd509d27be3af54d53b4ca0751bc518244ced6d0567e518e51452a2
MD5 874fe654079f96224cc0fb20e9274c15
BLAKE2b-256 8b36c01a5bc34660d46c6a3b1fe090bbdc8c76af7b5c1a6613cc671aa6df8349

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313t-win_amd64.whl.

File metadata

  • Download URL: optree-0.13.1-cp313-cp313t-win_amd64.whl
  • Upload date:
  • Size: 331.5 kB
  • Tags: CPython 3.13t, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp313-cp313t-win_amd64.whl
Algorithm Hash digest
SHA256 d580f1bf23bb352c4db6b3544f282f1ac08dcb0d9ab537d25e56220353438cf7
MD5 fe7578cbdea8614491b02a8221a98a34
BLAKE2b-256 2f59d7601959ad0b90d309794c0975a256304488b4c5671f24e3e12101ade7ef

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313t-win32.whl.

File metadata

  • Download URL: optree-0.13.1-cp313-cp313t-win32.whl
  • Upload date:
  • Size: 292.7 kB
  • Tags: CPython 3.13t, Windows x86
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp313-cp313t-win32.whl
Algorithm Hash digest
SHA256 e40f018f522fcfd244688d1b3a360518e636ba7f636385aae0566eae3e7d29bc
MD5 28f0a04590a4fef1bc113f39b1668bd4
BLAKE2b-256 06993eb53829c4c0b6dc20115d957d2d8e945630ddf40c656dc4e39c5a6e51f2

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 55e82426bef151149cfa41d68ac957730fcd420996c0db8324fca81aa6a810ba
MD5 77e40f7c37d5dad0894e91ae7975e5b7
BLAKE2b-256 3d84bb521a66d3a84fe2f1500ef67d245c2cc1a26277fcaaf4bc70b22c06e99b

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl
Algorithm Hash digest
SHA256 025d23400b8b579462a251420f0a9ae77d3d3593f84276f3465985731d79d722
MD5 9acffa4a5432a24c9f3d0c5f0af8c7ff
BLAKE2b-256 8f377bf815f4da7234e387863228b17246b42b8c02553882581a4013a64a88d0

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl
Algorithm Hash digest
SHA256 48c29d9c6c64c8dc48c8ee97f7c1d5cdb83e37320f0be0857c06ce4b97994aea
MD5 50ae78e4e09e0383145dc45f4f084c40
BLAKE2b-256 f07ca08191e0c9202f2be9c415057eea3cf3a5af18e9a6d81f4c7b0e6faf0a1f

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 01819c3df950696f32c91faf8d376ae6b695ffdba18f330f1cab6b8e314e4612
MD5 9f3f7e92f5a23f04f7665218e2db6885
BLAKE2b-256 19f251a63a799f6dce31813d7e02a7547394aebcb39f407e62038ecbd999d490

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 1935639dd498a42367633e3877797e1330e39d44d48bbca1a136bb4dbe4c1bc9
MD5 e03fd59e5b3dfd4114db5dc38d8f87c0
BLAKE2b-256 63374ddf05267467809236203e2007e9443519c4d55e0744ce7eea1aa74dffee

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313t-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313t-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 5b5626c38d4a18a144063db5c1dbb558431d83ca10682324f74665a12214801f
MD5 0eea29d2b787665cef1568aa2d8bcb05
BLAKE2b-256 087f70a2d02110ccb245bc57bd9ad57668acfea0ff364c27d7dfe1735ede79ed

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313t-macosx_10_13_universal2.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313t-macosx_10_13_universal2.whl
Algorithm Hash digest
SHA256 3010ae24e994f6e00071098d34e98e78eb995b7454a2ef629a0bf7df17441b24
MD5 96b9871347105550f7b727c8412c77d3
BLAKE2b-256 0dd6f81e6748bcc3f35a2f570a814014e3418b0ed425d7cbc2b42d88d12863d5

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313-win_arm64.whl.

File metadata

  • Download URL: optree-0.13.1-cp313-cp313-win_arm64.whl
  • Upload date:
  • Size: 293.7 kB
  • Tags: CPython 3.13, Windows ARM64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp313-cp313-win_arm64.whl
Algorithm Hash digest
SHA256 f39c7174a3f3cdc3f5fe6fb4b832f608c40ac174d7567ed6734b2ee952094631
MD5 247a7e55f7688d7f730d8eab6e065c13
BLAKE2b-256 8034d1b1849a6240385c4a3af5da9425b11912204d0b1cf142d802815319b73a

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313-win_amd64.whl.

File metadata

  • Download URL: optree-0.13.1-cp313-cp313-win_amd64.whl
  • Upload date:
  • Size: 293.7 kB
  • Tags: CPython 3.13, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 27d81dc43b522ba47ba7d2e7d91dbb486940348b1bf85caeb0afc2815c0aa492
MD5 6eb94e4984b17418bda1fa769d960cab
BLAKE2b-256 eff98a1421181c5eb0c0f81d1423a900baeb3faba68a48747bbdffb7581239ac

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313-win32.whl.

File metadata

  • Download URL: optree-0.13.1-cp313-cp313-win32.whl
  • Upload date:
  • Size: 264.3 kB
  • Tags: CPython 3.13, Windows x86
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp313-cp313-win32.whl
Algorithm Hash digest
SHA256 5b6531cd4eb23fadbbf77faf834e1119da06d7af3154f55786b59953cd87bb8a
MD5 a0672a5dc85760370308a3bf5a9f38f2
BLAKE2b-256 06023a701d6307fdfefe4fcecbac644803e2a4314ab2406ff465e03129cc85f6

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 9c8ee1e988c634a451146b87d9ebdbf650a75dc1f52a9cffcd89fabb7289321c
MD5 e912c00f23ca8a06ef825f2922ad1810
BLAKE2b-256 9f428c08ce4ebb3d9a6e4415f1a97830c84879e2d1a43710a7c8a18b2c3e169d

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl
Algorithm Hash digest
SHA256 b5e5f09c85ae558a6bdaea57e63168082e728e777391393e9e2792f0d15b7b59
MD5 4fbc840afe2e94a20cc95e098862f033
BLAKE2b-256 16fafc2a8183e14f0d195d25824bf65095ff32b34bd469614a6c30d0a596a30f

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl
Algorithm Hash digest
SHA256 6c4ab1d391b89cb88eb3c63383d5eb0930bc21141de9d5acd277feed9e38eb65
MD5 7f3e94e79f7b50aee74aadf2974dbfc9
BLAKE2b-256 8a1d0d5bbab8c99580b732b89ef2c5fcdd6ef410478295949fdf2984fa1bfc28

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 4711f5cac5a2a49c3d6c9f0eca7b77c22b452170bb33ea01c3214ebb17931db9
MD5 5f0d9d225462d92ca98b8a9173130d32
BLAKE2b-256 e5e3587e0d28dc2cee064902adfebca97db124e12b275dbe9c2b05a70a22345f

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 bbc5fa2ff5090389f3a906567446f01d692bd6fe5cfcc5ae2d5861f24e8e0e4d
MD5 9fa433e0437b07b1db5bb3f183a1671e
BLAKE2b-256 45db08921e56f3425bf649eb593eb28775263c935d029985d35572dc5690cc1a

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 2909cb42add6bb1a5a2b0243bdd8c4b861bf072f3741e26239481907ac8ad4e6
MD5 f1d6cabd658ae51d0f288794114e3514
BLAKE2b-256 64f268beb9da2dd52baa50e7a589ed2bd8434fdd70cdba06754aa5910263da06

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp313-cp313-macosx_10_13_universal2.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp313-cp313-macosx_10_13_universal2.whl
Algorithm Hash digest
SHA256 f788b2ad120deb73b4908a74473cd6de79cfb9f33bbe9dcb59cea2e2477d4e28
MD5 64ddbe33aa26daf480c308b06be565cd
BLAKE2b-256 3f53f3727cad24f16a06666f328f1212476988cadac9b9e7919ddfb2c22eb662

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp312-cp312-win_arm64.whl.

File metadata

  • Download URL: optree-0.13.1-cp312-cp312-win_arm64.whl
  • Upload date:
  • Size: 292.0 kB
  • Tags: CPython 3.12, Windows ARM64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp312-cp312-win_arm64.whl
Algorithm Hash digest
SHA256 5da0fd26325a07354915cc4e3a9aee797cb75dff07c60d24b3f309457069abd3
MD5 2780e96e2166a8c305b7c67c399ed016
BLAKE2b-256 9fd75dec5d97c0a0c7951f0c8f5d24b4c6c8529d41ee69d0705f06bfa8b4874f

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp312-cp312-win_amd64.whl.

File metadata

  • Download URL: optree-0.13.1-cp312-cp312-win_amd64.whl
  • Upload date:
  • Size: 292.0 kB
  • Tags: CPython 3.12, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 587fb8de8e75e80fe7c7240e269630876bec3ee2038724893370976207813e4b
MD5 1f0e78f14d31f09048ff5fbbb1e24916
BLAKE2b-256 e3deb114d999746f9a9fb64476c8520ad499c11651912cecffe77aee1d5bec18

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp312-cp312-win32.whl.

File metadata

  • Download URL: optree-0.13.1-cp312-cp312-win32.whl
  • Upload date:
  • Size: 261.6 kB
  • Tags: CPython 3.12, Windows x86
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp312-cp312-win32.whl
Algorithm Hash digest
SHA256 bc9c396f64f9aacdf852713bd75f1b9a83f118660fd82e87c937c081b7ddccd1
MD5 4cbea88572b4a0eb520933f24df95b2f
BLAKE2b-256 9d58f7430d613197260fc38fead8bc974a0069c4513ea3c04f11a771daf8b20f

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d866f707b9f3a9f0e670a73fe8feee4993b2dbdbf9eef598e1cf2e5cb2876413
MD5 ace8cf837833f9804a76fa8150904b29
BLAKE2b-256 9810087a684c7b5029e3be1f335d9df422b406cbfd842c77abfa7b17085adce5

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl
Algorithm Hash digest
SHA256 3da76fc43dcc22fe58d11634a04672ca7cc270aed469ac35fd5c78b7b9bc9125
MD5 6847d0f947f97b0192a2522e5b6d95da
BLAKE2b-256 e3ec6041c3ffe04af5890af7ab2b5f0ca48253032dce32aa5cddf8188ad4cc4b

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl
Algorithm Hash digest
SHA256 a408a43f16840475612c7058eb80b53791bf8b8266c5b3cd07f69697958fd97d
MD5 6088c21413a672a2be3f1b9b7d191700
BLAKE2b-256 01be56f946d3af013561d46c95f75880302cab03f1490ef939569852af6331c0

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 0aec6da79a6130b4c76073241c0f31c11b96a38e70c7a00f9ed918d7464394ab
MD5 bc7bd713fe42331055d7a8b4705e2534
BLAKE2b-256 6f22c65ef2b6b191119a90223226b4a02100a9c9dd3a38e8410e473bd1653eff

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 28f083ede9be89503357a6b9e5d304826701596abe13d33e8f6fa2cd85b407fc
MD5 01c7758e72ac5b6abc5f10764c38a8b8
BLAKE2b-256 7105ea228c1677a53855572a0ebb0c4e2a3e5d8e792d59e2b536ef50a9a02495

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 111172446e8a4f0d3be13a853fa28cb46b5679a1c7ca15b2e6db2b43dbbf9efb
MD5 3f1329fd0a7101c7992f44f5c0c5471d
BLAKE2b-256 fa7cb7bedf44dbc54c55b8a408a4f978d9bb1ffbfb376093c33fc8576b1848dd

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp312-cp312-macosx_10_13_universal2.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp312-cp312-macosx_10_13_universal2.whl
Algorithm Hash digest
SHA256 0914ba436d6c0781dc9b04e3b95e06fe5c4fc6a87e94893da971805a3790efe8
MD5 c72dfa69a49bda4cbbe2ddf6c4d0826b
BLAKE2b-256 c6e7f605320e064ba54078f2966a9034fa2b3fc47db1e728e07a2a38b2e9075f

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp311-cp311-win_arm64.whl.

File metadata

  • Download URL: optree-0.13.1-cp311-cp311-win_arm64.whl
  • Upload date:
  • Size: 292.2 kB
  • Tags: CPython 3.11, Windows ARM64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp311-cp311-win_arm64.whl
Algorithm Hash digest
SHA256 cf85ba1a7d80b6dc19ef5ca4c17d2ff0290dc9306c5b8b468d51cede287f3c8d
MD5 6cc60374dd138093eea14b59e8c21e4d
BLAKE2b-256 beccb6dcb33954a95ad7c3b643175778b46ce25629bba038e1a1fd5ae3d4803b

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp311-cp311-win_amd64.whl.

File metadata

  • Download URL: optree-0.13.1-cp311-cp311-win_amd64.whl
  • Upload date:
  • Size: 292.2 kB
  • Tags: CPython 3.11, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 0adc896018f34b5f37f6c92c35ae639877578725c5281cc9d4a0ac2ab2c46f77
MD5 034d901c8c533540509ad74a476708e0
BLAKE2b-256 2b14c76d594bf85178d5d616bca143610619174dc3acd097a29e427b8ddd3fd2

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp311-cp311-win32.whl.

File metadata

  • Download URL: optree-0.13.1-cp311-cp311-win32.whl
  • Upload date:
  • Size: 260.9 kB
  • Tags: CPython 3.11, Windows x86
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp311-cp311-win32.whl
Algorithm Hash digest
SHA256 f74fb880472572d550d85d2f1563365b6f194e2157a7703790cbd54d9ab5cf29
MD5 a2ee3a7eb8d59c236e09715b4d355658
BLAKE2b-256 4c37400aa4a413a4886ae92221224e73e474f527203d5031f1041b1e5a5082dd

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 efbffeec15e4a79ed9921dc2227cbba1b64db353c4b72ce4ce83e62fbce9e652
MD5 e5eb07215a542e479131f502bd65ac1e
BLAKE2b-256 57d760f5e9ca2b94face19b9e5ba6ded59eddf94e349cdd26a317f1a8f1aef3b

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl
Algorithm Hash digest
SHA256 5dec0785bc4bbcabecd7e82be3f189b21f3ce8a1244b243009736912a6d8f737
MD5 5b4229973d1aaae2a7b49bf11a87682b
BLAKE2b-256 c78c42b3c398b2096dbffc4f6c3319def09884e688a0e3339c4c8a42e74c8e43

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl
Algorithm Hash digest
SHA256 360f2e8f7eb22ff131bc7e3e241035908e6b47d41372eb3d68d77bc7036ddb30
MD5 bcc57a0259b2fdc6af29db064c87eaed
BLAKE2b-256 f5a5ab3d146ecf4b34ec4f660ee56ef05eff7a79663b0cf16bfac02dac12455c

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 b21ac55473476007e317500fd5851d0a0d695a0c51742bd65fe7347d18530da2
MD5 43a0a6454332c9feae88a91f3500ad39
BLAKE2b-256 3738958677663cd988af5d401d7280d2756cc1de5e5c3139327981cc10900e5a

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 5f94a627c5a2fb776bbfa8f7558db5b918916d37586ba943e74e5f22789c4301
MD5 ffb816fa977eb4e12f76eda26bb637a8
BLAKE2b-256 17871b8b457b5e0446421d3001c4f64a76ae63f0fbe4847365123fc2fb087c8c

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 6bc9aae5ee17a38e3657c8c5db1a60923cc10debd177f6781f352362a846feeb
MD5 c15749aef8ccf242c126b73371a23ee9
BLAKE2b-256 27306fe920c811b19dc1465ab0627db4133993f58fdc38ed929f4afb308a61fe

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 c84ecb6977ba7f5d4ba24d0312cbffb74c6860237572701c2716bd811ca9b226
MD5 69fe3cd1578c70c9f981eb850d3e0b00
BLAKE2b-256 d9c15723dcb9e065f1fdff996b16c958013185e3b2c0f9da0a199b0ca5851f05

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp310-cp310-win_arm64.whl.

File metadata

  • Download URL: optree-0.13.1-cp310-cp310-win_arm64.whl
  • Upload date:
  • Size: 282.7 kB
  • Tags: CPython 3.10, Windows ARM64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp310-cp310-win_arm64.whl
Algorithm Hash digest
SHA256 d0c5a389c108367007151bcfef494f8c2674e4aa23d80ac9163876f5b213dfb6
MD5 006b2f426d97be751f398aa9871f6587
BLAKE2b-256 208556267883d1ee350be1561781e0fd4d16e1ee80825a9633e58926ca92647d

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp310-cp310-win_amd64.whl.

File metadata

  • Download URL: optree-0.13.1-cp310-cp310-win_amd64.whl
  • Upload date:
  • Size: 282.7 kB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 64032b77420410c3d315a4b9bcbece15853432c155613bb4261d87809b3ee357
MD5 54cee74bce1c7a1c351bbf10f4bad290
BLAKE2b-256 5c25dcc520ece35026e44ca7dc75f246eb132f6e29b3a52b989e180e1b05846c

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp310-cp310-win32.whl.

File metadata

  • Download URL: optree-0.13.1-cp310-cp310-win32.whl
  • Upload date:
  • Size: 255.1 kB
  • Tags: CPython 3.10, Windows x86
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp310-cp310-win32.whl
Algorithm Hash digest
SHA256 135e29e0a69149958003443d43f49af0ebb65f03ae52cddf4142e94d5a36b0c8
MD5 b74b1db93c05d3dc898020f67652342b
BLAKE2b-256 77eca0c5dad12c097d543fa376d02d1a12294ab1e1cc78fde9c83587f0676189

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 cfdf7f5cfb5f9b1c0188c667a3dc56551e60a52a918cb8600f84e2f0ad882106
MD5 5b97c8d8672b37a56a7a1452a61d8a9c
BLAKE2b-256 b1c970e551db94823262c520149d278aea055a46ac46c9dc23e08831a15ee1a7

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl
Algorithm Hash digest
SHA256 940c739c9957404a9bbe40ed9289792adaf476cece59eca4fe2f32137fa15a8d
MD5 32840d74a3866680b9282f4be4defce2
BLAKE2b-256 0ebf5b08425c74494f1a61c70361c28da4040e94fd0595ed5687ed69f7981c76

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl
Algorithm Hash digest
SHA256 1d74ff3dfe8599935d52b26a2fe5a43242b4d3f47be6fc1c5ce34c25e116d616
MD5 accb2315011e2683ae8d9cd6bc3ac390
BLAKE2b-256 e88a9d0af81397f511d713b53ac7fdc758d82e39ed8f743bea0a316936d5c3b7

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 d1844b966bb5c95b64af5c6f92f99e4037452b92b18d060fbd80097b5b773d86
MD5 00d5e37aaca4df9f3cc77a0bb697fb09
BLAKE2b-256 695b11bdfeb4d4580f794f58c94a3c4c77c0732bfccbbdd206924e983270f78a

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 34b4dd0f5d73170c7740726cadfca973220ccbed9559beb51fab446d9e584d0a
MD5 8a016d48218f5169a86515ccedbb8420
BLAKE2b-256 467e2e4bc5920078bbafa9a640d99e262c3a003455cfcd2e646f3c339fd48bdb

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 a3058e2d6a6a7d6362d40f7826258204d9fc2cc4cc8f72eaa3dbff14b6622025
MD5 eee5c8b167e41f629380beed66e53d4a
BLAKE2b-256 b81b1689753799ff4fd0f72148673e2bed3d99596f14ad17b93490c5b1885588

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp310-cp310-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp310-cp310-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 f8e2a546cecc5077ec7d4fe24ec8aede43ca8555b832d115f1ebbb4f3b35bc78
MD5 34df6bcbc5aa6f2acf338e328859a1ba
BLAKE2b-256 26dbaf7430add026a1eed470492441c6eba891ae4d0e44e9ebb1ab7d829e4e83

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp39-cp39-win_arm64.whl.

File metadata

  • Download URL: optree-0.13.1-cp39-cp39-win_arm64.whl
  • Upload date:
  • Size: 277.7 kB
  • Tags: CPython 3.9, Windows ARM64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp39-cp39-win_arm64.whl
Algorithm Hash digest
SHA256 04252b5f24e5dae716647848b302f5f7849ecb028f8c617666d1b89a42eb988b
MD5 9cfd9f4a0357ab2320f3afb5d6d29c3c
BLAKE2b-256 5e4c3202e0f4e438e9cb238ca2dd12b0fba1c44bb222ff1966b667939e39ad1e

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp39-cp39-win_amd64.whl.

File metadata

  • Download URL: optree-0.13.1-cp39-cp39-win_amd64.whl
  • Upload date:
  • Size: 277.7 kB
  • Tags: CPython 3.9, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 2cba7ca4cf991270a9fdd080b091d2cbdbcbf27858acebda6af40ff57312d1ea
MD5 6f38f987ebcc61bbad3160a6ad0cfc86
BLAKE2b-256 b192ce6af6dc4e896215a48b068d20ee2ab0353c1d28217a6f7ab8f020a7f5b2

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp39-cp39-win32.whl.

File metadata

  • Download URL: optree-0.13.1-cp39-cp39-win32.whl
  • Upload date:
  • Size: 255.1 kB
  • Tags: CPython 3.9, Windows x86
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp39-cp39-win32.whl
Algorithm Hash digest
SHA256 363939b255a9fa0e077d8297a8301857c859592fc581cee19ec9238e0c145c4a
MD5 05a06dd447238843936cdc0f00e2a4aa
BLAKE2b-256 67ea095ee811749b6af620bf4df8c0a4668abf7c06b819dfd6c70d98039fc4e9

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 1b291aed475ca5992a0c587ca4b72f074724209e01afca9d015c9a5b2089c68d
MD5 c59badb15add8af733bab7d85ac1f05f
BLAKE2b-256 4466083e4396b0ced924cc920cfe86f00cf140e0e58c93b91a64d06f77205f07

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl
Algorithm Hash digest
SHA256 30b02951c48ecca6fbeb6a3cc7a858267c4d82d1c874481a639061e845168da5
MD5 d13d68f3f3292c4dc0097325debf33da
BLAKE2b-256 db80f4a6a7f06601c196272131a42c7f7773fd823167525c7936f6a4d0fb05a8

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl
Algorithm Hash digest
SHA256 100d70cc57af5284649f881e6b266fee3a3e86e82024484eaa64ee18d1587e42
MD5 87836ab3ab4cd7f56e752c64b44f0898
BLAKE2b-256 8cc4df3255e3c3767e07be7077bf0b49f41241a333fee06daf5b678f19e56f60

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 5569b95e214d20a1b7acb7d9477fabbd709d334bc34f3257368ea1418b811a44
MD5 a121bf6b0d1846263f319773acb28c52
BLAKE2b-256 a6a96656a6ab74a81c6ca8a18d95de1e02747e4aca7f71b4fb8d1627650b6720

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 aee696272eece657c2b9e3cf079d8fc7cbbcc8a5c8199dbcd0960ddf7e672fe9
MD5 9d90751f429a2be0c4962554c11be45c
BLAKE2b-256 46564d3eb5656ac7cf941be52f4453e2beb1dea15f40e0208700de985446fd72

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp39-cp39-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp39-cp39-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 7abf1c6fe42cb112f0fb169f80d7b26476fa44226d2caf3727b49d210bdc3343
MD5 2f6f7f8ac653a06679b8512747ef9724
BLAKE2b-256 652878f90ea721e4b880305a98e958a33194d52572fa53f7578f58ca30373c8f

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp39-cp39-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp39-cp39-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 b94f9081cd810a59faae4dbac8f0447e59ce0fb2d70cfb388dc123c33a9fd1a8
MD5 68a71dbea9c6a13f96acfe4266f989e1
BLAKE2b-256 5b7c227664ad981e2c6f17cee6c05a62728e15cf77bc743b50d450f661b8cf97

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp38-cp38-win_amd64.whl.

File metadata

  • Download URL: optree-0.13.1-cp38-cp38-win_amd64.whl
  • Upload date:
  • Size: 282.6 kB
  • Tags: CPython 3.8, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 f2a9eadcab78ccc04114a6916e9decdbc886bbe04c1b7a7bb32e723209162998
MD5 028ea51cb790f3c97118d510e83d27c3
BLAKE2b-256 55eda470e2c7d1290592f47dee6e1c1a89645f8005be07ee94b96b0c83250784

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp38-cp38-win32.whl.

File metadata

  • Download URL: optree-0.13.1-cp38-cp38-win32.whl
  • Upload date:
  • Size: 254.9 kB
  • Tags: CPython 3.8, Windows x86
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp38-cp38-win32.whl
Algorithm Hash digest
SHA256 5c950c85561c47efb3b1a3771ed1b2b2339bd5e28a0ca42bdcedadccc645eeac
MD5 dd426e4cfbe72f0f64bbcc75c72f89cd
BLAKE2b-256 991d7858a55c2e1cda63453d75b677590eb19735c1b3a315a2d6a114728dc5c2

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 2063234ef4d58f11277e157d1cf066a8bd07be911da226bff84fc9761b8c1a25
MD5 688c710b3e26fd26193c843f884e460a
BLAKE2b-256 e19faed0f65bf7bda78837b9a319f369514efe1a52a535ed05912686c4e31ed0

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl
Algorithm Hash digest
SHA256 8d89891e11a55ad83ab3e2810f8571774b2117a6198b4044fa44e0f37f72855e
MD5 963cb25728d8071e0f1fa00cf4ae869b
BLAKE2b-256 a0368e384691492fd1a65efda5477630ed9fcbaf540e7475f9ca78ade5633010

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl
Algorithm Hash digest
SHA256 de1ae16ea0410497e50fe2b4d48a83c37bfc87da76e1e82f9cc8c800b4fc8be6
MD5 7d38633b010ec295eaf4b43ffb4ae233
BLAKE2b-256 9f2f542cc3c4e985e00df1e85e25163a20ee3a4e39b64192765d0429d79961f9

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 1891267f9dc76e9ddfed947ff7b755ad438ad483de0537a6b5bcf38478d5a33c
MD5 54bd49fc8b573836602f1e3ec7628730
BLAKE2b-256 67238ad4292a384491becc73a0742924146f6776c632df8b3797458e1b59bc26

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 84a6a974aa9dc4119fe502865c8e1755090ac17dbb53a964619a8ece1130831e
MD5 726a0bfd241aab67bf75c3fbfec2c887
BLAKE2b-256 605d1df478febd326316c2fb3f2759237139730099f8294f5f703bdefaada7b1

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp38-cp38-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp38-cp38-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 536ecf0e555432cc939d958590e33e00e75cc254ab0dd269e84fc9de8352db61
MD5 f800d299cba9175bdcf023cbc9d13f9c
BLAKE2b-256 5fd182e34422cdbbfaf4dc8bf2b6c7dee99a3636a1bb72703fd9915a53f5e2c0

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp38-cp38-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp38-cp38-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 50dd6a9c8ccef267ab4941f07eac53faf6a00666dce4d209da20525570ffaca3
MD5 cd31172b9fe128b378361e69e4a7ac0e
BLAKE2b-256 7359125198280db2b9e090e25283344d2f03c08f0290f29ca9a8cc75a0608fea

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp37-cp37m-win_amd64.whl.

File metadata

  • Download URL: optree-0.13.1-cp37-cp37m-win_amd64.whl
  • Upload date:
  • Size: 281.0 kB
  • Tags: CPython 3.7m, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp37-cp37m-win_amd64.whl
Algorithm Hash digest
SHA256 7e1c1da6574d59073b6a6b9a13633217f584ec271ddee4e014c7e422f171e9b4
MD5 b3d9f550f4dc7353ae0ea0c593140321
BLAKE2b-256 9a0f6d0115c5dece1ae5aa51c934fffe43b49dd367160d37805d7b837ce35270

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp37-cp37m-win32.whl.

File metadata

  • Download URL: optree-0.13.1-cp37-cp37m-win32.whl
  • Upload date:
  • Size: 257.9 kB
  • Tags: CPython 3.7m, Windows x86
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for optree-0.13.1-cp37-cp37m-win32.whl
Algorithm Hash digest
SHA256 63b2749504fe0b9ac3892e26bf55a040ae2973bcf8da1476afe9266a4624be9d
MD5 770e9088d0cbf3b670e262098265c7de
BLAKE2b-256 2bb78fea9082dfb15ce92ceb7b88380316b9ae5845bcdd6a67f6a4be54c7acae

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 1496f29d5b9633fed4b3f1fd4b7e772d77200eb2370c08ef8e14404309c669b9
MD5 59be38d76f4cc1c92e1854c70caa2e21
BLAKE2b-256 3b353551871e0c8c494b0661b8a977eb90f02ca705d3e188292ad195d769f681

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl
Algorithm Hash digest
SHA256 c99891c2ea6050738f7e3de5ab4038736cf33555a752b34a06922ebc9bf0488e
MD5 4a6131247b4d1d3b6251bd5883ec3f3f
BLAKE2b-256 0d604a459167761185de5d39234aa84c70d19b9a227576f127d50b054e469eac

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl
Algorithm Hash digest
SHA256 2521840d6aded4dac62c787f50bcb1cacbfcda86b9319d666b4025fa0ba5545a
MD5 41466db8319d6538573dab540cd5d919
BLAKE2b-256 1620504dd7162c7fd835ecb9af3a4f0cede07515161f80d034bcbb8d2c2f4eaf

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 22ce30c9d733c2214fa321c8370e4dfc8c7829970364618b2b5cacffbc9e8949
MD5 3bdc887976dec02ff0643ceba56fa9dc
BLAKE2b-256 a6bfe5ce34e0eab4fa2412ef5621983a18e3fd32d2ee93574317af38299f74bd

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 5d21a8b449e47fdbf118ac1938cf6f97d8a60258bc45c6eba3e61f79feeb1ea8
MD5 3bb013785b3770bbd4741f8e64be1723
BLAKE2b-256 572f1215cc464126f130d2c6bd0147514391faf1781ed7379877c7d08aa462a7

See more details on using hashes here.

File details

Details for the file optree-0.13.1-cp37-cp37m-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for optree-0.13.1-cp37-cp37m-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 9824a4258b058282eeaee1b388c8dfc704e49beda957b99177db8bd8249a3abe
MD5 40336bb822c61dc96d1496a652bc43ea
BLAKE2b-256 6a44b9f9d280339f4a55e7dc7d22dc0ddca6c3dc02b913175259e7a8b9e5a0de

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page