Module pytrees that cleanly handle parameter trainability and transformations for JAX models.
Project description
My🌳
"Module pytrees" that cleanly handle parameter trainability and transformations for JAX models.
Installation
pip install mytree
Usage
Defining a model
- Define all your class attributes upfront as an annotation (a bit like a dataclass!).
- Mark 🍀 attributes with
param_field
to set a default bijector and trainable status. - Unmarked 🍀 attributes behave as
param_field(bijector=Identity trainable=True)
.
from mytree import Mytree, param, Softplus, Identity
class SimpleModel(Mytree):
# Marked ☘️ to set default bijector and trainability.
weight: float = param_field(bijector=Softplus, trainable=False)
# Unmarked ☘️ has `Identity` bijector and trainability set to `True`.
bias: float
def __init__(self, weight, bias):
self.weight = weight
self.bias = bias
def __call__(self, test_point):
return test_point * self.weight + self.bias
- We are fully compatible with Distrax and TensorFlow Probability bijectors, so feel free to use these!
- As
Mytree
inherits from simple-pytree'sPytree
, you can mark fields as static viasimple_pytree.static_field
.
Dataclasses
You can seamlessly use the dataclasses.dataclass
decorator with Mytree
classes.
from dataclasses import dataclass
@dataclass
class SimpleModel(Mytree):
weight: float = param_field(bijector=Softplus, trainable=False)
bias: float
def __call__(self, test_point):
return test_point * self.weight + self.bias
Replacing values
Update values via replace
.
model = SimpleModel(1.0, 2.0)
model.replace(weight=123.0)
SimpleModel(weight=123.0, bias=2.0)
Transformations 🤖
Applying transformations
Use constrain
/ unconstrain
to return a Mytree
with each parameter's bijector forward
/ inverse
operation applied!
model.constrain()
model.unconstrain()
SimpleModel(weight=1.3132616, bias=2.0)
SimpleModel(weight=0.5413248, bias=2.0)
Replacing transformations
Default transformations can be replaced on an instance via the replace_bijector
method.
new = model.replace_bijector(bias=Identity)
new.constrain()
new.unconstrain()
SimpleModel(weight=1.0, bias=2.0)
SimpleModel(weight=1.0, bias=2.0)
And we see that weight
's gradient is no longer zero.
Trainability 🚂
Applying trainability
We begin by creating some simulated data.
import jax
n = 100
key = jax.random.PRNGKey(123)
x = jax.random.uniform(key, (n, ))
y = 3.0 * x + 2.0 + 1e-3 * jax.random.normal(key, (n, ))
And create a mean-squared-error loss function to evaluate our model on.
def loss(model: SimpleModel) -> float:
model = model.stop_gradient()
return jax.numpy.sum((y - model(x))**2)
Here we use the stop_gradient
method within the loss function, to prevent the flow of gradients during forward or reverse-mode automatic differentiation.
jax.grad(loss)(model)
SimpleModel(weight=0.0, bias=-188.37418)
As weight
trainability was set to False
, it's gradient is zero as expected!
Replacing trainability
Default trainability status can be replaced on an instance via the replace_trainable
method.
new = model.replace_trainable(weight=True)
jax.grad(loss)(model)
SimpleModel(weight=-121.42676, bias=-188.37418)
And we see that weight
's gradient is no longer zero.
Performance 🏎
This is an early experimental library to demonstrate an idea, so it is not yet optimised. Some benchmarks can be found in: https://github.com/Daniel-Dodd/mytree/blob/master/benchmarks/benchmarks.ipynb
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.