Skip to main content

Elegy is a Neural Networks framework based on Jax and Haiku.

Project description

Elegy

Coverage Status Contributions welcome


A High Level API for Deep Learning in JAX

Main Features

  • 😀 Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to use for most common tasks.
  • 💪‍ Flexible: Elegy provides a Pytorch Lightning-like low-level API that offers maximum flexibility when needed.
  • 🔌 Compatible: Elegy supports various frameworks and data sources including Flax & Haiku Modules, Optax Optimizers, TensorFlow Datasets, Pytorch DataLoaders, and more.

Elegy is built on top of Treex and Treeo and reexports their APIs for convenience.

Getting Started | Examples | Documentation

What is included?

  • A Model class with an Estimator-like API.
  • A callbacks module with common Keras callbacks.

From Treex

  • A Module class.
  • A nn module for with common layers.
  • A losses module with common loss functions.
  • A metrics module with common metrics.

Installation

Install using pip:

pip install elegy

For Windows users, we recommend the Windows subsystem for Linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a straightforward interface you can use by implementing the following steps:

1. Define the architecture inside a Module:

import jax
import elegy as eg

class MLP(eg.Module):
    @eg.compact
    def __call__(self, x):
        x = eg.Linear(300)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(10)(x)
        return x

2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:

import optax optax
import elegy as eg

model = eg.Model(
    module=MLP(),
    loss=[
        eg.losses.Crossentropy(),
        eg.regularizers.L2(l=1e-5),
    ],
    metrics=eg.metrics.Accuracy(),
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    inputs=X_train,
    labels=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[eg.callbacks.TensorBoard("summaries")]
)

Using Flax

Show

To use Flax with Elegy just create a flax.linen.Module and pass it to Model.

import jax
import elegy as eg
import optax optax
import flax.linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x, training: bool):
        x = nn.Dense(300)(x)
        x = jax.nn.relu(x)
        x = nn.Dense(10)(x)
        return x


model = eg.Model(
    module=MLP(),
    loss=[
        eg.losses.Crossentropy(),
        eg.regularizers.L2(l=1e-5),
    ],
    metrics=eg.metrics.Accuracy(),
    optimizer=optax.rmsprop(1e-3),
)

As shown here, Flax Modules can optionally request a training argument to __call__ which will be provided by Elegy / Treex.

Using Haiku

Show

To use Haiku with Elegy do the following:

  • Create a forward function.
  • Create a TransformedWithState object by feeding forward to hk.transform_with_state.
  • Pass your TransformedWithState to Model.

You can also optionally create your own hk.Module and use it in forward if needed. Putting everything together should look like this:

import jax
import elegy as eg
import optax optax
import haiku as hk


def forward(x, training: bool):
    x = hk.Linear(300)(x)
    x = jax.nn.relu(x)
    x = hk.Linear(10)(x)
    return x


model = eg.Model(
    module=hk.transform_with_state(forward),
    loss=[
        eg.losses.Crossentropy(),
        eg.regularizers.L2(l=1e-5),
    ],
    metrics=eg.metrics.Accuracy(),
    optimizer=optax.rmsprop(1e-3),
)

As shown here, forward can optionally request a training argument which will be provided by Elegy / Treex.

Quick Start: Low-level API

Elegy's low-level API lets you explicitly define what goes on during training, testing, and inference. Let's define our own custom Model to implement a LinearClassifier with pure JAX:

1. Define a custom init_step method:

class LinearClassifier(eg.Model):
    # use treex's API to declare parameter nodes
    w: jnp.ndarray = eg.Parameter.node()
    b: jnp.ndarray = eg.Parameter.node()

    def init_step(self, key: jnp.ndarray, inputs: jnp.ndarray):
        self.w = jax.random.uniform(
            key=key,
            shape=[features_in, 10],
        )
        self.b = jnp.zeros([10])

        self.optimizer = self.optimizer.init(self)

        return self

Here we declared the parameters w and b using Treex's Parameter.node() for pedagogical reasons, however normally you don't have to do this since you typically use a sub-Module instead.

2. Define a custom test_step method:

    def test_step(self, inputs, labels):
        # flatten + scale
        inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255

        # forward
        logits = jnp.dot(inputs, self.w) + self.b

        # crossentropy loss
        target = jax.nn.one_hot(labels["target"], 10)
        loss = optax.softmax_cross_entropy(logits, target).mean()

        # metrics
        logs = dict(
            acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]),
            loss=loss,
        )

        return loss, logs, self

3. Instantiate our LinearClassifier with an optimizer:

model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)

4. Train the model using the fit method:

model.fit(
    inputs=X_train,
    labels=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[eg.callbacks.TensorBoard("summaries")]
)

Using other JAX Frameworks

Show

It is straightforward to integrate other functional JAX libraries with this low-level API, here is an example with Flax:

import elegy as eg
import flax.linen as nn

class LinearClassifier(eg.Model):
    params: Mapping[str, Any] = eg.Parameter.node()
    batch_stats: Mapping[str, Any] = eg.BatchStat.node()
    next_key: eg.KeySeq

    def __init__(self, module: nn.Module, **kwargs):
        self.flax_module = module
        super().__init__(**kwargs)

    def init_step(self, key, inputs):
        self.next_key = eg.KeySeq(key)

        variables = self.flax_module.init(
            {"params": self.next_key(), "dropout": self.next_key()}, x
        )
        self.params = variables["params"]
        self.batch_stats = variables["batch_stats"]

        self.optimizer = self.optimizer.init(self.parameters())

    def test_step(self, inputs, labels):
        # forward
        variables = dict(
            params=self.params,
            batch_stats=self.batch_stats,
        )
        logits, variables = self.flax_module.apply(
            variables,
            inputs, 
            rngs={"dropout": self.next_key()}, 
            mutable=True,
        )
        self.batch_stats = variables["batch_stats"]
        
        # loss
        target = jax.nn.one_hot(labels["target"], 10)
        loss = optax.softmax_cross_entropy(logits, target).mean()

        # logs
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, self

Examples

Check out the /example directory for some inspiration. To run an example, first install some requirements:

pip install -r examples/requirements.txt

And the run it normally with python e.g.

python examples/flax/mnist_vae.py

Contributing

If your are interested in helping improve Elegy check out the Contributing Guide.

Sponsors 💚

Citing Elegy

BibTeX

@software{elegy2020repository,
	title        = {Elegy: A High Level API for Deep Learning in JAX},
	author       = {PoetsAI},
	year         = 2021,
	url          = {https://github.com/poets-ai/elegy},
	version      = {0.8.1}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

elegy-0.9.0.dev1.tar.gz (71.7 kB view hashes)

Uploaded Source

Built Distribution

elegy-0.9.0.dev1-py3-none-any.whl (85.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page