Skip to main content

JAX compatible dataclass.

Project description



Installation |Description |Quick Example |StatefulComputation |More |Acknowledgements

Tests pyver pyver codestyle Open In Colab Downloads codecov Documentation Status GitHub commit activity DOI PyPI

For previous PyTreeClass use v0.1 branch

๐Ÿ› ๏ธ Installation

pip install pytreeclass

Install development version

pip install git+https://github.com/ASEM000/PyTreeClass

๐Ÿ“– Description

PyTreeClass is a JAX-compatible dataclass-like decorator to create and operate on stateful JAX PyTrees.

The package aims to achieve two goals:

  1. ๐Ÿ”’ To maintain safe and correct behaviour by using immutable modules with functional API.
  2. To achieve the most intuitive user experience in the JAX ecosystem by :
    • ๐Ÿ—๏ธ Defining layers similar to PyTorch or TensorFlow subclassing style.
    • โ˜๏ธ Filtering\Indexing layer values similar to jax.numpy.at[].{get,set,apply,...}
    • ๐ŸŽจ Visualize defined layers in plethora of ways.

โฉ Quick Example

๐Ÿ—๏ธ Simple Tree example

Code PyTree representation
import jax
import jax.numpy as jnp
import pytreeclass as pytc

@pytc.treeclass
class Tree:
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jnp.array([4.,5.,6.])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x

tree = Tree()
# leaves are parameters

Tree
    โ”œโ”€โ”€ a=1
    โ”œโ”€โ”€ b:tuple
    โ”‚   โ”œโ”€โ”€ [0]=2.0
    โ”‚   โ””โ”€โ”€ [1]=3.0
    โ””โ”€โ”€ c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])

๐ŸŽจ Visualize

Visualize PyTrees
tree_summary tree_diagram [tree_mermaid](https://mermaid.js.org)(Native support in Github/Notion) tree_repr tree_str
print(pytc.tree_summary(tree, depth=1))
โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
โ”‚Nameโ”‚Type  โ”‚Countโ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚a   โ”‚int   โ”‚1    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚b   โ”‚tuple โ”‚1    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚c   โ”‚f32[3]โ”‚3    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ   โ”‚Tree  โ”‚5    โ”‚
โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜
print(pytc.tree_diagram(tree, depth=1))
Tree
โ”œโ”€โ”€ a=1
โ”œโ”€โ”€ b=(...)
โ””โ”€โ”€ c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])
print(pytc.tree_mermaid(tree, depth=1))
flowchart LR
    id0(<b>Tree</b>)
    id0 --- id1("</b>a=1</b>")
    id0 --- id2("</b>b=(...)</b>")
    id0 --- id3("</b>c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])</b>")
print(pytc.tree_repr(tree, depth=1))
Tree(a=1, b=(...), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))
print(pytc.tree_str(tree, depth=1))
Tree(a=1, b=(...), c=[4. 5. 6.])
print(pytc.tree_summary(tree, depth=2))
โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
โ”‚Nameโ”‚Type  โ”‚Countโ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚a   โ”‚int   โ”‚1    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚b[0]โ”‚float โ”‚1    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚b[1]โ”‚float โ”‚1    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚c   โ”‚f32[3]โ”‚3    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ   โ”‚Tree  โ”‚6    โ”‚
โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜
print(pytc.tree_diagram(tree, depth=2))
Tree
โ”œโ”€โ”€ a=1
โ”œโ”€โ”€ b:tuple
โ”‚   โ”œโ”€โ”€ [0]=2.0
โ”‚   โ””โ”€โ”€ [1]=3.0
โ””โ”€โ”€ c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])
print(pytc.tree_mermaid(tree, depth=2))
flowchart LR
    id2 --- id3("</b>[0]=2.0</b>")
    id2 --- id4("</b>[1]=3.0</b>")
    id0(<b>Tree</b>)
    id0 --- id1("</b>a=1</b>")
    id0 --- id2("</b>b:tuple</b>")
    id0 --- id5("</b>c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])</b>")
