Skip to main content

JAX compatible dataclass.

Project description



Installation |Description |Quick Example |StatefulComputation |More |Acknowledgements

Tests pyver pyver codestyle Open In Colab Downloads codecov Documentation Status GitHub commit activity DOI PyPI

For previous PyTreeClass use v0.1 branch

๐Ÿ› ๏ธ Installation

pip install pytreeclass

Install development version

pip install git+https://github.com/ASEM000/PyTreeClass

๐Ÿ“– Description

PyTreeClass is a JAX-compatible class builder to create and operate on stateful JAX PyTrees.

The package aims to achieve two goals:

  1. ๐Ÿ”’ To maintain safe and correct behaviour by using immutable modules with functional API.
  2. To achieve the most intuitive user experience in the JAX ecosystem by :
    • ๐Ÿ—๏ธ Defining layers similar to PyTorch or TensorFlow subclassing style.
    • โ˜๏ธ Filtering\Indexing layer values similar to jax.numpy.at[].{get,set,apply,...}
    • ๐ŸŽจ Visualize defined layers in plethora of ways.

โฉ Quick Example

๐Ÿ—๏ธ Simple Tree example

import jax
import jax.numpy as jnp
import pytreeclass as pytc


class Tree(pytc.TreeClass):
    a: int = 1
    b: tuple = (2, 3.)
    c: jax.Array = jnp.array([4., 5., 6.])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x


tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree \
       .at["a"].set(10) \
       .at["b"].at[0].set(10) \
       .at[mask].set(100)

print(tree)
# Tree(a=10, b=(10, 3.0), c=[  4.   5. 100.])

print(pytc.tree_diagram(tree))
# Tree
# โ”œโ”€โ”€ .a=10
# โ”œโ”€โ”€ .b:tuple
# โ”‚   โ”œโ”€โ”€ [0]=10
# โ”‚   โ””โ”€โ”€ [1]=3.0
# โ””โ”€โ”€ .c=f32[3](ฮผ=36.33, ฯƒ=45.02, โˆˆ[4.00,100.00])

print(pytc.tree_summary(tree))
# โ”Œโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
# โ”‚Name โ”‚Type  โ”‚Countโ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.a   โ”‚int   โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.b[0]โ”‚int   โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.b[1]โ”‚float โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.c   โ”‚f32[3]โ”‚3    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚ฮฃ    โ”‚Tree  โ”‚6    โ”‚
# โ””โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜


# ** pass it to jax transformations **

# freeze all non-differentiable parameters to make it
# work with jax trnasformations
mask = jax.tree_map(pytc.is_nondiff, tree)
tree = tree.at[mask].apply(pytc.freeze)

@jax.jit
@jax.grad
def sum_tree(tree:Tree, x):
    # unfreeze before calling tree
    tree = tree.at[...].apply(pytc.unfreeze, is_leaf=pytc.is_frozen)
    return sum(tree(x))

print(sum_tree(tree, 1.0))
# Tree(a=#10, b=(#10, 0.0), c=[1. 1. 1.])

๐ŸŽจ Visualize

Visualize PyTrees
tree_summary tree_diagram [tree_mermaid](https://mermaid.js.org)(Native support in Github/Notion) tree_repr tree_str
print(pytc.tree_summary(tree, depth=1))
โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
โ”‚Nameโ”‚Type  โ”‚Countโ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.a  โ”‚int   โ”‚1    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.b  โ”‚tuple โ”‚1    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.c  โ”‚f32[3]โ”‚3    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ   โ”‚Tree  โ”‚5    โ”‚
โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜
print(pytc.tree_diagram(tree, depth=1))
Tree
โ”œโ”€โ”€ .a=1
โ”œโ”€โ”€ .b=(...)
โ””โ”€โ”€ .c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])
print(pytc.tree_mermaid(tree, depth=1))
flowchart LR
    id0(<b>Tree</b>)
    id0 --- id1("</b>.a=1</b>")
    id0 --- id2("</b>.b=(...)</b>")
    id0 --- id3("</b>.c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])</b>")
