JAX compatible dataclass.
Project description
Installation |Description |Quick Example |StatefulComputation |More |Acknowledgements
For previous PyTreeClass
use v0.1 branch
๐ ๏ธ Installation
pip install pytreeclass
Install development version
pip install git+https://github.com/ASEM000/PyTreeClass
๐ Description
PyTreeClass
is a JAX-compatible dataclass
-like decorator to create and operate on stateful JAX PyTrees.
The package aims to achieve two goals:
- ๐ To maintain safe and correct behaviour by using immutable modules with functional API.
- To achieve the most intuitive user experience in the
JAX
ecosystem by :- ๐๏ธ Defining layers similar to
PyTorch
orTensorFlow
subclassing style. - โ๏ธ Filtering\Indexing layer values similar to
jax.numpy.at[].{get,set,apply,...}
- ๐จ Visualize defined layers in plethora of ways.
- ๐๏ธ Defining layers similar to
โฉ Quick Example
๐๏ธ Simple Tree example
Code | PyTree representation |
import jax
import jax.numpy as jnp
import pytreeclass as pytc
@pytc.treeclass
class Tree:
a:int = 1
b:tuple[float] = (2.,3.)
c:jax.Array = jnp.array([4.,5.,6.])
def __call__(self, x):
return self.a + self.b[0] + self.c + x
tree = Tree()
|
# leaves are parameters
Tree
โโโ a=1
โโโ b:tuple
โ โโโ [0]=2.0
โ โโโ [1]=3.0
โโโ c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00])
|
๐จ Visualize
Visualize PyTrees
tree_summary | tree_diagram | [tree_mermaid](https://mermaid.js.org)(Native support in Github/Notion) | tree_repr | tree_str |
print(pytc.tree_summary(tree, depth=1))
โโโโโโฌโโโโโโโฌโโโโโโ
โNameโType โCountโ
โโโโโโผโโโโโโโผโโโโโโค
โa โint โ1 โ
โโโโโโผโโโโโโโผโโโโโโค
โb โtuple โ1 โ
โโโโโโผโโโโโโโผโโโโโโค
โc โf32[3]โ3 โ
โโโโโโผโโโโโโโผโโโโโโค
โฮฃ โTree โ5 โ
โโโโโโดโโโโโโโดโโโโโโ
|
print(pytc.tree_diagram(tree, depth=1))
Tree
โโโ a=1
โโโ b=(...)
โโโ c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00])
|
print(pytc.tree_mermaid(tree, depth=1))
flowchart LR
id0(<b>Tree</b>)
id0 --- id1("</b>a=1</b>")
id0 --- id2("</b>b=(...)</b>")
id0 --- id3("</b>c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00])</b>")
|
print(pytc.tree_repr(tree, depth=1))
Tree(a=1, b=(...), c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00]))
|
print(pytc.tree_str(tree, depth=1))
Tree(a=1, b=(...), c=[4. 5. 6.])
|
print(pytc.tree_summary(tree, depth=2))
โโโโโโฌโโโโโโโฌโโโโโโ
โNameโType โCountโ
โโโโโโผโโโโโโโผโโโโโโค
โa โint โ1 โ
โโโโโโผโโโโโโโผโโโโโโค
โb[0]โfloat โ1 โ
โโโโโโผโโโโโโโผโโโโโโค
โb[1]โfloat โ1 โ
โโโโโโผโโโโโโโผโโโโโโค
โc โf32[3]โ3 โ
โโโโโโผโโโโโโโผโโโโโโค
โฮฃ โTree โ6 โ
โโโโโโดโโโโโโโดโโโโโโ
|
print(pytc.tree_diagram(tree, depth=2))
Tree
โโโ a=1
โโโ b:tuple
โ โโโ [0]=2.0
โ โโโ [1]=3.0
โโโ c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00])
|
print(pytc.tree_mermaid(tree, depth=2))
flowchart LR
id2 --- id3("</b>[0]=2.0</b>")
id2 --- id4("</b>[1]=3.0</b>")
id0(<b>Tree</b>)
id0 --- id1("</b>a=1</b>")
id0 --- id2("</b>b:tuple</b>")
id0 --- id5("</b>c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00])</b>")
|
print(pytc.tree_repr(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00]))
|
print(pytc.tree_str(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])
|
๐ Working with jax
transformation
Make arbitrary PyTrees work with jax transformations
Parameters are defined in Tree
at the top of class definition similar to defining
dataclasses.dataclass
field.
Lets optimize our parameters
@jax.grad
def loss_func(tree:Tree, x:jax.Array):
preds = jax.vmap(tree)(x) # <--- vectorize the tree call over the leading axis
return jnp.mean(preds**2) # <--- return the mean squared error
@jax.jit
def train_step(tree:Tree, x:jax.Array):
grads = loss_func(tree, x)
# apply a small gradient step
return jax.tree_util.tree_map(lambda x, g: x - 1e-3*g, tree, grads)
# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)
for epoch in range(1_000):
jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))
print(jaxable_tree)
# **the `frozen` params have "#" prefix**
#Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778 3.4190805])
# unfreeze the tree
tree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778 3.4190805])
โ๏ธ Advanced Indexing with .at[]
Out-of-place updates using mask, attribute name or index
PyTreeClass
offers 3 means of indexing through .at[]
- Indexing by boolean mask.
- Indexing by attribute name.
- Indexing by Leaf index.
Since treeclass
wrapped class are immutable, .at[]
operations returns new instance of the tree
Index update by boolean mask
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯ=0.82, โ[4,6]))
# lets create a mask for values > 4
mask = jax.tree_util.tree_map(lambda x: x>4, tree)
print(mask)
# Tree(a=False, b=(False, False), c=[False True True])
print(tree.at[mask].get())
# Tree(a=None, b=(None, None), c=[5 6])
print(tree.at[mask].set(10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])
print(tree.at[mask].apply(lambda x: 10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])
Index update by attribute name
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯ=0.82, โ[4,6]))
print(tree.at["a"].get())
# Tree(a=1, b=(None, None), c=None)
print(tree.at["a"].set(10))
# Tree(a=10, b=(2, 3), c=[4 5 6])
print(tree.at["a"].apply(lambda x: 10))
# Tree(a=10, b=(2, 3), c=[4 5 6])
Index update by integer index
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯ=0.82, โ[4,6]))
print(tree.at[1].at[0].get())
# Tree(a=None, b=(2.0, None), c=None)
print(tree.at[1].at[0].set(10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
print(tree.at[1].at[0].apply(lambda x: 10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
๐ Stateful computations
First, Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using @pytc.treeclass no need to separate the instance variables ; instead the whole instance is passed as a state.
Using the following pattern,Updating state functionally can be achieved under jax.jit
import jax
import pytreeclass as pytc
@pytc.treeclass
class Counter:
calls : int = 0
def increment(self):
self.calls += 1
counter = Counter() # Counter(calls=0)
Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at
. To achieve this we can use .at[method_name].__call__(*args,**kwargs)
, this functional call will return the value of this call and a new model instance with the update state.
@jax.jit
def update(counter):
value, new_counter = counter.at["increment"]()
return new_counter
for i in range(10):
counter = update(counter)
print(counter.calls) # 10
โ More
[Advanced] Register custom user-defined classes to work with visualization and indexing tools.
Similar to jax.tree_util.register_pytree_node
, PyTreeClass
register common data structures and treeclass
wrapped classes to figure out how to define the names, types, index, and metadatas of certain leaf along its path.
Here is an example of registering
class Tree:
def __init__(self, a, b):
self.a = a
self.b = b
def __repr__(self) -> str:
return f"{self.__class__.__name__}(a={self.a}, b={self.b})"
# jax flatten rule
def tree_flatten(tree):
return (tree.a, tree.b), None
# jax unflatten rule
def tree_unflatten(_, children):
return Tree(*children)
# PyTreeClass flatten rule
def pytc_tree_flatten(tree):
names = ("a", "b")
types = (type(tree.a), type(tree.b))
indices = (0,1)
metadatas = (None, None)
return [*zip(names, types, indices, metadatas)]
# Register with `jax`
jax.tree_util.register_pytree_node(Tree, tree_flatten, tree_unflatten)
# Register the `Tree` class trace function to support indexing
pytc.register_pytree_node_trace(Tree, pytc_tree_flatten)
tree = Tree(1, 2)
# works with jax
jax.tree_util.tree_leaves(tree) # [1, 2]
# works with PyTreeClass viz tools
print(pytc.tree_summary(tree))
# โโโโโโฌโโโโโฌโโโโโโฌโโโโโโโ
# โNameโTypeโCountโSize โ
# โโโโโโผโโโโโผโโโโโโผโโโโโโโค
# โa โint โ1 โ28.00Bโ
# โโโโโโผโโโโโผโโโโโโผโโโโโโโค
# โb โint โ1 โ28.00Bโ
# โโโโโโผโโโโโผโโโโโโผโโโโโโโค
# โฮฃ โTreeโ2 โ56.00Bโ
# โโโโโโดโโโโโดโโโโโโดโโโโโโโ
After registeration, you can use internal tools like
pytc.tree_map_with_trace
pytc.tree_leaves_with_trace
pytc.tree_flatten_with_trace
More details on that soon.
Validate or convert inputs using callbacks
PyTreeClass
includes callbacks
in the field
to apply a sequence of functions on input at setting the attribute stage. The callback is quite useful in several cases, for instance, to ensure a certain input type within a valid range. See example:
import jax
import pytreeclass as pytc
def positive_int_callback(value):
if not isinstance(value, int):
raise TypeError("Value must be an integer")
if value <= 0:
raise ValueError("Value must be positive")
return value
@pytc.treeclass
class Tree:
in_features:int = pytc.field(callbacks=[positive_int_callback])
tree = Tree(1)
# no error
tree = Tree(0)
# ValueError: Error for field=`in_features`:
# Value must be positive
tree = Tree(1.0)
# TypeError: Error for field=`in_features`:
# Value must be an integer
Add leafwise math operations to PyTreeClass wrapped class
import functools as ft
import pytreeclass as pytc
@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
a:int = 1
b:tuple[float] = (2.,3.)
c:jax.Array = jnp.array([4.,5.,6.])
def __call__(self, x):
return self.a + self.b[0] + self.c + x
tree = Tree()
tree + 100
# Tree(a=101, b=(102.0, 103.0), c=f32[3](ฮผ=105.00, ฯ=0.82, โ[104.00,106.00]))
@jax.grad
def loss_func(tree:Tree, x:jax.Array):
preds = jax.vmap(tree)(x) # <--- vectorize the tree call over the leading axis
return jnp.mean(preds**2) # <--- return the mean squared error
@jax.jit
def train_step(tree:Tree, x:jax.Array):
grads = loss_func(tree, x)
return tree - grads*1e-3 # <--- eliminate `tree_map`
# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)
for epoch in range(1_000):
jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))
print(jaxable_tree)
# **the `frozen` params have "#" prefix**
# Tree(a=#1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783 3.024264 ])
# unfreeze the tree
tree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783 3.024264 ])
Eliminate tree_map using bcmap + treeclass(..., leafwise=True)
TDLR
import functools as ft
import pytreeclass as pytc
import jax.numpy as jnp
@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
a:int = 1
b:tuple[float] = (2.,3.)
c:jax.Array = jnp.array([4.,5.,6.])
tree = Tree()
print(pytc.bcmap(jnp.where)(tree>2, tree+100, 0))
# Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])
bcmap(func, is_leaf)
maps a function over PyTrees leaves with automatic broadcasting for scalar arguments.
bcmap
is function transformation that broadcast a scalar to match the first argument of the function this enables us to convert a function like jnp.where
to work with arbitrary tree structures without the need to write a specific function for each broadcasting case
For example, lets say we want to use jnp.where
to zeros out all values in an arbitrary tree structure that are less than 0
tree = ([1], {"a":1, "b":2}, (1,), -1,)
we can use jax.tree_util.tree_map
to apply jnp.where
to the tree but we need to write a specific function for broadcasting the scalar to the tree
def map_func(leaf):
# here we encoded the scalar `0` inside the function
return jnp.where(leaf>0, leaf, 0)
jtu.tree_map(map_func, tree)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(0, dtype=int32, weak_type=True))
However, lets say we want to use jnp.where
to set a value to a leaf value from another tree that looks like this
def map_func2(lhs_leaf, rhs_leaf):
# here we encoded the scalar `0` inside the function
return jnp.where(lhs_leaf>0, lhs_leaf, rhs_leaf)
tree2 = jtu.tree_map(lambda x: 1000, tree)
jtu.tree_map(map_func2, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(1000, dtype=int32, weak_type=True))
Now, bcmap
makes this easier by figuring out the broadcasting case.
broadcastable_where = pytc.bcmap(jnp.where)
mask = jtu.tree_map(lambda x: x>0, tree)
case 1
broadcastable_where(mask, tree, 0)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(0, dtype=int32, weak_type=True))
case 2
broadcastable_where(mask, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(1000, dtype=int32, weak_type=True))
lets then take this a step further to eliminate mask
from the equation
by using pytreeclass
with leafwise=True
@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
tree : tuple = ([1], {"a":1, "b":2}, (1,), -1,)
tree = Tree()
# Tree(tree=([1], {a:1, b:2}, (1), -1))
case 1: broadcast scalar to tree
print(broadcastable_where(tree>0, tree, 0))
# Tree(tree=([1], {a:1, b:2}, (1), 0))
case 2: broadcast tree to tree
```python
print(broadcastable_where(tree>0, tree, tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))
bcmap
also works with all kind of arguments in the wrapped function
print(broadcastable_where(tree>0, x=tree, y=tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))
in concolusion, bcmap
is a function transformation that can be used to
to make functions work with arbitrary tree structures without the need to write
a specific function for each broadcasting case
Moreover, bcmap
can be more powerful when used with pytreeclass
to
facilitate operation of arbitrary functions on PyTree
objects
without the need to use tree_map
Use PyTreeClass vizualization tools with arbitrary PyTrees
import jax
import pytreeclass as pytc
tree = [1, [2,3], 4]
print(pytc.tree_diagram(tree,depth=1))
# list
# โโโ [0]=1
# โโโ [1]=[..., ...]
# โโโ [2]=4
print(pytc.tree_diagram(tree,depth=2))
# list
# โโโ [0]=1
# โโโ [1]:list
# โ โโโ [0]=2
# โ โโโ [1]=3
# โโโ [2]=4
print(pytc.tree_summary(tree,depth=1))
# โโโโโโฌโโโโโฌโโโโโโฌโโโโโโโโ
# โNameโTypeโCountโSize โ
# โโโโโโผโโโโโผโโโโโโผโโโโโโโโค
# โ[0] โint โ1 โ28.00B โ
# โโโโโโผโโโโโผโโโโโโผโโโโโโโโค
# โ[1] โlistโ2 โ56.00B โ
# โโโโโโผโโโโโผโโโโโโผโโโโโโโโค
# โ[2] โint โ1 โ28.00B โ
# โโโโโโผโโโโโผโโโโโโผโโโโโโโโค
# โฮฃ โlistโ4 โ112.00Bโ
# โโโโโโดโโโโโดโโโโโโดโโโโโโโโ
print(pytc.tree_summary(tree,depth=2))
# โโโโโโโโฌโโโโโฌโโโโโโฌโโโโโโโโ
# โName โTypeโCountโSize โ
# โโโโโโโโผโโโโโผโโโโโโผโโโโโโโโค
# โ[0] โint โ1 โ28.00B โ
# โโโโโโโโผโโโโโผโโโโโโผโโโโโโโโค
# โ[1][0]โint โ1 โ28.00B โ
# โโโโโโโโผโโโโโผโโโโโโผโโโโโโโโค
# โ[1][1]โint โ1 โ28.00B โ
# โโโโโโโโผโโโโโผโโโโโโผโโโโโโโโค
# โ[2] โint โ1 โ28.00B โ
# โโโโโโโโผโโโโโผโโโโโโผโโโโโโโโค
# โฮฃ โlistโ4 โ112.00Bโ
# โโโโโโโโดโโโโโดโโโโโโดโโโโโโโโ
Use PyTreeClass components with other libraries
import jax
import pytreeclass as pytc
from flax import struct
@struct.dataclass
class FlaxTree:
a:int = 1
b:tuple[float] = (2.,3.)
c:jax.Array = jax.numpy.array([4.,5.,6.])
def __repr__(self) -> str:
return pytc.tree_repr(self)
def __str__(self) -> str:
return pytc.tree_str(self)
@property
def at(self):
return pytc.tree_indexer(self)
def pytc_flatten_rule(tree):
names =("a","b","c")
types = map(type, (tree.a, tree.b, tree.c))
indices = range(3)
metadatas= (None, None, None)
return [*zip(names, types, indices, metadatas)]
pytc.register_pytree_node_trace(FlaxTree, pytc_flatten_rule)
flax_tree = FlaxTree()
print(f"{flax_tree!r}")
# FlaxTree(a=1, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00]))
print(f"{flax_tree!s}")
# FlaxTree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])
print(pytc.tree_diagram(flax_tree))
# FlaxTree
# โโโ a=1
# โโโ b:tuple
# โ โโโ [0]=2.0
# โ โโโ [1]=3.0
# โโโ c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00])
print(pytc.tree_summary(flax_tree))
# โโโโโโฌโโโโโโโโโฌโโโโโโฌโโโโโโโ
# โNameโType โCountโSize โ
# โโโโโโผโโโโโโโโโผโโโโโโผโโโโโโโค
# โa โint โ1 โ28.00Bโ
# โโโโโโผโโโโโโโโโผโโโโโโผโโโโโโโค
# โb[0]โfloat โ1 โ24.00Bโ
# โโโโโโผโโโโโโโโโผโโโโโโผโโโโโโโค
# โb[1]โfloat โ1 โ24.00Bโ
# โโโโโโผโโโโโโโโโผโโโโโโผโโโโโโโค
# โc โf32[3] โ3 โ12.00Bโ
# โโโโโโผโโโโโโโโโผโโโโโโผโโโโโโโค
# โฮฃ โFlaxTreeโ6 โ88.00Bโ
# โโโโโโดโโโโโโโโโดโโโโโโดโโโโโโโ
flax_tree.at[0].get()
# FlaxTree(a=1, b=(None, None), c=None)
flax_tree.at["a"].set(10)
# FlaxTree(a=10, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00]))
Use tree_map_with_trace
V0.2 of PyTreeClass
register common python datatypes and treeclass
wrapped class to trace
registry.
While jax
uses jax.tree_util.register_pytree_node
to define flatten_rule
for leaves, PyTreeClass
extends on this
By registering the flatten_rule
of (1) names, (2) types, (3) indexing, (4) metadata -if exists-
For demonstration , the following figure contains the 4 variants of the same Tree
instance define
import jax
import jax.numpy as jnp
import pytreeclass as pytc
@pytc.treeclass
class Tree:
a:int = 1
b:tuple[float] = (2.,3.)
c:jax.Array = jnp.array([4.,5.,6.])
- Value leaves variant.
- Name leaves variant.
- Type leaves variant.
- Indexing leaves variant.
The four variants can be accessed using pytc.tree_map_with_trace
.
Similar to jax.tree_util.tree_map
, pytc.tree_map_with_trace
accepts the map function, however the first argument must be the trace
argument.
Trace is a four item tuple consists of names,types,indices,metadatas path for each leaf.
For example for the previous tree, the reuslting trace path for each leaf is :
Named tree variant
>>> name_tree = pytc.tree_map_with_trace(lambda trace,x : trace[0], tree)
>>> print(name_tree)
Tree(a=(Tree, a), b=((Tree, b, [0]), (Tree, b, [1])), c=(Tree, c))
Typed tree variant
>>> type_tree = pytc.tree_map_with_trace(lambda trace,x : f"{trace[1]!s}", tree)
>>> print(type_tree)
Tree(
a=(<class '__main__.Tree'>, <class 'int'>),
b=(
(<class '__main__.Tree'>, <class 'tuple'>, <class 'float'>),
(<class '__main__.Tree'>, <class 'tuple'>, <class 'float'>)
),
c=(<class '__main__.Tree'>, <class 'jaxlib.xla_extension.ArrayImpl'>)
)
Index tree variant
>>> index_tree = pytc.tree_map_with_trace(lambda trace,x : trace[2], tree)
>>> print(index_tree)
Tree(a=(0, 0), b=((0, 1, 0), (0, 1, 1)), c=(0, 2))
In essence, each leaf contains information about the name path, type path, and indices path. The rules for custom types can be registered using pytc.register_pytree_node_trace
Comparison with dataclass
PyTreeClass | dataclass | |
Generated init method | โ | โ |
Generated repr method | โ | โ |
Generated str method | โ | |
Generated hash method | โ | โ |
Generated eq method | โ โ * | โ |
Support slots | โ | |
Keyword-only args | โ | โ 3.10+ |
Positional-only args | โ | |
Frozen instance | โ ** | โ |
Match args support | โ | โ |
Support callbacks | โ |
*
Either compare the whole instance and return True/False
or treating it leafwise using treeclass(..., leafwise=True)
and retrurn Tree(a=True, ....)
**
Always frozen. non-frozen is not supported.
***
treeclass
decorator is also a bit faster than dataclasses.dataclass
๐ Acknowledgements
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytreeclass-0.2.4-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5831f768eb08a81105f4bd4ca920af1badcccbeaa7f51348dec3c05ef87013c1 |
|
MD5 | 6f79412d40288bea7c06b0347dd37343 |
|
BLAKE2b-256 | 43080900e4b15a5af4b28f2d5706ed89a37a12b36815ce642570201c58cad584 |