Skip to main content

No project description provided

Project description

codecov

Simple Pytree

A dead simple Python package for creating custom JAX pytree objects.

  • Strives to be minimal, the implementation is just ~100 lines of code
  • Has no dependencies other than JAX
  • Its compatible with both dataclasses and regular classes
  • It has no intention of supporting Neural Network use cases (e.g. partitioning)

Installation

pip install simple-pytree

Usage

import jax
from simple_pytree import Pytree

class Foo(Pytree):
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo)

assert foo.x == -1 and foo.y == -2

Static fields

You can mark fields as static by assigning static_field() to a class attribute with the same name as the instance attribute:

import jax
from simple_pytree import Pytree, static_field

class Foo(Pytree):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified

assert foo.x == -1 and foo.y == 2

Static fields are not included in the pytree leaves, they are passed as pytree metadata instead.

Dataclasses

simple_pytree provides a dataclass decorator you can use with classes that contain static_fields:

import jax
from simple_pytree import Pytree, dataclass, static_field

@dataclass
class Foo(Pytree):
    x: int
    y: int = static_field(default=2)
    
foo = Foo(1)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified

assert foo.x == -1 and foo.y == 2

simple_pytree.dataclass is just a wrapper around dataclasses.dataclass but when used static analysis tools and IDEs will understand that static_field is a field specifier just like dataclasses.field.

Mutability

Pytree objects are immutable by default after __init__:

from simple_pytree import Pytree, static_field

class Foo(Pytree):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo.x = 3 # AttributeError

If you want to make them mutable, you can use the mutable argument in class definition:

from simple_pytree import Pytree, static_field

class Foo(Pytree, mutable=True):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo.x = 3 # OK

Replacing fields

If you want to make a copy of a Pytree object with some fields modified, you can use the .replace() method:

from simple_pytree import Pytree, static_field

class Foo(Pytree):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo = foo.replace(x=10)

assert foo.x == 10 and foo.y == 2

replace works for both mutable and immutable Pytree objects. If the class is a dataclass, replace internally use dataclasses.replace.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

simple_pytree-0.2.2.tar.gz (5.3 kB view hashes)

Uploaded Source

Built Distribution

simple_pytree-0.2.2-py3-none-any.whl (6.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page