print(pytc.tree_repr(tree, depth=1))
Tree(a=1, b=(...), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))
print(pytc.tree_str(tree, depth=1))
Tree(a=1, b=(...), c=[4. 5. 6.])
print(pytc.tree_summary(tree, depth=2))
โ”Œโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
โ”‚Name โ”‚Type  โ”‚Countโ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.a   โ”‚int   โ”‚1    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.b[0]โ”‚int   โ”‚1    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.b[1]โ”‚float โ”‚1    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.c   โ”‚f32[3]โ”‚3    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ    โ”‚Tree  โ”‚6    โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜
print(pytc.tree_diagram(tree, depth=2))
Tree
โ”œโ”€โ”€ .a=1
โ”œโ”€โ”€ .b:tuple
โ”‚   โ”œโ”€โ”€ [0]=2.0
โ”‚   โ””โ”€โ”€ [1]=3.0
โ””โ”€โ”€ .c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])
print(pytc.tree_mermaid(tree, depth=2))
flowchart LR
    id2 --- id3("</b>[0]=2.0</b>")
    id2 --- id4("</b>[1]=3.0</b>")
    id0(<b>Tree</b>)
    id0 --- id1("</b>.a=1</b>")
    id0 --- id2("</b>.b:tuple</b>")
    id0 --- id5("</b>.c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])</b>")
print(pytc.tree_repr(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))
print(pytc.tree_str(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])

๐Ÿƒ Working with jax transformation

Make arbitrary PyTrees work with jax transformations

Parameters are defined in Tree at the top of class definition similar to defining dataclasses.dataclass field. Lets optimize our parameters

import pytreeclass as pytc
import jax
import jax.numpy as jnp


class Tree(pytc.TreeClass)
    a: int = 1
    b: tuple[float] = (2., 3.)
    c: jax.Array = jnp.array([4., 5., 6.])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x


tree = Tree()


@jax.grad
def loss_func(tree: Tree, x: jax.Array):
    tree = tree.at[...].apply(pytc.unfreeze, is_leaf=pytc.is_frozen)  # <--- unfreeze the tree before calling it
    preds = jax.vmap(tree)(x)  # <--- vectorize the tree call over the leading axis
    return jnp.mean(preds**2)  # <--- return the mean squared error


@jax.jit
def train_step(tree: Tree, x: jax.Array):
    grads = loss_func(tree, x)
    # apply a small gradient step
    return jax.tree_util.tree_map(lambda x, g: x - 1e-3 * g, tree, grads)


# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)

for epoch in range(1_000):
    jaxable_tree = train_step(jaxable_tree, jnp.ones([10, 1]))

print(jaxable_tree)
# **the `frozen` params have "#" prefix**
# Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])


# unfreeze the tree
tree = jaxable_tree.at[...].apply(pytc.unfreeze, is_leaf=pytc.is_frozen)
# the previous line is equivalent to:
# >>> tree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])

โ˜๏ธ Advanced Indexing with .at[]

Out-of-place updates using mask, attribute name or index

PyTreeClass offers 3 means of indexing through .at[]

  1. Indexing by boolean mask.
  2. Indexing by attribute name.
  3. Indexing by Leaf index.

Since treeclass wrapped class are immutable, .at[] operations returns new instance of the tree

Index update by boolean mask

tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4,6]))

# lets create a mask for values > 4
mask = jax.tree_util.tree_map(lambda x: x>4, tree)

print(mask)
# Tree(a=False, b=(False, False), c=[False  True  True])

print(tree.at[mask].get())
# Tree(a=None, b=(None, None), c=[5 6])

print(tree.at[mask].set(10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])

print(tree.at[mask].apply(lambda x: 10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])

Index update by attribute name

tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4,6]))

print(tree.at["a"].get())
# Tree(a=1, b=(None, None), c=None)

print(tree.at["a"].set(10))
# Tree(a=10, b=(2, 3), c=[4 5 6])

print(tree.at["a"].apply(lambda x: 10))
# Tree(a=10, b=(2, 3), c=[4 5 6])

Index update by integer index

tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4,6]))

print(tree.at[1].at[0].get())
# Tree(a=None, b=(2.0, None), c=None)

print(tree.at[1].at[0].set(10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])

print(tree.at[1].at[0].apply(lambda x: 10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])

Mix, match , and chain index update

import jax
import jax.numpy as jnp
import pytreeclass as pytc

class Tree(pytc.TreeClass):
    a: int = 1
    b: str = "b"
    c: float = 1.0
    d: bool = True
    e: tuple = (1, 2, 3)
    f: jax.Array = jax.numpy.array([1, 2, 3])

tree = Tree()

integer_mask = jax.tree_util.tree_map(lambda x: isinstance(x, int), tree)

tree = (
    tree
    .at["a"].set(10)
    .at["b"].set("B")
    .at["c"].set(10.0)
    .at["d"].set(False)
    .at["e"].at[0].set(10)  # set first element of tuple to 10
    .at["f"].apply(jnp.sin)  # apply to all elements in array
    .at[integer_mask].apply(float)  # cast all `int` to `float`
)