print(pytc.tree_repr(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))
print(pytc.tree_str(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])

๐Ÿƒ Working with jax transformation

Make arbitrary PyTrees work with jax transformations

Parameters are defined in Tree at the top of class definition similar to defining dataclasses.dataclass field. Lets optimize our parameters

@jax.grad
def loss_func(tree:Tree, x:jax.Array):
    preds = jax.vmap(tree)(x)  # <--- vectorize the tree call over the leading axis
    return jnp.mean(preds**2)  # <--- return the mean squared error

@jax.jit
def train_step(tree:Tree, x:jax.Array):
    grads = loss_func(tree, x)
    # apply a small gradient step
    return jax.tree_util.tree_map(lambda x, g: x - 1e-3*g, tree, grads)

# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)

for epoch in range(1_000):
    jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))

print(jaxable_tree)
# **the `frozen` params have "#" prefix**
#Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])


# unfreeze the tree
tree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])

โ˜๏ธ Advanced Indexing with .at[]

Out-of-place updates using mask, attribute name or index

PyTreeClass offers 3 means of indexing through .at[]

  1. Indexing by boolean mask.
  2. Indexing by attribute name.
  3. Indexing by Leaf index.

Since treeclass wrapped class are immutable, .at[] operations returns new instance of the tree

Index update by boolean mask

tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4,6]))

# lets create a mask for values > 4
mask = jax.tree_util.tree_map(lambda x: x>4, tree)

print(mask)
# Tree(a=False, b=(False, False), c=[False  True  True])

print(tree.at[mask].get())
# Tree(a=None, b=(None, None), c=[5 6])

print(tree.at[mask].set(10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])

print(tree.at[mask].apply(lambda x: 10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])

Index update by attribute name

tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4,6]))

print(tree.at["a"].get())
# Tree(a=1, b=(None, None), c=None)

print(tree.at["a"].set(10))
# Tree(a=10, b=(2, 3), c=[4 5 6])

print(tree.at["a"].apply(lambda x: 10))
# Tree(a=10, b=(2, 3), c=[4 5 6])

Index update by integer index

tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4,6]))

print(tree.at[1].at[0].get())
# Tree(a=None, b=(2.0, None), c=None)

print(tree.at[1].at[0].set(10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])

print(tree.at[1].at[0].apply(lambda x: 10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])

๐Ÿ“œ Stateful computations

First, Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using @pytc.treeclass no need to separate the instance variables ; instead the whole instance is passed as a state.

Using the following pattern,Updating state functionally can be achieved under jax.jit

import jax
import pytreeclass as pytc

@pytc.treeclass
class Counter:
    calls : int = 0

    def increment(self):
        self.calls += 1
counter = Counter() # Counter(calls=0)

Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at. To achieve this we can use .at[method_name].__call__(*args,**kwargs), this functional call will return the value of this call and a new model instance with the update state.

@jax.jit
def update(counter):
    value, new_counter = counter.at["increment"]()
    return new_counter

for i in range(10):
    counter = update(counter)

print(counter.calls) # 10

โž• More

[Advanced] Register custom user-defined classes to work with visualization and indexing tools.

Similar to jax.tree_util.register_pytree_node, PyTreeClass register common data structures and treeclass wrapped classes to figure out how to define the names, types, index, and metadatas of certain leaf along its path.

Here is an example of registering

class Tree:
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(a={self.a}, b={self.b})"


# jax flatten rule
def tree_flatten(tree):
    return (tree.a, tree.b), None

# jax unflatten rule
def tree_unflatten(_, children):
    return Tree(*children)

# PyTreeClass flatten rule
def pytc_tree_flatten(tree):
    names = ("a", "b")
    types = (type(tree.a), type(tree.b))
    indices = (0,1)
    metadatas = (None, None)
    return [*zip(names, types, indices, metadatas)]


