No project description provided
Project description
Simple Pytree
A dead simple Python package for creating custom JAX pytree objects.
- Strives to be minimal, the implementation is just ~100 lines of code
- Has no dependencies other than JAX
- Its compatible with both
dataclasses
and regular classes - It has no intention of supporting Neural Network use cases (e.g. partitioning)
Installation
pip install simple-pytree
Usage
import jax
from simple_pytree import Pytree
class Foo(Pytree):
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo)
assert foo.x == -1 and foo.y == -2
Static fields
You can mark fields as static by assigning static_field()
to a class attribute with the same name
as the instance attribute:
import jax
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified
assert foo.x == -1 and foo.y == 2
Static fields are not included in the pytree leaves, they are passed as pytree metadata instead.
Dataclasses
simple_pytree
provides a dataclass
decorator you can use with classes
that contain static_field
s:
import jax
from simple_pytree import Pytree, dataclass, static_field
@dataclass
class Foo(Pytree):
x: int
y: int = static_field(default=2)
foo = Foo(1)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified
assert foo.x == -1 and foo.y == 2
simple_pytree.dataclass
is just a wrapper around dataclasses.dataclass
but
when used static analysis tools and IDEs will understand that static_field
is a
field specifier just like dataclasses.field
.
Mutability
Pytree
objects are immutable by default after __init__
:
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo.x = 3 # AttributeError
If you want to make them mutable, you can use the mutable
argument in class definition:
from simple_pytree import Pytree, static_field
class Foo(Pytree, mutable=True):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo.x = 3 # OK
Replacing fields
If you want to make a copy of a Pytree
object with some fields modified, you can use the .replace()
method:
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = foo.replace(x=10)
assert foo.x == 10 and foo.y == 2
replace
works for both mutable and immutable Pytree
objects. If the class
is a dataclass
, replace
internally use dataclasses.replace
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file simple_pytree-0.2.2.tar.gz
.
File metadata
- Download URL: simple_pytree-0.2.2.tar.gz
- Upload date:
- Size: 5.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.4.0 CPython/3.8.17 Linux/5.15.0-1041-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b61eddb5b5de558209dfb6464a041c5faf06af2c3cab582f4d69543a773dab0e |
|
MD5 | 05d4b11e511d5d1241366d19d868bc1a |
|
BLAKE2b-256 | 8eb0b2e7ea15dfb26bf014cfb6243a9bb20b9477ee2f12d754257514f508639a |
File details
Details for the file simple_pytree-0.2.2-py3-none-any.whl
.
File metadata
- Download URL: simple_pytree-0.2.2-py3-none-any.whl
- Upload date:
- Size: 6.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.4.0 CPython/3.8.17 Linux/5.15.0-1041-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3a7a2f66883194ab14875dd01c4306c4f102f31d90e7e2d421366f12a2c49bab |
|
MD5 | 8074401ea8a9f7b45489baf73398a222 |
|
BLAKE2b-256 | 7c160272467306ef489512a843222567c9939b9aff7003f15474a1ef90168c8f |