Skip to main content

PyTorch-like neural networks in JAX

Project description

Equinox

Equinox is a JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees.

In doing so:

  • We get a PyTorch-like API...
  • ...that's fully compatible with native JAX transformations...
  • ...with no new concepts you have to learn. (It's all just PyTrees.)

The elegance of Equinox is its selling point in a world that already has Haiku, Flax and so on.

(In other words, why should you care? Because Equinox is really simple to learn, and really simple to use.)

Installation

pip install equinox

Requires Python 3.7+ and JAX 0.3.4+.

Documentation

Available at https://docs.kidger.site/equinox.

Quick example

Models are defined using PyTorch-like syntax:

import equinox as eqx
import jax

class Linear(eqx.Module):
    weight: jax.numpy.ndarray
    bias: jax.numpy.ndarray

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

and fully compatible with normal JAX operations:

@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)

Finally, there's no magic behind the scenes. All eqx.Module does is register your class as a PyTree. From that point onwards, JAX already knows how to work with PyTrees.

Citation

If you found this library to be useful in academic work, then please cite: (arXiv link)

@article{kidger2021equinox,
    author={Patrick Kidger and Cristian Garcia},
    title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
    year={2021},
    journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}

(Also consider starring the project on GitHub.)

See also

Numerical differential equation solvers: Diffrax.

Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: jaxtyping.

SymPy<->JAX conversion; train symbolic expressions via gradient descent: sympy2jax.

Project details


Release history Release notifications | RSS feed

This version

0.7.1

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

equinox-0.7.1.tar.gz (57.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

equinox-0.7.1-py3-none-any.whl (68.5 kB view details)

Uploaded Python 3

File details

Details for the file equinox-0.7.1.tar.gz.

File metadata

  • Download URL: equinox-0.7.1.tar.gz
  • Upload date:
  • Size: 57.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for equinox-0.7.1.tar.gz
Algorithm Hash digest
SHA256 41983cb0abfd06b023800908193d04d7da1bde033e25f6d61dba66463aeb581b
MD5 50afa6649e72603665b2069605f0a274
BLAKE2b-256 7841aaf9e90ce6ac5b540b051b5df5bd8aa1ff0c1c99104ab5c7313d3cb13b62

See more details on using hashes here.

File details

Details for the file equinox-0.7.1-py3-none-any.whl.

File metadata

  • Download URL: equinox-0.7.1-py3-none-any.whl
  • Upload date:
  • Size: 68.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for equinox-0.7.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b882e0c063cae07fc74c55707ac7e7967f210058b99dee72f5d8825af0db7d51
MD5 e158f85ad247be07654382cedf7c4b30
BLAKE2b-256 9f967f6922ed7d3f1a67927443fadc66a78adcbbeb4fc021c10d309367473af3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page