Skip to main content

Elegy is a Neural Networks framework based on Jax and Haiku.

Project description

Elegy

PyPI Status Badge Coverage PyPI - Python Version Documentation Code style: black Contributions welcome Status


Elegy is a Neural Networks framework based on Jax inspired by Keras and Haiku.

Elegy implements the Keras API but makes changes to play better with Jax and gives more flexibility around losses and metrics, it also ports Haiku's excellent module system and makes it easier to use. Elegy is in an early stage, feel free to send us your feedback!

Main Features

  • Familiar: Elegy should feel very familiar to Keras users.
  • Flexible: Elegy improves upon the basic Keras API by letting users optionally take more control over the definition of losses and metrics.
  • Easy-to-use: Elegy maintains all the simplicity and ease of use that Keras brings with it.
  • Compatible: Elegy strives to be compatible with the rest of the Jax ecosystem.

For more information take a look at the Documentation.

Installation

Install Elegy using pip:

pip install elegy

For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.

Quick Start

Elegy greatly simplifies the training of Deep Learning models compared to pure Jax where, due to Jax's functional nature, users have to do a lot of book keeping around the state of the model. In Elegy you just have to follow 3 basic steps:

1. Define the architecture inside an elegy.Module:

class MLP(elegy.Module):
    def call(self, x: jnp.ndarray) -> jnp.ndarray:
        x = elegy.nn.Linear(300)(x)
        x = jax.nn.relu(x)
        x = elegy.nn.Linear(10)(x)
        return x

Note that we can define sub-modules on-the-fly directly in the call (forward) method.

2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:

model = elegy.Model(
    module=MLP(),
    loss=[
        elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
        elegy.regularizers.GlobalL2(l=1e-5),
    ],
    metrics=elegy.metrics.SparseCategoricalAccuracy(),
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

And you are done! For more information check out:

Why Jax & Elegy?

Given all the well-stablished Deep Learning framework like TensorFlow + Keras or Pytorch + Pytorch-Lightning/Skorch, it is fair to ask why we need something like Jax + Elegy? Here are some of the reasons why this framework exists.

Why Jax?

Jax is a linear algebra library with the perfect recipe:

  • Numpy's familiar API
  • The speed and hardware support of XLA
  • Automatic Differentiation

The awesome thing about Jax is that Deep Learning is just a use-case that it happens to excel at but you can use it for most task you would use NumPy for. Jax is so compatible with Numpy that is array type actually inherits from np.ndarray.

In a sense, Jax takes the best of both TensorFlow and Pytorch in a principled manner: while both TF and Pytorch historically converged to the same set of features, their APIs still contain quirks they have to keep for compatibility.

Why Elegy?

We believe that Elegy can offer the best experience for coding Deep Learning applications by leveraging the power and familiarity of Jax API, an easy-to-use and succinct Module system, and packaging everything on top of a convenient Keras-like API. Elegy improves upon other Deep Learning frameworks in the following ways:

  1. Its hook-based Module System makes it easier (less verbose) to write model code compared to Keras & Pytorch since it lets you declare sub-modules, parameters, and states directly on your call (forward) method. Thanks to this you get shape inference for free so there is no need for a build method (Keras) or propagating shape information all over the place (Pytorch). A naive implementation of Linear could be as simple as:
class Linear(elegy.Module):
    def __init__(self, units):
        super().__init__()
        self.units = units

    def call(self, x):
        w = elegy.get_parameter("w", [x.shape[-1], self.units], initializer=jnp.ones)
        b = elegy.get_parameter("b", [self.units], initializer=jnp.ones)

        return jnp.dot(x, w) + b
  1. It has a very flexible system for defining the inputs for losses and metrics based on dependency injection in opposition to Keras rigid requirement to have matching (output, label) pairs, and being unable to use additional information like inputs, parameters, and states in the definition of losses and metrics.
  2. Its hook system preserve's reference information from a module to its sub-modules, parameters, and states while maintaining a functional API. This is crucial since most Jax-based frameworks like Flax and Haiku tend to loose this information which makes it very tricky to perform tasks like transfer learning where you need to mix a pre-trained models into a new model (easier to do if you keep references).

Features

  • Model estimator class
  • losses module
  • metrics module
  • regularizers module
  • callbacks module
  • nn layers module

For more information checkout the Reference API section in the Documentation.

Contributing

Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contibute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contibuting Guide.

About Us

We are some friends passionate about ML.

License

Apache

Citing Elegy

To cite this project:

BibTeX

@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A Keras-like deep learning framework based on Jax},
url = {https://github.com/poets-ai/elegy},
version = {0.2.2},
year = {2020},
}

Where the current version may be retrieved either from the Release tag or the file elegy/__init__.py and the year corresponds to the project's release year.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

elegy-0.2.2.tar.gz (94.9 kB view hashes)

Uploaded Source

Built Distribution

elegy-0.2.2-py3-none-any.whl (141.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page