Skip to main content

Visualize, create, and operate on JAX PyTree in the most intuitive way possible.

Project description



Installation |Description |Quick Example |StatefulComputation |Benchamrks |Acknowledgements

Tests pyver pyver codestyle Open In Colab Downloads codecov Documentation Status GitHub commit activity DOI PyPI CodeFactor

๐Ÿ› ๏ธ Installation

pip install pytreeclass

Install development version

pip install git+https://github.com/ASEM000/pytreeclass

๐Ÿ“– Description

pytreeclass is a JAX-compatible class builder to create and operate on stateful JAX PyTrees in a performant and intuitive way, by building on familiar concepts found in numpy, dataclasses, and others.

See documentation and ๐Ÿณ Common recipes to check if this library is a good fit for your work. If you find the package useful consider giving it a ๐ŸŒŸ.

โฉ Quick Example

import jax
import jax.numpy as jnp
import pytreeclass as tc

@tc.autoinit
class Tree(tc.TreeClass):
    a: float = 1.0
    b: tuple[float, float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x


tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
       .at["a"].set(100.0)\
       .at["b"][0].set(10.0)\
       .at[mask].set(100.0)

print(tree)
# Tree(a=100.0, b=(10.0, 3.0), c=[  4.   5. 100.])

print(tc.tree_diagram(tree))
# Tree
# โ”œโ”€โ”€ .a=100.0
# โ”œโ”€โ”€ .b:tuple
# โ”‚   โ”œโ”€โ”€ [0]=10.0
# โ”‚   โ””โ”€โ”€ [1]=3.0
# โ””โ”€โ”€ .c=f32[3](ฮผ=36.33, ฯƒ=45.02, โˆˆ[4.00,100.00])

print(tc.tree_summary(tree))
# โ”Œโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”
# โ”‚Name โ”‚Type  โ”‚Countโ”‚Size  โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.a   โ”‚float โ”‚1    โ”‚      โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.b[0]โ”‚float โ”‚1    โ”‚      โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.b[1]โ”‚float โ”‚1    โ”‚      โ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚.c   โ”‚f32[3]โ”‚3    โ”‚12.00Bโ”‚
# โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
# โ”‚ฮฃ    โ”‚Tree  โ”‚6    โ”‚12.00Bโ”‚
# โ””โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”˜

# ** pass it to jax transformations **
# works with jit, grad, vmap, etc.

@jax.jit
@jax.grad
def sum_tree(tree: Tree, x):
    return sum(tree(x))

print(sum_tree(tree, 1.0))
# Tree(a=3.0, b=(3.0, 0.0), c=[1. 1. 1.])

๐Ÿ“œ Stateful computations

Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using TreeClass no need to separate the instance variables ; instead the whole instance is passed as a state.

Using the following pattern,Updating state functionally can be achieved under jax.jit

import jax
import pytreeclass as tc

class Counter(tc.TreeClass):
    def __init__(self, calls: int = 0):
        self.calls = calls

    def increment(self):
        self.calls += 1
counter = Counter() # Counter(calls=0)

Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at. To achieve this we can use .at[method_name].__call__(*args,**kwargs), this functional call will return the value of this call and a new model instance with the update state.

@jax.jit
def update(counter):
    value, new_counter = counter.at["increment"]()
    return new_counter

for i in range(10):
    counter = update(counter)

print(counter.calls) # 10

โž• Benchmarks

Benchmark flatten/unflatten compared to Flax and Equinox

Open In Colab

CPUGPU
Benchmark simple training against `flax` and `equinox`

Training simple sequential linear benchmark against flax and equinox

Num of layers Flax/tc time
Open In Colab
Equinox/tc time
Open In Colab
10 1.427 6.671
100 1.1130 2.714

๐Ÿ“™ Acknowledgements

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytreeclass-0.9.1.tar.gz (55.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytreeclass-0.9.1-py3-none-any.whl (46.4 kB view details)

Uploaded Python 3

File details

Details for the file pytreeclass-0.9.1.tar.gz.

File metadata

  • Download URL: pytreeclass-0.9.1.tar.gz
  • Upload date:
  • Size: 55.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for pytreeclass-0.9.1.tar.gz
Algorithm Hash digest
SHA256 4cb654c18c81d8ba8647819c0c1b86178423f5d02a7fb27253d27c23f3d8429a
MD5 9aabc5e529922b55a797a4cb291ef9d4
BLAKE2b-256 e1aa2a66cfc078835e902fc4bcba2dba6ed066c9eb880067b0578d2bbebdf7fa

See more details on using hashes here.

File details

Details for the file pytreeclass-0.9.1-py3-none-any.whl.

File metadata

  • Download URL: pytreeclass-0.9.1-py3-none-any.whl
  • Upload date:
  • Size: 46.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for pytreeclass-0.9.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d39bc5f656097795e283546bb40317d92639bbb8503c32b1e78501bf521de7ff
MD5 fa7103e5b52cd2fe5ab8932244c268ae
BLAKE2b-256 41ce48c7e6012893397ae4ae6a31d765a2bd440f1c9c10f837affb3b9349bef9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page