# Register with `jax`
jax.tree_util.register_pytree_node(Tree, tree_flatten, tree_unflatten)

# Register the `Tree` class trace function to support indexing
pytc.register_pytree_node_trace(Tree, pytc_tree_flatten)

tree = Tree(1, 2)

# works with jax
jax.tree_util.tree_leaves(tree)  # [1, 2]

# works with PyTreeClass viz tools
print(pytc.tree_summary(tree))

# โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”
# โ”‚Nameโ”‚Typeโ”‚Countโ”‚Size  โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚a   โ”‚int โ”‚1    โ”‚28.00Bโ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚b   โ”‚int โ”‚1    โ”‚28.00Bโ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚ฮฃ   โ”‚Treeโ”‚2    โ”‚56.00Bโ”‚
# โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”˜

After registeration, you can use internal tools like

  • pytc.tree_map_with_trace
  • pytc.tree_leaves_with_trace
  • pytc.tree_flatten_with_trace

More details on that soon.

Validate or convert inputs using callbacks

PyTreeClass includes callbacks in the field to apply a sequence of functions on input at setting the attribute stage. The callback is quite useful in several cases, for instance, to ensure a certain input type within a valid range. See example:

import jax
import pytreeclass as pytc

def positive_int_callback(value):
    if not isinstance(value, int):
        raise TypeError("Value must be an integer")
    if value <= 0:
        raise ValueError("Value must be positive")
    return value


@pytc.treeclass
class Tree:
    in_features:int = pytc.field(callbacks=[positive_int_callback])


tree = Tree(1)
# no error

tree = Tree(0)
# ValueError: Error for field=`in_features`:
# Value must be positive

tree = Tree(1.0)
# TypeError: Error for field=`in_features`:
# Value must be an integer
Add leafwise math operations to PyTreeClass wrapped class
import functools as ft
import pytreeclass as pytc

@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jnp.array([4.,5.,6.])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x

tree = Tree()

tree + 100
# Tree(a=101, b=(102.0, 103.0), c=f32[3](ฮผ=105.00, ฯƒ=0.82, โˆˆ[104.00,106.00]))

@jax.grad
def loss_func(tree:Tree, x:jax.Array):
    preds = jax.vmap(tree)(x)  # <--- vectorize the tree call over the leading axis
    return jnp.mean(preds**2)  # <--- return the mean squared error

@jax.jit
def train_step(tree:Tree, x:jax.Array):
    grads = loss_func(tree, x)
    return tree - grads*1e-3  # <--- eliminate `tree_map`

# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)

for epoch in range(1_000):
    jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))

print(jaxable_tree)
# **the `frozen` params have "#" prefix**
# Tree(a=#1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783  3.024264 ])


# unfreeze the tree
tree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783  3.024264 ])
Eliminate tree_map using bcmap + treeclass(..., leafwise=True)

TDLR

import functools as ft
import pytreeclass as pytc
import jax.numpy as jnp

@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jnp.array([4.,5.,6.])

tree = Tree()

print(pytc.bcmap(jnp.where)(tree>2, tree+100, 0))
# Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])

bcmap(func, is_leaf) maps a function over PyTrees leaves with automatic broadcasting for scalar arguments.

bcmap is function transformation that broadcast a scalar to match the first argument of the function this enables us to convert a function like jnp.where to work with arbitrary tree structures without the need to write a specific function for each broadcasting case

For example, lets say we want to use jnp.where to zeros out all values in an arbitrary tree structure that are less than 0

tree = ([1], {"a":1, "b":2}, (1,), -1,)

we can use jax.tree_util.tree_map to apply jnp.where to the tree but we need to write a specific function for broadcasting the scalar to the tree

def map_func(leaf):
    # here we encoded the scalar `0` inside the function
    return jnp.where(leaf>0, leaf, 0)

jtu.tree_map(map_func, tree)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(0, dtype=int32, weak_type=True))