print(tree)
# Tree(
#   a=10.0,
#   b=B,
#   c=10.0,
#   d=0.0,
#   e=(10.0, 2.0, 3.0),
#   f=[0.841471  0.9092974 0.14112  ]
# )

๐Ÿ“œ Stateful computations

First, Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using TreeClass no need to separate the instance variables ; instead the whole instance is passed as a state.

Using the following pattern,Updating state functionally can be achieved under jax.jit

import jax
import pytreeclass as pytc

class Counter(pytc.TreeClass):
    calls : int = 0

    def increment(self):
        self.calls += 1
counter = Counter() # Counter(calls=0)

Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at. To achieve this we can use .at[method_name].__call__(*args,**kwargs), this functional call will return the value of this call and a new model instance with the update state.

@jax.jit
def update(counter):
    value, new_counter = counter.at["increment"]()
    return new_counter

for i in range(10):
    counter = update(counter)

print(counter.calls) # 10

โž• More

Validate or convert inputs using callbacks

PyTreeClass includes callbacks in the field to apply a sequence of functions on input at setting the attribute stage. The callback is quite useful in several cases, for instance, to ensure a certain input type within a valid range. See example:

import jax
import pytreeclass as pytc

def positive_int_callback(value):
    if not isinstance(value, int):
        raise TypeError("Value must be an integer")
    if value <= 0:
        raise ValueError("Value must be positive")
    return value


class Tree(pytc.TreeClass):
    in_features:int = pytc.field(callbacks=[positive_int_callback])


tree = Tree(1)
# no error

tree = Tree(0)
# ValueError: Error for field=`in_features`:
# Value must be positive

tree = Tree(1.0)
# TypeError: Error for field=`in_features`:
# Value must be an integer
Add leafwise math operations to PyTreeClass wrapped class
import functools as ft
import pytreeclass as pytc
import jax
import jax.tree_util as jtu
import jax.numpy as jnp


class Tree(pytc.TreeClass, leafwise=True):
    a: int = 1
    b: tuple[float] = (2., 3.)
    c: jax.Array = jnp.array([4., 5., 6.])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x


tree = Tree()

tree + 100
# Tree(a=101, b=(102.0, 103.0), c=f32[3](ฮผ=105.00, ฯƒ=0.82, โˆˆ[104.00,106.00]))


@jax.grad
def loss_func(tree: Tree, x: jax.Array):
    tree = jtu.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen)  # <--- unfreeze the tree before calling it
    preds = jax.vmap(tree)(x)  # <--- vectorize the tree call over the leading axis
    return jnp.mean(preds**2)  # <--- return the mean squared error


@jax.jit
def train_step(tree: Tree, x: jax.Array):
    grads = loss_func(tree, x)
    return tree - grads * 1e-3  # <--- eliminate `tree_map`


# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)

for epoch in range(1_000):
    jaxable_tree = train_step(jaxable_tree, jnp.ones([10, 1]))

print(jaxable_tree)
# **the `frozen` params have "#" prefix**
# Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])


# unfreeze the tree
tree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])
Eliminate tree_map using bcmap + treeclass(..., leafwise=True)

TDLR

import functools as ft
import pytreeclass as pytc
import jax.numpy as jnp

class Tree(pytc.TreeClass, leafwise=True):
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jnp.array([4.,5.,6.])

tree = Tree()

print(pytc.bcmap(jnp.where)(tree>2, tree+100, 0))
# Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])

bcmap(func, is_leaf) maps a function over PyTrees leaves with automatic broadcasting for scalar arguments.

bcmap is function transformation that broadcast a scalar to match the first argument of the function this enables us to convert a function like jnp.where to work with arbitrary tree structures without the need to write a specific function for each broadcasting case

For example, lets say we want to use jnp.where to zeros out all values in an arbitrary tree structure that are less than 0

tree = ([1], {"a":1, "b":2}, (1,), -1,)

we can use jax.tree_util.tree_map to apply jnp.where to the tree but we need to write a specific function for broadcasting the scalar to the tree

def map_func(leaf):
    # here we encoded the scalar `0` inside the function
    return jnp.where(leaf>0, leaf, 0)

jtu.tree_map(map_func, tree)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(0, dtype=int32, weak_type=True))

However, lets say we want to use jnp.where to set a value to a leaf value from another tree that looks like this

def map_func2(lhs_leaf, rhs_leaf):
    # here we encoded the scalar `0` inside the function
    return jnp.where(lhs_leaf>0, lhs_leaf, rhs_leaf)

tree2 = jtu.tree_map(lambda x: 1000, tree)

jtu.tree_map(map_func2, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(1000, dtype=int32, weak_type=True))

Now, bcmap makes this easier by figuring out the broadcasting case.

