Skip to main content

Stainless neural networks in JAX

Project description

Inox's banner

Stainless neural networks in JAX

Inox is a minimal JAX library for neural networks with an intuitive PyTorch-like syntax. As with Equinox, modules are represented as PyTrees, which enables complex architectures, easy manipulations, and functional transformations.

Inox aims to be a leaner version of Equinox by only retaining its core features: PyTrees and lifted transformations. In addition, Inox takes inspiration from other projects like NNX and Serket to provide a versatile interface.

Inox means "stainless steel" in French 🔪

Installation

The inox package is available on PyPI, which means it is installable via pip.

pip install inox

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/francois-rozet/inox

Getting started

Modules are defined with an intuitive PyTorch-like syntax,

import jax
import inox.nn as nn

init_key, data_key = jax.random.split(jax.random.key(0))

class MLP(nn.Module):
    def __init__(self, key):
        keys = jax.random.split(key, 3)

        self.l1 = nn.Linear(3, 64, key=keys[0])
        self.l2 = nn.Linear(64, 64, key=keys[1])
        self.l3 = nn.Linear(64, 3, key=keys[2])
        self.relu = nn.ReLU()

    def __call__(self, x):
        x = self.l1(x)
        x = self.l2(self.relu(x))
        x = self.l3(self.relu(x))

        return x

model = MLP(init_key)

and are compatible with JAX transformations.

X = jax.random.normal(data_key, (1024, 3))
Y = jax.numpy.sort(X, axis=-1)

@jax.jit
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = jax.grad(loss_fn)(model, X, Y)

However, if a tree contains strings or boolean flags, it becomes incompatible with JAX transformations. For this reason, Inox provides lifted transformations that consider all non-array leaves as static.

model.name = 'stainless'

@inox.jit
def loss_fn(model, x, y):
    pred = inox.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = inox.grad(loss_fn)(model, X, Y)

For more information, check out the documentation at inox.readthedocs.io.

Contributing

If you have a question, an issue or would like to contribute, please read our contributing guidelines.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

inox-0.5.0.tar.gz (28.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

inox-0.5.0-py3-none-any.whl (33.9 kB view details)

Uploaded Python 3

File details

Details for the file inox-0.5.0.tar.gz.

File metadata

  • Download URL: inox-0.5.0.tar.gz
  • Upload date:
  • Size: 28.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for inox-0.5.0.tar.gz
Algorithm Hash digest
SHA256 05d750b0d3bc13c367feb57f6648b65977d9e6600b96ec36f31ab0faf70ea336
MD5 7202ef080910e64af2c044218cae618e
BLAKE2b-256 f9b616bacb128da4a24cfcfae54ce2a8625546be71273a84b199577905cb7674

See more details on using hashes here.

File details

Details for the file inox-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: inox-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 33.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for inox-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 af8ea976de2d51bfb1ec84fa4809465e8f13d0c5c7a1f861669569751aa658cc
MD5 7d0840a0f93b743c6690ff866f6ea427
BLAKE2b-256 c969d6f48fa06166135f0922c0e318b47009dae9f86dc2c4fc9cecdfa58d9fe2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page