Skip to main content

No project description provided

Project description

codecov

Simple Pytree

A dead simple Python package for creating custom JAX pytree objects.

  • Strives to be minimal, the implementation is just ~100 lines of code
  • Has no dependencies other than JAX
  • Its compatible with both dataclasses and regular classes
  • It has no intention of supporting Neural Network use cases (e.g. partitioning)

Installation

pip install simple-pytree

Usage

import jax
from simple_pytree import Pytree

class Foo(Pytree):
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo)

assert foo.x == -1 and foo.y == -2

Static fields

You can mark fields as static by assigning static_field() to a class attribute with the same name as the instance attribute:

import jax
from simple_pytree import Pytree, static_field

class Foo(Pytree):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified

assert foo.x == -1 and foo.y == 2

Static fields are not included in the pytree leaves, they are passed as pytree metadata instead.

Dataclasses

simple_pytree provides a dataclass decorator you can use with classes that contain static_fields:

import jax
from simple_pytree import Pytree, dataclass, static_field

@dataclass
class Foo(Pytree):
    x: int
    y: int = static_field(default=2)
    
foo = Foo(1)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified

assert foo.x == -1 and foo.y == 2

simple_pytree.dataclass is just a wrapper around dataclasses.dataclass but when used static analysis tools and IDEs will understand that static_field is a field specifier just like dataclasses.field.

Mutability

Pytree objects are immutable by default after __init__:

from simple_pytree import Pytree, static_field

class Foo(Pytree):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo.x = 3 # AttributeError

If you want to make them mutable, you can use the mutable argument in class definition:

from simple_pytree import Pytree, static_field

class Foo(Pytree, mutable=True):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo.x = 3 # OK

Replacing fields

If you want to make a copy of a Pytree object with some fields modified, you can use the .replace() method:

from simple_pytree import Pytree, static_field

class Foo(Pytree):
    y = static_field()
    
    def __init__(self, x, y):
        self.x = x
        self.y = y

foo = Foo(1, 2)
foo = foo.replace(x=10)

assert foo.x == 10 and foo.y == 2

replace works for both mutable and immutable Pytree objects. If the class is a dataclass, replace internally use dataclasses.replace.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

simple_pytree-0.2.2.tar.gz (5.3 kB view details)

Uploaded Source

Built Distribution

simple_pytree-0.2.2-py3-none-any.whl (6.2 kB view details)

Uploaded Python 3

File details

Details for the file simple_pytree-0.2.2.tar.gz.

File metadata

  • Download URL: simple_pytree-0.2.2.tar.gz
  • Upload date:
  • Size: 5.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.0 CPython/3.8.17 Linux/5.15.0-1041-azure

File hashes

Hashes for simple_pytree-0.2.2.tar.gz
Algorithm Hash digest
SHA256 b61eddb5b5de558209dfb6464a041c5faf06af2c3cab582f4d69543a773dab0e
MD5 05d4b11e511d5d1241366d19d868bc1a
BLAKE2b-256 8eb0b2e7ea15dfb26bf014cfb6243a9bb20b9477ee2f12d754257514f508639a

See more details on using hashes here.

File details

Details for the file simple_pytree-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: simple_pytree-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 6.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.0 CPython/3.8.17 Linux/5.15.0-1041-azure

File hashes

Hashes for simple_pytree-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 3a7a2f66883194ab14875dd01c4306c4f102f31d90e7e2d421366f12a2c49bab
MD5 8074401ea8a9f7b45489baf73398a222
BLAKE2b-256 7c160272467306ef489512a843222567c9939b9aff7003f15474a1ef90168c8f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page