Skip to main content

PyTorch-like neural networks in JAX

Project description

Equinox

Equinox is a JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees.

In doing so:

  • We get a PyTorch-like API...
  • ...that's fully compatible with native JAX transformations...
  • ...with no new concepts you have to learn. (It's all just PyTrees.)

The elegance of Equinox is its selling point in a world that already has Haiku, Flax and so on.

(In other words, why should you care? Because Equinox is really simple to learn, and really simple to use.)

Installation

pip install equinox

Requires Python 3.7+ and JAX 0.3.4+.

Documentation

Available at https://docs.kidger.site/equinox.

Quick example

Models are defined using PyTorch-like syntax:

import equinox as eqx
import jax

class Linear(eqx.Module):
    weight: jax.numpy.ndarray
    bias: jax.numpy.ndarray

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

and fully compatible with normal JAX operations:

@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)

Finally, there's no magic behind the scenes. All eqx.Module does is register your class as a PyTree. From that point onwards, JAX already knows how to work with PyTrees.

Citation

If you found this library to be useful in academic work, then please cite: (arXiv link)

@article{kidger2021equinox,
    author={Patrick Kidger and Cristian Garcia},
    title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
    year={2021},
    journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}

(Also consider starring the project on GitHub.)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

equinox-0.4.0.tar.gz (40.3 kB view details)

Uploaded Source

Built Distribution

equinox-0.4.0-py3-none-any.whl (48.2 kB view details)

Uploaded Python 3

File details

Details for the file equinox-0.4.0.tar.gz.

File metadata

  • Download URL: equinox-0.4.0.tar.gz
  • Upload date:
  • Size: 40.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for equinox-0.4.0.tar.gz
Algorithm Hash digest
SHA256 4c35d34a1f518fd7a713faac849e9e949d9a494dbe9d3bcb09897a285a050732
MD5 0e7fe43be244af3f30e2d89c99e44a13
BLAKE2b-256 d2856459a95347d6959242f9d6e5c6a7394ac61a1d7cc810f864dbdecdb2d678

See more details on using hashes here.

File details

Details for the file equinox-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: equinox-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 48.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for equinox-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bc369f665a7185902a4619de19d341d9e348639fb95f91c4d6bdca582a9b397d
MD5 5ca8187995e932598a0adf5caaed4382
BLAKE2b-256 27bcdb2ab3f4350c35260d689c09ded957b74435da8211a42a50817f9ac27d65

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page