However, lets say we want to use jnp.where to set a value to a leaf value from another tree that looks like this

def map_func2(lhs_leaf, rhs_leaf):
    # here we encoded the scalar `0` inside the function
    return jnp.where(lhs_leaf>0, lhs_leaf, rhs_leaf)

tree2 = jtu.tree_map(lambda x: 1000, tree)

jtu.tree_map(map_func2, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(1000, dtype=int32, weak_type=True))

Now, bcmap makes this easier by figuring out the broadcasting case.

broadcastable_where = pytc.bcmap(jnp.where)
mask = jtu.tree_map(lambda x: x>0, tree)

case 1

broadcastable_where(mask, tree, 0)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(0, dtype=int32, weak_type=True))

case 2

broadcastable_where(mask, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(1000, dtype=int32, weak_type=True))

lets then take this a step further to eliminate mask from the equation by using pytreeclass with leafwise=True

@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
    tree : tuple = ([1], {"a":1, "b":2}, (1,), -1,)

tree = Tree()
# Tree(tree=([1], {a:1, b:2}, (1), -1))

case 1: broadcast scalar to tree

print(broadcastable_where(tree>0, tree, 0))
# Tree(tree=([1], {a:1, b:2}, (1), 0))

case 2: broadcast tree to tree
```python
print(broadcastable_where(tree>0, tree, tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))

bcmap also works with all kind of arguments in the wrapped function

print(broadcastable_where(tree>0, x=tree, y=tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))

in concolusion, bcmap is a function transformation that can be used to to make functions work with arbitrary tree structures without the need to write a specific function for each broadcasting case

Moreover, bcmap can be more powerful when used with pytreeclass to facilitate operation of arbitrary functions on PyTree objects without the need to use tree_map

Use PyTreeClass vizualization tools with arbitrary PyTrees
import jax
import pytreeclass as pytc

tree = [1, [2,3], 4]

print(pytc.tree_diagram(tree, depth=1))
# list
# โ”œโ”€โ”€ [0]=1
# โ”œโ”€โ”€ [1]=[...]
# โ””โ”€โ”€ [2]=4

print(pytc.tree_diagram(tree, depth=2))
# list
# โ”œโ”€โ”€ [0]=1
# โ”œโ”€โ”€ [1]:list
# โ”‚   โ”œโ”€โ”€ [0]=2
# โ”‚   โ””โ”€โ”€ [1]=3
# โ””โ”€โ”€ [2]=4


print(pytc.tree_summary(tree, depth=1))
# โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
# โ”‚Nameโ”‚Typeโ”‚Countโ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[0] โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[1] โ”‚listโ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[2] โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚ฮฃ   โ”‚listโ”‚3    โ”‚
# โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜

print(pytc.tree_summary(tree, depth=2))
# โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
# โ”‚Name  โ”‚Typeโ”‚Countโ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[0]   โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[1][0]โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[1][1]โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[2]   โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚ฮฃ     โ”‚listโ”‚4    โ”‚
# โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜
Use PyTreeClass components with other libraries
import jax
import pytreeclass as pytc
from flax import struct

@struct.dataclass
class FlaxTree:
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jax.numpy.array([4.,5.,6.])

    def __repr__(self) -> str:
        return pytc.tree_repr(self)
    def __str__(self) -> str:
        return pytc.tree_str(self)
    @property
    def at(self):
        return pytc.tree_indexer(self)

def pytc_flatten_rule(tree):
    names =("a","b","c")
    types = map(type, (tree.a, tree.b, tree.c))
    indices = range(3)
    metadatas= (None, None, None)
    return [*zip(names, types, indices, metadatas)]

pytc.register_pytree_node_trace(FlaxTree, pytc_flatten_rule)

flax_tree = FlaxTree()

print(f"{flax_tree!r}")
# FlaxTree(a=1, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))

print(f"{flax_tree!s}")
# FlaxTree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])

print(pytc.tree_diagram(flax_tree))
# FlaxTree
# โ”œโ”€โ”€ a=1
# โ”œโ”€โ”€ b:tuple
# โ”‚   โ”œโ”€โ”€ [0]=2.0
# โ”‚   โ””โ”€โ”€ [1]=3.0
# โ””โ”€โ”€ c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])

print(pytc.tree_summary(flax_tree))
# โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
# โ”‚Nameโ”‚Type    โ”‚Countโ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚a   โ”‚int     โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚b[0]โ”‚float   โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚b[1]โ”‚float   โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚c   โ”‚f32[3]  โ”‚3    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚ฮฃ   โ”‚FlaxTreeโ”‚6    โ”‚
# โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜

flax_tree.at[0].get()
# FlaxTree(a=1, b=(None, None), c=None)

flax_tree.at["a"].set(10)
# FlaxTree(a=10, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))
Benchmark flatten/unflatten compared to Flax and Equinox

Open In Colab

CPUGPU
Use tree_map_with_trace

V0.2 of PyTreeClass register common python datatypes and treeclass wrapped class to trace registry. While jax uses jax.tree_util.register_pytree_node to define flatten_rule for leaves, PyTreeClass extends on this By registering the flatten_rule of (1) names, (2) types, (3) indexing, (4) metadata -if exists-

For demonstration , the following figure contains the 4 variants of the same Tree instance define

import jax
import jax.numpy as jnp
import pytreeclass as pytc

@pytc.treeclass
class Tree:
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jnp.array([4.,5.,6.])

image

  1. Value leaves variant.
  2. Name leaves variant.
  3. Type leaves variant.
  4. Indexing leaves variant.

The four variants can be accessed using pytc.tree_map_with_trace . Similar to jax.tree_util.tree_map, pytc.tree_map_with_trace accepts the map function, however the first argument must be the trace argument. Trace is a four item tuple consists of names,types,indices,metadatas path for each leaf. For example for the previous tree, the reuslting trace path for each leaf is :

Named tree variant

>>> name_tree = pytc.tree_map_with_trace(lambda trace,x : trace[0], tree)
>>> print(name_tree)
Tree(a=(a), b=((b, [0]), (b, [1])), c=(c))

Typed tree variant

>>> type_tree = pytc.tree_map_with_trace(lambda trace,x : f"{trace[1]!s}", tree)
>>> print(type_tree)
Tree(
  a=(<class 'int'>,),
  b=((<class 'tuple'>, <class 'float'>), (<class 'tuple'>, <class 'float'>)),
  c=(<class 'jaxlib.xla_extension.ArrayImpl'>,)
)

Index tree variant

>>> index_tree = pytc.tree_map_with_trace(lambda trace,x : trace[2], tree)
>>> print(index_tree)
Tree(a=(0), b=((1, 0), (1, 1)), c=(2))

In essence, each leaf contains information about the name path, type path, and indices path. The rules for custom types can be registered using pytc.register_pytree_node_trace

Comparison with dataclass
PyTreeClass dataclass
Generated init method โœ… โœ…
Generated repr method โœ… โœ…
Generated str method โœ…
Generated hash method โœ… โœ…
Generated eq method โœ…โœ…* โœ…
Support slots โœ…
Keyword-only args โœ… โœ… 3.10+
Positional-only args โœ…
Frozen instance โœ…** โœ…
Match args support โœ…
Support callbacks โœ…

* Either compare the whole instance and return True/False or treating it leafwise using treeclass(..., leafwise=True) and retrurn Tree(a=True, ....)

** Always frozen. non-frozen is not supported.

*** treeclass decorator is also a bit faster than dataclasses.dataclass Open In Colab

๐Ÿ“™ Acknowledgements

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytreeclass-0.2.6.tar.gz (58.9 kB view hashes)

Uploaded Source

Built Distribution

pytreeclass-0.2.6-py3-none-any.whl (57.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page