Skip to main content

Elegy is a Neural Networks framework based on Jax and Haiku.

Project description

Elegy

Coverage Status Contributions welcome


A High Level API for Deep Learning in JAX

Main Features

  • 😀 Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to use for most common tasks.
  • 💪‍ Flexible: Elegy provides a Pytorch Lightning-like low-level API that offers maximum flexibility when needed.
  • 🔌 Compatible: Elegy supports various frameworks and data sources including Flax & Haiku Modules, Optax Optimizers, TensorFlow Datasets, Pytorch DataLoaders, and more.

Elegy is built on top of Treex and Treeo and reexports their APIs for convenience.

Getting Started | Examples | Documentation

What is included?

  • A Model class with an Estimator-like API.
  • A callbacks module with common Keras callbacks.

From Treex

  • A Module class.
  • A nn module for with common layers.
  • A losses module with common loss functions.
  • A metrics module with common metrics.

Installation

Install using pip:

pip install elegy

For Windows users, we recommend the Windows subsystem for Linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a straightforward interface you can use by implementing the following steps:

1. Define the architecture inside a Module:

import jax
import elegy as eg

class MLP(eg.Module):
    @eg.compact
    def __call__(self, x):
        x = eg.Linear(300)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(10)(x)
        return x

2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:

import optax optax
import elegy as eg

model = eg.Model(
    module=MLP(),
    loss=[
        eg.losses.Crossentropy(),
        eg.regularizers.L2(l=1e-5),
    ],
    metrics=eg.metrics.Accuracy(),
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    inputs=X_train,
    labels=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[eg.callbacks.TensorBoard("summaries")]
)

Using Flax

Show

To use Flax with Elegy just create a flax.linen.Module and pass it to Model.

import jax
import elegy as eg
import optax optax
import flax.linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x, training: bool):
        x = nn.Dense(300)(x)
        x = jax.nn.relu(x)
        x = nn.Dense(10)(x)
        return x


model = eg.Model(
    module=MLP(),
    loss=[
        eg.losses.Crossentropy(),
        eg.regularizers.L2(l=1e-5),
    ],
    metrics=eg.metrics.Accuracy(),
    optimizer=optax.rmsprop(1e-3),
)

As shown here, Flax Modules can optionally request a training argument to __call__ which will be provided by Elegy / Treex.

Using Haiku

Show

To use Haiku with Elegy do the following:

  • Create a forward function.
  • Create a TransformedWithState object by feeding forward to hk.transform_with_state.
  • Pass your TransformedWithState to Model.

You can also optionally create your own hk.Module and use it in forward if needed. Putting everything together should look like this:

import jax
import elegy as eg
import optax optax
import haiku as hk


def forward(x, training: bool):
    x = hk.Linear(300)(x)
    x = jax.nn.relu(x)
    x = hk.Linear(10)(x)
    return x


model = eg.Model(
    module=hk.transform_with_state(forward),
    loss=[
        eg.losses.Crossentropy(),
        eg.regularizers.L2(l=1e-5),
    ],
    metrics=eg.metrics.Accuracy(),
    optimizer=optax.rmsprop(1e-3),
)

As shown here, forward can optionally request a training argument which will be provided by Elegy / Treex.

Quick Start: Low-level API

Elegy's low-level API lets you explicitly define what goes on during training, testing, and inference. Let's define our own custom Model to implement a LinearClassifier with pure JAX:

1. Define a custom init_step method:

class LinearClassifier(eg.Model):
    # use treex's API to declare parameter nodes
    w: jnp.ndarray = eg.Parameter.node()
    b: jnp.ndarray = eg.Parameter.node()

    def init_step(self, key: jnp.ndarray, inputs: jnp.ndarray):
        self.w = jax.random.uniform(
            key=key,
            shape=[features_in, 10],
        )
        self.b = jnp.zeros([10])

        self.optimizer = self.optimizer.init(self)

        return self

Here we declared the parameters w and b using Treex's Parameter.node() for pedagogical reasons, however normally you don't have to do this since you typically use a sub-Module instead.

