No project description provided
Project description
Simple Pytree
A dead simple Python package for creating custom JAX pytree objects.
- Strives to be minimal, the implementation is just ~100 lines of code
- Has no dependencies other than JAX
- Its compatible with both
dataclassesand regular classes - It has no intention of supporting Neural Network use cases (e.g. partitioning)
Installation
pip install simple-pytree
Usage
import jax
from simple_pytree import Pytree
class Foo(Pytree):
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo)
assert foo.x == -1 and foo.y == -2
Static fields
You can mark fields as static by assigning static_field() to a class attribute with the same name
as the instance attribute:
import jax
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified
assert foo.x == -1 and foo.y == 2
Static fields are not included in the pytree leaves, they are passed as pytree metadata instead.
Dataclasses
simple_pytree provides a dataclass decorator you can use with classes
that contain static_fields:
import jax
from simple_pytree import Pytree, dataclass, static_field
@dataclass
class Foo(Pytree):
x: int
y: int = static_field(default=2)
foo = Foo(1)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified
assert foo.x == -1 and foo.y == 2
simple_pytree.dataclass is just a wrapper around dataclasses.dataclass but
when used static analysis tools and IDEs will understand that static_field is a
field specifier just like dataclasses.field.
Mutability
Pytree objects are immutable by default after __init__:
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo.x = 3 # AttributeError
If you want to make them mutable, you can use the mutable argument in class definition:
from simple_pytree import Pytree, static_field
class Foo(Pytree, mutable=True):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo.x = 3 # OK
Replacing fields
If you want to make a copy of a Pytree object with some fields modified, you can use the .replace() method:
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = foo.replace(x=10)
assert foo.x == 10 and foo.y == 2
replace works for both mutable and immutable Pytree objects. If the class
is a dataclass, replace internally use dataclasses.replace.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file simple_pytree-0.2.2.tar.gz.
File metadata
- Download URL: simple_pytree-0.2.2.tar.gz
- Upload date:
- Size: 5.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.4.0 CPython/3.8.17 Linux/5.15.0-1041-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b61eddb5b5de558209dfb6464a041c5faf06af2c3cab582f4d69543a773dab0e
|
|
| MD5 |
05d4b11e511d5d1241366d19d868bc1a
|
|
| BLAKE2b-256 |
8eb0b2e7ea15dfb26bf014cfb6243a9bb20b9477ee2f12d754257514f508639a
|
File details
Details for the file simple_pytree-0.2.2-py3-none-any.whl.
File metadata
- Download URL: simple_pytree-0.2.2-py3-none-any.whl
- Upload date:
- Size: 6.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.4.0 CPython/3.8.17 Linux/5.15.0-1041-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3a7a2f66883194ab14875dd01c4306c4f102f31d90e7e2d421366f12a2c49bab
|
|
| MD5 |
8074401ea8a9f7b45489baf73398a222
|
|
| BLAKE2b-256 |
7c160272467306ef489512a843222567c9939b9aff7003f15474a1ef90168c8f
|