PyTorch-like neural networks in JAX
Project description
Equinox
Callable PyTrees and filtered JIT/grad transformations
=> neural networks in JAX
Equinox brings more power to your model building in JAX.
Represent parameterised functions as data and use filtered transformations for powerful fine-grained control of the model-building process. Equinox demonstrates how to use a PyTorch-like class-based API without compromising on JAX-like functional programming.
Equinox is half tech-demo, half neural network library, and comes with no behind-the-scenes magic, guaranteed.
The elegance of Equinox is its selling point in a world that already has Haiku, Flax and so on.
Quick start
Installation
pip install equinox
Requires Python 3.7+ and JAX 0.2.18+.
Parameterised functions as data
Equinox represents parameterised functions as PyTrees. (For example a neural network is a function parameterised by its weights, biases, etc.) Now you can JIT/grad/etc. a higher-order function (like a loss function) with respect to a parameterised function as its input (like a model).
Previous libraries have introduced a lot of extra complexity to make this work. e.g. custom notions of parameter groups, class-to-functional transformations, or specially-wrapped library.jit
or library.grad
to be compatible with JAX's JIT/grad/etc.
In contrast, Equinox makes it elegant:
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
class MyModule(eqx.Module):
# Specify the module's attributes;
layers: list
bias: jnp.ndarray
# And how to initialise them;
def __init__(self, key):
key1, key2 = jrandom.split(key)
self.layers = [eqx.nn.Linear(2, 8, key=key1),
eqx.nn.Linear(8, 2, key=key2)]
self.bias = jnp.ones(2)
# And the forward pass of the model.
def __call__(self, x):
for layer in self.layers[:-1]:
x = jnn.relu(layer(x))
return self.layers[-1](x) + self.bias
@jax.jit
@jax.grad
def loss(model, x, y):
pred_y = jax.vmap(model)(x)
return jnp.mean((y - pred_y) ** 2)
x_key, y_key, model_key = jrandom.split(jrandom.PRNGKey(0), 3)
x, y = jrandom.normal(x_key, (100, 2)), jrandom.normal(y_key, (100, 2))
model = MyModule(model_key)
grads = loss(model, x, y)
And there's no magic there! All eqx.Module
really does is register your class with JAX as a PyTree node. (In fact the source code for eqx.Module
is only about 100 lines long.)
Filtering
In the previous example, all of the model attributes were Modules and JAX arrays.
Arbitrary Python objects are fine too! We just need to handle them appropriately around jax.jit
and jax.grad
.
import equinox as eqx
import functools as ft
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
class AnotherModule(eqx.Module):
layers: list
def __init__(self, key):
key1, key2 = jrandom.split(key)
# Model now has `jnn.relu` -- a Python function -- as part of its PyTree.
self.layers = [eqx.nn.Linear(2, 8, key=key1),
jnn.relu,
eqx.nn.Linear(8, 2, key=key2)]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
x_key, y_key, model_key = jrandom.split(jrandom.PRNGKey(0), 3)
x, y = jrandom.normal(x_key, (100, 2)), jrandom.normal(y_key, (100, 2))
model = AnotherModule(model_key)
# Option 1: explicitly filter out anything that isn't JIT/grad-able.
@ft.partial(jax.jit, static_argnums=1)
@jax.grad
def loss(params, static, x, y):
model = eqx.combine(params, static)
pred_y = jax.vmap(model)(x)
return jnp.mean((y - pred_y) ** 2)
params, static = eqx.partition(model, eqx.is_array)
loss(params, static, x, y)
# Option 2: use filtered transformations, which automates the above process for you.
# (Can be handy if you want to JIT/grad with respect to different things!)
@eqx.filter_jit
@eqx.filter_grad
def loss(model, x, y):
pred_y = jax.vmap(model)(x)
return jnp.mean((y - pred_y) ** 2)
loss(model, x, y)
Here, params
and static
are actually both instances of AnotherModule
. params
keeps just the attributes that are JAX arrays, and static
keeps everything else. Then combine
just merges the two PyTrees back together afterwards.
Integrates smoothly with JAX
And that's it! That's pretty much all of Equinox.
Equinox introduces a powerful yet straightforward way to build neural networks, without introducing lots of new notions or tieing you into a framework.
Equinox is all just regular JAX -- PyTrees and transformations. Together, these two pieces allow us to specify complex models in JAX-friendly ways.
Examples
-
build_model.py
builds an MLP from scratch, demonstrating the easy parameterised-functions-as-data approach that Equinox introduces. We'll then pass it into higher-order functions like JIT and grad. Overall we produce models using a familiar class-based syntax, that are also functional and integrate directly with JAX's JIT/autograd. -
filtered_transformations.py
introducesequinox.filter_jit
andequinox.filter_grad
. These will be used to select the parameters of an MLP and train them. -
frozen_layer.py
demonstrates how this approach really shines: some of the parameters will be trained, some of them will be frozen, but all of them will be efficiently JIT-traced. -
train_rnn.py
trains an RNN on a toy clockwise/anticlockwise spiral classification problem. -
modules_to_initapply.py
demonstrates how to use Equinox in an init/apply-style way, which some JAX libraries have been built around. (e.g. Stax)
API
Full API list
# Module # Neural networks
equinox.Module equinox.nn.Linear
equinox.nn.Identity
# Filtering/combining equinox.nn.Dropout
equinox.filter equinox.nn.GRUCell
equinox.partition equinox.nn.LSTMCell
equinox.combine equinox.nn.Sequential
equinox.nn.MLP
# Filtered transformations
equinox.filter_jit # Utilities
equinox.filter_grad equinox.apply_updates
equinox.filter_value_and_grad equinox.static_field
equinox.tree_at
# Filters equinox.tree_equal
equinox.is_array
equinox.is_array_like
equinox.is_inexact_array
equinox.is_inexact_array_like
Module
equinox.Module
Base class; create your model by inheriting from this.
Specify all its attributes at the class level (identical to dataclasses). This defines its children in the PyTree.
class MyModule(equinox.Module):
weight: typing.Any
bias: typing.Any
submodule: Module
In this case a default __init__
method is provided, which just fills in these attributes with the argments passed: MyModule(weight, bias, submodule)
. Alternatively you can provide an __init__
method yourself. (For example to specify dimension sizes instead of raw weights.) By the end of __init__
, every attribute must have been assigned.
class AnotherModule(equinox.Module):
weight: Any
def __init__(self, input_size, output_size, key):
self.weight = jax.random.normal(key, (output_size, input_size))
After initialisation then attributes cannot be modified: models are immutable as per functional programming. (Parameter updates are made by creating a new model, not by mutating parameters in-place; see for example train_rnn.py
.)
It is typical to also create some methods on the class. As self
will be an input parameter -- treated as a PyTree -- then these methods will get access to the attributes of the instance. Defining __call__
gives an easy way to define a forward pass for a model (although any method can be used, and no methods are special-cased):
class LinearWithoutBias(equinox.Module):
weight: Any
def __call__(self, x):
return self.weight @ x
If defining a method meth
, then take care not to write instance = MyModule(...); jax.jit(instance.meth)(...)
. (Or similarly with jax.grad
, equinox.filter_jit
etc.) This is because instance.meth
is not a pure function as it already has the self
parameter passed implicitly. Instead do
@jax.jit
def func(instance, args):
instance.meth(args)
# Also use this pattern with instance(args) if you defined `__call__` instead of `meth`.
Filtering/combining
Filtering can be used to organise the contents of PyTrees.
equinox.filter(pytree, filter_spec, inverse=False, replace=None)
Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying the condition are replaced with replace
.
pytree
is any PyTreefilter_spec
is a PyTree whose structure should be a prefix of the structure ofpytree
. Each of its leaves should either be:True
, in which case the leaf or subtree is kept;False
, in which case the leaf or subtree is replaced withreplace
;- a callable
Leaf -> bool
, in which case this is evaluted on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
inverse
switches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.replace
is what to replace any falsey leaves with. Defaults toNone
.
Returns a PyTree of the same structure as pytree
.
An important special case is something like equinox.filter(pytree, equinox.is_array)
. Then equinox.is_array
is evaluted on all of pytree
's leaves, and each leaf then kept or replaced.
See also equinox.combine
to reconstitute the PyTree again.
equinox.partition(pytree, filter_spec, replace=None)
Equivalent to filter(...), filter(..., inverse=True)
, but slightly more efficient.
equinox.combine(*pytrees)
Every element of pytrees
must be a PyTree of the same structure. The return value is also a PyTree of the same structure. Each leaf will be the first non-None
leaf found in the corresponding leaves of pytrees
, as they are iterated over. The intention is that this be used to undo a call to equinox.filter
or equinox.partition
.
Filtered transformations
It's very common to need to filter just to handle JAX transformations. Equinox provides the following convenience wrappers.
They're not designed to handle every edge case -- they're just a way to streamline the common cases. Use separate equinox.filter
+jax.jit
etc. if you need finer control.
equinox.filter_jit(fun, *, filter_spec=is_array, **kwargs)
Wraps jax.jit
.
fun
is a pure function to JIT compile.filter_spec
is a PyTree whose structure should be a prefix of the structure of the inputs tofun
. Each of its leaves should either beTrue
,False
, or a callableLeaf -> bool
. It behaves exactly as thefilter_spec
argument toequinox.filter
. Truthy values will be traced; falsey values will be held static. Specifically, if callingfun(*args, **kwargs)
, thenfilter_spec
must have a structure which is a prefix for(args, kwargs)
.**kwargs
are any other arguments tojax.jit
.
An important special case is to pass a function as filter_spec
, which will be applied to every leaf of every input. For example, equinox.filter_jit(fun, equinox.is_array)
.
See also equinox.is_array
, which is the default choice of filter_spec
. This will trace every JAX array, and make every other argument static.
equinox.filter_grad(fun, *, filter_spec=is_inexact_array, **kwargs)
Wraps jax.grad
.
fun
is a pure function to differentiate.filter_spec
is a PyTree whose structure should be a prefix of the structure of the first input tofun
. Each of its leaves should either beTrue
,False
, or a callableLeaf -> bool
. It behaves exactly as thefilter_spec
argument toequinox.filter
. Truthy values will be differentiated; falsey values will not. Specifically, if callingfun(x, *args, **kwargs)
, thenfilter_spec
must have a structure which is a prefix for the structure ofx
.**kwargs
are any other arguments tojax.grad
.
An important special case is to pass a function as filter_spec
, which will be applied to every leaf of the first input. For example, equinox.filter_grad(fun, equinox.is_inexact_array)
.
See also equinox.is_inexact_array
, which is the default choice of filter_spec
. This will differentiate all floating-point JAX arrays.
Note that as the returned gradients must have the same structure as the inputs, then all nondifferentiable components of the input PyTree will have gradient None
. See equinox.apply_updates
for a convenience to only apply non-None
updates.
equinox.filter_value_and_grad(fun, *, filter_spec=is_inexact_array, **kwargs)
Wraps jax.value_and_grad
. Arguments are as equinox.filter_grad
.
Filters
Any function Any -> bool
can be used as a filter. We provide some convenient common choices.
equinox.is_array(element)
Returns True
if element
is a JAX array (not but a NumPy array).
equinox.is_array_like(element)
Returns True
if element
is a JAX array, NumPy array, or a Python float/int/bool/complex.
equinox.is_inexact_array(element)
Returns True
if element
is a floating point JAX array (but not a NumPy array).
equinox.is_inexact_array_like(element)
Returns True
if element
is a floating point JAX array, floating point NumPy array, or a Python float or complex.
Utilities
equinox.apply_updates(model, updates)
Performs a training update to a model.
model
must be a PyTree;updates
must be a PyTree with the same structure.
It essentially performs jax.tree_map(lambda m, u: m + u, model, updates)
(or optax.apply_upates(model, updates)
). However anywhere updates
is None
then no update is made at all, so as to handle nondifferentiable parts of model
.
The returned value is the updated model. (model
is not mutated in place, as is usual in JAX and functional programming.)
To produce updates
, it is typical to take the gradients from the loss function, and then adjust them according to any standard optimiser; for example Optax provides optax.sgd
or optax.adam
.
equinox.static_field(**kwargs)
This is a relatively advanced feature. Use it to mark one of the fields of a Module
as being "static": that is, never differentiated, and always a static_argnum
to JIT. Best used only if you control whatever will be assigned to that field. For example equinox.nn.MLP
does not use this for its activation function, as in principle a learnt activation function could be passed.
Example:
class MyModule(equinox.Module):
value: list = equinox.static_field()
If any **kwargs
are passed, then they will be forwarded on to dataclasses.field
. (Recall that Equinox uses dataclasses as its modules, so general dataclasses
behaviour should work as normal.)
equinox.tree_at(where, pytree, replace=_sentinel, replace_fn=_sentinel)
Modifies an existing tree, and returns the modified tree. (Like .at
for "in place modifications" of JAX arrays.)
where
is a callablePyTree -> Leaf
orPyTree -> Tuple[Leaf, ...]
. It should consume a PyTree of the same shape aspytree
, and return the leaf or leaves that should be replaced. For examplewhere=lambda mlp: mlp.layers[-1].linear.weight
.pytree
is the existing PyTree to modify.replace
should either be a single element, or a tuple of the same length as returned bywhere
. This specifies the replacements to make at the locations specified bywhere
. Mutually exclusive withreplace_fn
.replace_fn
should be a functionLeaf -> Any
. It will be called on every leaf replaced usingwhere
. The return value fromreplace_fn
will be used in its place. Mutually exclusive withreplace
.
For example this can be used to help specify the weights of a model to train or not train:
trainable = jax.tree_map(lambda _: False, model)
trainable = equinox.tree_at(lambda mlp: mlp.layers[-1].linear.weight, model, replace=True)
equinox.filter_grad(..., filter_spec=trainable)
equinox.tree_equal(*pytrees)
Returns True
if all PyTrees in the list are equal. All arrays must have the same shape, dtype, and values. JAX arrays and NumPy arrays are not considered equal.
Neural network library
Equinox includes a small neural network library, mostly as a tech demo for how the rest of the library can be used. Its API is broadly modelled after PyTorch.
equinox.nn.Linear(in_features, out_features, use_bias=True, *, key)(input)
equinox.nn.Identity(*args, **kwargs)(input) # args and kwargs are ignored
equinox.nn.Dropout(p=0.5, deterministic=False)(input, *, key=None, deterministic=None)
equinox.nn.GRUCell(input_size, hidden_size, use_bias=True, *, key)(input, hidden)
equinox.nn.LSTMCell(input_size, hidden_size, use_bias=True, *, key)(input, hidden)
equinox.nn.Sequential(layers)(input, *, key=None)
equinox.nn.MLP(in_size, out_size, width_size, depth,
activation=jax.nn.relu, final_activation=lambda x: x, *, key)(input)
These all behave in the way you expect. The key
arguments are used to generate the random initial weights, or to generate randomness on the forward pass of stochastic layers like Dropout
.
The Dropout(deterministic=...)(deterministic=...)
options determines whether to have the layer act as the identity function, as is commonly done with dropout during inference time. The call-time deterministic
takes precendence if it passed; otherwise the init-time deterministic
is used. (Note that because models are PyTrees, you can modify the init-time deterministic
flag using equinox.tree_at
. This is perfectly fine, and might be handy if it's easier than using the call-time flag.)
The MLP(final_activation=...)
option determines any final activation function to apply after the last layer. (In some cases it is desirable for this to be different to the activation used in the main part of the network.)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file equinox-0.1.2.tar.gz
.
File metadata
- Download URL: equinox-0.1.2.tar.gz
- Upload date:
- Size: 27.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.8.11
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 920ac4afab692de638c92b1602eebad0f25771719228393c9df8625d96df3c64 |
|
MD5 | 1080074de04c8994dd344aa560cbde99 |
|
BLAKE2b-256 | bf264b7414759b161a6bd2aa5f4276d1edb4247686b06aab916478a3c33a15ce |
File details
Details for the file equinox-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: equinox-0.1.2-py3-none-any.whl
- Upload date:
- Size: 23.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.8.11
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 299e5f0f9ec924892521221b5bff94537ebd82a3504d193fecd206c2f2daf5f8 |
|
MD5 | cc945c531e1d4b563a32604d2b35cf05 |
|
BLAKE2b-256 | 11bad6112c86a8f74f70af485216f2ea62c7c7fd74e715d86a6f2ce3f65df9b4 |