2. Define a custom test_step method:

    def test_step(self, inputs, labels):
        # flatten + scale
        inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255

        # forward
        logits = jnp.dot(inputs, self.w) + self.b

        # crossentropy loss
        target = jax.nn.one_hot(labels["target"], 10)
        loss = optax.softmax_cross_entropy(logits, target).mean()

        # metrics
        logs = dict(
            acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]),
            loss=loss,
        )

        return loss, logs, self

3. Instantiate our LinearClassifier with an optimizer:

model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)

4. Train the model using the fit method:

model.fit(
    inputs=X_train,
    labels=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[eg.callbacks.TensorBoard("summaries")]
)

Using other JAX Frameworks

Show

It is straightforward to integrate other functional JAX libraries with this low-level API, here is an example with Flax:

import elegy as eg
import flax.linen as nn

class LinearClassifier(eg.Model):
    params: Mapping[str, Any] = eg.Parameter.node()
    batch_stats: Mapping[str, Any] = eg.BatchStat.node()
    next_key: eg.KeySeq

    def __init__(self, module: nn.Module, **kwargs):
        self.flax_module = module
        super().__init__(**kwargs)

    def init_step(self, key, inputs):
        self.next_key = eg.KeySeq(key)

        variables = self.flax_module.init(
            {"params": self.next_key(), "dropout": self.next_key()}, x
        )
        self.params = variables["params"]
        self.batch_stats = variables["batch_stats"]

        self.optimizer = self.optimizer.init(self.parameters())

    def test_step(self, inputs, labels):
        # forward
        variables = dict(
            params=self.params,
            batch_stats=self.batch_stats,
        )
        logits, variables = self.flax_module.apply(
            variables,
            inputs, 
            rngs={"dropout": self.next_key()}, 
            mutable=True,
        )
        self.batch_stats = variables["batch_stats"]
        
        # loss
        target = jax.nn.one_hot(labels["target"], 10)
        loss = optax.softmax_cross_entropy(logits, target).mean()

        # logs
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, self

Examples

Check out the /example directory for some inspiration. To run an example, first install some requirements:

pip install -r examples/requirements.txt

And the run it normally with python e.g.

python examples/flax/mnist_vae.py

Contributing

If your are interested in helping improve Elegy check out the Contributing Guide.

Sponsors 💚

Citing Elegy

BibTeX

@software{elegy2020repository,
	title        = {Elegy: A High Level API for Deep Learning in JAX},
	author       = {PoetsAI},
	year         = 2021,
	url          = {https://github.com/poets-ai/elegy},
	version      = {0.8.1}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

elegy-0.8.6.tar.gz (60.3 kB view details)

Uploaded Source

Built Distribution

elegy-0.8.6-py3-none-any.whl (72.2 kB view details)

Uploaded Python 3

File details

Details for the file elegy-0.8.6.tar.gz.

File metadata

  • Download URL: elegy-0.8.6.tar.gz
  • Upload date:
  • Size: 60.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.4 CPython/3.8.12 Linux/5.11.0-1028-azure

File hashes

Hashes for elegy-0.8.6.tar.gz
Algorithm Hash digest
SHA256 1c21b2854b4c55395ec2cf5bfc110990eb67b9206168f3c4ed38dd66f9f0bb48
MD5 5d9c75b6b6ef1689624f6138f35e7945
BLAKE2b-256 31262efee9bb7bcb8b0d2f4d5d77a630f0cd5da71ee5e89d721c29855297c896

See more details on using hashes here.

File details

Details for the file elegy-0.8.6-py3-none-any.whl.

File metadata

  • Download URL: elegy-0.8.6-py3-none-any.whl
  • Upload date:
  • Size: 72.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.4 CPython/3.8.12 Linux/5.11.0-1028-azure

File hashes

Hashes for elegy-0.8.6-py3-none-any.whl
Algorithm Hash digest
SHA256 ce6ab4653a573581c1b1faceee3c114d1045df9ba37913f10a8e02fd30744b42
MD5 646200d1adfad0ceb3f9067166f3eb02
BLAKE2b-256 daef96bfb702632bf3d5f82185ddac7d5420ad3e183b315e67459919dc304f30

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page