broadcastable_where = pytc.bcmap(jnp.where)
mask = jtu.tree_map(lambda x: x>0, tree)

case 1

broadcastable_where(mask, tree, 0)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(0, dtype=int32, weak_type=True))

case 2

broadcastable_where(mask, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(1000, dtype=int32, weak_type=True))

lets then take this a step further to eliminate mask from the equation by using pytreeclass with leafwise=True

class Tree(pytc.TreeClass, leafwise=True):
    tree : tuple = ([1], {"a":1, "b":2}, (1,), -1,)

tree = Tree()
# Tree(tree=([1], {a:1, b:2}, (1), -1))

case 1: broadcast scalar to tree

print(broadcastable_where(tree>0, tree, 0))
# Tree(tree=([1], {a:1, b:2}, (1), 0))

case 2: broadcast tree to tree
```python
print(broadcastable_where(tree>0, tree, tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))

bcmap also works with all kind of arguments in the wrapped function

print(broadcastable_where(tree>0, x=tree, y=tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))

in concolusion, bcmap is a function transformation that can be used to to make functions work with arbitrary tree structures without the need to write a specific function for each broadcasting case

Moreover, bcmap can be more powerful when used with pytreeclass to facilitate operation of arbitrary functions on PyTree objects without the need to use tree_map

Use PyTreeClass vizualization tools with arbitrary PyTrees
import jax
import pytreeclass as pytc

tree = [1, [2,3], 4]

print(pytc.tree_diagram(tree, depth=1))
# list
# โ”œโ”€โ”€ [0]=1
# โ”œโ”€โ”€ [1]=[...]
# โ””โ”€โ”€ [2]=4

print(pytc.tree_diagram(tree, depth=2))
# list
# โ”œโ”€โ”€ [0]=1
# โ”œโ”€โ”€ [1]:list
# โ”‚   โ”œโ”€โ”€ [0]=2
# โ”‚   โ””โ”€โ”€ [1]=3
# โ””โ”€โ”€ [2]=4


print(pytc.tree_summary(tree, depth=1))
# โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
# โ”‚Nameโ”‚Typeโ”‚Countโ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[0] โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[1] โ”‚listโ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[2] โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚ฮฃ   โ”‚listโ”‚3    โ”‚
# โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜

print(pytc.tree_summary(tree, depth=2))
# โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
# โ”‚Name  โ”‚Typeโ”‚Countโ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[0]   โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[1][0]โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[1][1]โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚[2]   โ”‚int โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚ฮฃ     โ”‚listโ”‚4    โ”‚
# โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜
Use PyTreeClass components with other libraries
import jax
import pytreeclass as pytc
from flax import struct

import jax
import pytreeclass as pytc
from flax import struct

# note that flax is registered with `jax.tree_util.register_pytree_with_keys`
# otherwise for arbitrary objects you need to do the key registration

@struct.dataclass
class FlaxTree:
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jax.numpy.array([4.,5.,6.])

    def __repr__(self) -> str:
        return pytc.tree_repr(self)
    def __str__(self) -> str:
        return pytc.tree_str(self)
    @property
    def at(self):
        return pytc.tree_indexer(self)

flax_tree = FlaxTree()

print(f"{flax_tree!r}")
# FlaxTree(a=1, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))

print(f"{flax_tree!s}")
# FlaxTree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])

print(pytc.tree_diagram(flax_tree))
# FlaxTree
# โ”œโ”€โ”€ .a=1
# โ”œโ”€โ”€ .b:tuple
# โ”‚   โ”œโ”€โ”€ [0]=2.0
# โ”‚   โ””โ”€โ”€ [1]=3.0
# โ””โ”€โ”€ .c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])

print(pytc.tree_summary(flax_tree))
# โ”Œโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
# โ”‚Name โ”‚Type    โ”‚Countโ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.a   โ”‚int     โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.b[0]โ”‚float   โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.b[1]โ”‚float   โ”‚1    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.c   โ”‚f32[3]  โ”‚3    โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚ฮฃ    โ”‚FlaxTreeโ”‚6    โ”‚
# โ””โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜

flax_tree.at[0].get()
# FlaxTree(a=1, b=(None, None), c=None)

flax_tree.at["a"].set(10)
# FlaxTree(a=10, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))
Benchmark flatten/unflatten compared to Flax and Equinox

Open In Colab

CPUGPU

๐Ÿ“™ Acknowledgements

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytreeclass-0.3.7.tar.gz (52.2 kB view hashes)

Uploaded Source

Built Distribution

pytreeclass-0.3.7-py3-none-any.whl (52.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page