Skip to main content

No project description provided

Project description

Treeo

A small library for creating and manipulating custom JAX Pytree classes

  • Light-weight: has no dependencies other than jax.
  • Compatible: Treeo Tree objects are compatible with any jax function that accepts Pytrees.
  • Standards-based: treeo.field is built on top of python's dataclasses.field.
  • Flexible: Treeo is compatible with both dataclass and non-dataclass classes.

Treeo lets you easily create class-based Pytrees so your custom objects can easily interact seamlessly with JAX. Uses of Treeo can range from just creating simple simple JAX-aware utility classes to using it as the core abstraction for full-blown frameworks. Treeo was originally extracted from the core of Treex and shares a lot in common with flax.struct.

Documentation | User Guide

Installation

Install using pip:

pip install treeo

Basics

With Treeo you can easily define your own custom Pytree classes by inheriting from Treeo's Tree class and using the field function to declare which fields are nodes (children) and which are static (metadata):

import treeo as to

@dataclass
class Person(to.Tree):
    height: jnp.array = to.field(node=True) # I am a node field!
    name: str = to.field(node=False) # I am a static field!

field is just a wrapper around dataclasses.field so you can define your Pytrees as dataclasses, but Treeo fully supports non-dataclass classes as well. Since all Tree instances are Pytree they work with the various functions from thejax library as expected:

p = Person(height=jnp.array(1.8), name="John")

# Trees can be jitted!
jax.jit(lambda person: person)(p) # Person(height=array(1.8), name='John')

# Trees can be mapped!
jax.tree_map(lambda x: 2 * x, p) # Person(height=array(3.6), name='John')

Kinds

Treeo also include a kind system that lets you give semantic meaning to fields (what a field represents within your application). A kind is just a type you pass to field via its kind argument:

class Parameter: pass
class BatchStat: pass

class BatchNorm(to.Tree):
    scale: jnp.ndarray = to.field(node=True, kind=Parameter)
    mean: jnp.ndarray = to.field(node=True, kind=BatchStat)

Kinds are very useful as a filtering mechanism via treeo.filter:

model = BatchNorm(...)

# select only Parameters, mean is filtered out
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)

Nothing behaves like None in Python, but it is a special value that is used to represent the absence of a value within Treeo.

Treeo also offers the merge function which lets you rejoin filtered Trees with a logic similar to Python dict.update but done recursively:

def loss_fn(params, model, ...):
    # add traced params to model
    model = to.merge(model, params)
    ...

# gradient only w.r.t. params
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)
grads = jax.grad(loss_fn)(params, model, ...)

For a more in-depth tour check out the User Guide.

Examples

A simple Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Character(to.Tree):
    position: jnp.ndarray = to.field(node=True)    # node field
    name: str = to.field(node=False, opaque=True)  # static field

character = Character(position=jnp.array([0, 0]), name='Adam')

# character can freely pass through jit
@jax.jit
def update(character: Character, velocity, dt) -> Character:
    character.position += velocity * dt
    return character

character = update(character velocity=jnp.array([1.0, 0.2]), dt=0.1)

A Stateful Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Counter(to.Tree):
    n: jnp.array = to.field(default=jnp.array(0), node=True) # node
    step: int = to.field(default=1, node=False) # static

    def inc(self):
        self.n += self.step

counter = Counter(step=2) # Counter(n=jnp.array(0), step=2)

@jax.jit
def update(counter: Counter):
    counter.inc()
    return counter

counter = update(counter) # Counter(n=jnp.array(2), step=2)

# map over the tree

Full Example - Linear Regression

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import treeo as to


class Linear(to.Tree):
    w: jnp.ndarray = to.node()
    b: jnp.ndarray = to.node()

    def __init__(self, din, dout, key):
        self.w = jax.random.uniform(key, shape=(din, dout))
        self.b = jnp.zeros(shape=(dout,))

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b


@jax.value_and_grad
def loss_fn(model, x, y):
    y_pred = model(x)
    loss = jnp.mean((y_pred - y) ** 2)

    return loss


def sgd(param, grad):
    return param - 0.1 * grad


@jax.jit
def train_step(model, x, y):
    loss, grads = loss_fn(model, x, y)
    model = jax.tree_map(sgd, model, grads)

    return loss, model


x = np.random.uniform(size=(500, 1))
y = 1.4 * x - 0.3 + np.random.normal(scale=0.1, size=(500, 1))

key = jax.random.PRNGKey(0)
model = Linear(1, 1, key=key)

for step in range(1000):
    loss, model = train_step(model, x, y)
    if step % 100 == 0:
        print(f"loss: {loss:.4f}")

X_test = np.linspace(x.min(), x.max(), 100)[:, None]
y_pred = model(X_test)

plt.scatter(x, y, c="k", label="data")
plt.plot(X_test, y_pred, c="b", linewidth=2, label="prediction")
plt.legend()
plt.show()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

treeo-0.2.2.tar.gz (20.1 kB view details)

Uploaded Source

Built Distribution

treeo-0.2.2-py3-none-any.whl (20.2 kB view details)

Uploaded Python 3

File details

Details for the file treeo-0.2.2.tar.gz.

File metadata

  • Download URL: treeo-0.2.2.tar.gz
  • Upload date:
  • Size: 20.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.1 CPython/3.8.10 Linux/5.13.0-1027-gcp

File hashes

Hashes for treeo-0.2.2.tar.gz
Algorithm Hash digest
SHA256 1999698b483659bdb8b6981b8cee38330a1d379767593fbc482f6b4a9f229626
MD5 c310feb9aa34e47b4caa56f9511e6d1d
BLAKE2b-256 b6bfdb5fcacbfc18fd617eeeb9bf60b6e163de7a9a6695a38df9dd5513509d1c

See more details on using hashes here.

File details

Details for the file treeo-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: treeo-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 20.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.1 CPython/3.8.10 Linux/5.13.0-1027-gcp

File hashes

Hashes for treeo-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 d7862a198ef60478afe6a8615195df7d76ab2ecee22af0de20f5126d4b2f7f05
MD5 643f36ec53d77ed9e2a082ec255ae94f
BLAKE2b-256 1557079b870662b41f80f6b3646731f39380162d93d2bfaa4eeb0e5c68d2b985

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page