Elegy is a Neural Networks framework based on Jax and Haiku.
Project description
Elegy
A High Level API for Deep Learning in JAX
Main Features
- 😀 Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to use for most common tasks.
- 💪 Flexible: Elegy provides a Pytorch Lightning-like low-level API that offers maximum flexibility when needed.
- 🔌 Compatible: Elegy supports various frameworks and data sources including Flax & Haiku Modules, Optax Optimizers, TensorFlow Datasets, Pytorch DataLoaders, and more.
Elegy is built on top of Treex and Treeo and reexports their APIs for convenience.
Getting Started | Examples | Documentation
What is included?
- A
Model
class with an Estimator-like API. - A
callbacks
module with common Keras callbacks.
From Treex
- A
Module
class. - A
nn
module for with common layers. - A
losses
module with common loss functions. - A
metrics
module with common metrics.
Installation
Install using pip:
pip install elegy
For Windows users, we recommend the Windows subsystem for Linux 2 WSL2 since jax does not support it yet.
Quick Start: High-level API
Elegy's high-level API provides a straightforward interface you can use by implementing the following steps:
1. Define the architecture inside a Module
:
import jax
import elegy as eg
class MLP(eg.Module):
@eg.compact
def __call__(self, x):
x = eg.Linear(300)(x)
x = jax.nn.relu(x)
x = eg.Linear(10)(x)
return x
2. Create a Model
from this module and specify additional things like losses, metrics, and optimizers:
import optax optax
import elegy as eg
model = eg.Model(
module=MLP(),
loss=[
eg.losses.Crossentropy(),
eg.regularizers.L2(l=1e-5),
],
metrics=eg.metrics.Accuracy(),
optimizer=optax.rmsprop(1e-3),
)
3. Train the model using the fit
method:
model.fit(
inputs=X_train,
labels=y_train,
epochs=100,
steps_per_epoch=200,
batch_size=64,
validation_data=(X_test, y_test),
shuffle=True,
callbacks=[eg.callbacks.TensorBoard("summaries")]
)
Using Flax
Show
To use Flax with Elegy just create a flax.linen.Module
and pass it to Model
.
import jax
import elegy as eg
import optax optax
import flax.linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(300)(x)
x = jax.nn.relu(x)
x = nn.Dense(10)(x)
return x
model = eg.Model(
module=MLP(),
loss=[
eg.losses.Crossentropy(),
eg.regularizers.L2(l=1e-5),
],
metrics=eg.metrics.Accuracy(),
optimizer=optax.rmsprop(1e-3),
)
As shown here, Flax Modules can optionally request a training
argument to __call__
which will be provided by Elegy / Treex.
Using Haiku
Show
To use Haiku with Elegy do the following:
- Create a
forward
function. - Create a
TransformedWithState
object by feedingforward
tohk.transform_with_state
. - Pass your
TransformedWithState
toModel
.
You can also optionally create your own hk.Module
and use it in forward
if needed. Putting everything together should look like this:
import jax
import elegy as eg
import optax optax
import haiku as hk
def forward(x, training: bool):
x = hk.Linear(300)(x)
x = jax.nn.relu(x)
x = hk.Linear(10)(x)
return x
model = eg.Model(
module=hk.transform_with_state(forward),
loss=[
eg.losses.Crossentropy(),
eg.regularizers.L2(l=1e-5),
],
metrics=eg.metrics.Accuracy(),
optimizer=optax.rmsprop(1e-3),
)
As shown here, forward
can optionally request a training
argument which will be provided by Elegy / Treex.
Quick Start: Low-level API
Elegy's low-level API lets you explicitly define what goes on during training, testing, and inference. Let's define our own custom Model
to implement a LinearClassifier
with pure JAX:
1. Define a custom init_step
method:
class LinearClassifier(eg.Model):
# use treex's API to declare parameter nodes
w: jnp.ndarray = eg.Parameter.node()
b: jnp.ndarray = eg.Parameter.node()
def init_step(self, key: jnp.ndarray, inputs: jnp.ndarray):
self.w = jax.random.uniform(
key=key,
shape=[features_in, 10],
)
self.b = jnp.zeros([10])
self.optimizer = self.optimizer.init(self)
return self
Here we declared the parameters w
and b
using Treex's Parameter.node()
for pedagogical reasons, however normally you don't have to do this since you typically use a sub-Module
instead.
2. Define a custom test_step
method:
def test_step(self, inputs, labels):
# flatten + scale
inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255
# forward
logits = jnp.dot(inputs, self.w) + self.b
# crossentropy loss
target = jax.nn.one_hot(labels["target"], 10)
loss = optax.softmax_cross_entropy(logits, target).mean()
# metrics
logs = dict(
acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]),
loss=loss,
)
return loss, logs, self
3. Instantiate our LinearClassifier
with an optimizer:
model = LinearClassifier(
optimizer=optax.rmsprop(1e-3),
)
4. Train the model using the fit
method:
model.fit(
inputs=X_train,
labels=y_train,
epochs=100,
steps_per_epoch=200,
batch_size=64,
validation_data=(X_test, y_test),
shuffle=True,
callbacks=[eg.callbacks.TensorBoard("summaries")]
)
Using other JAX Frameworks
Show
It is straightforward to integrate other functional JAX libraries with this low-level API, here is an example with Flax:
import elegy as eg
import flax.linen as nn
class LinearClassifier(eg.Model):
params: Mapping[str, Any] = eg.Parameter.node()
batch_stats: Mapping[str, Any] = eg.BatchStat.node()
next_key: eg.KeySeq
def __init__(self, module: nn.Module, **kwargs):
self.flax_module = module
super().__init__(**kwargs)
def init_step(self, key, inputs):
self.next_key = eg.KeySeq(key)
variables = self.flax_module.init(
{"params": self.next_key(), "dropout": self.next_key()}, x
)
self.params = variables["params"]
self.batch_stats = variables["batch_stats"]
self.optimizer = self.optimizer.init(self.parameters())
def test_step(self, inputs, labels):
# forward
variables = dict(
params=self.params,
batch_stats=self.batch_stats,
)
logits, variables = self.flax_module.apply(
variables,
inputs,
rngs={"dropout": self.next_key()},
mutable=True,
)
self.batch_stats = variables["batch_stats"]
# loss
target = jax.nn.one_hot(labels["target"], 10)
loss = optax.softmax_cross_entropy(logits, target).mean()
# logs
logs = dict(
accuracy=accuracy,
loss=loss,
)
return loss, logs, self
Examples
Check out the /example directory for some inspiration. To run an example, first install some requirements:
pip install -r examples/requirements.txt
And the run it normally with python e.g.
python examples/flax/mnist_vae.py
Contributing
If your are interested in helping improve Elegy check out the Contributing Guide.
Sponsors 💚
- Quansight - paid development time
Citing Elegy
BibTeX
@software{elegy2020repository,
title = {Elegy: A High Level API for Deep Learning in JAX},
author = {PoetsAI},
year = 2021,
url = {https://github.com/poets-ai/elegy},
version = {0.8.1}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file elegy-0.8.6.tar.gz
.
File metadata
- Download URL: elegy-0.8.6.tar.gz
- Upload date:
- Size: 60.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.4 CPython/3.8.12 Linux/5.11.0-1028-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1c21b2854b4c55395ec2cf5bfc110990eb67b9206168f3c4ed38dd66f9f0bb48 |
|
MD5 | 5d9c75b6b6ef1689624f6138f35e7945 |
|
BLAKE2b-256 | 31262efee9bb7bcb8b0d2f4d5d77a630f0cd5da71ee5e89d721c29855297c896 |
File details
Details for the file elegy-0.8.6-py3-none-any.whl
.
File metadata
- Download URL: elegy-0.8.6-py3-none-any.whl
- Upload date:
- Size: 72.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.4 CPython/3.8.12 Linux/5.11.0-1028-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ce6ab4653a573581c1b1faceee3c114d1045df9ba37913f10a8e02fd30744b42 |
|
MD5 | 646200d1adfad0ceb3f9067166f3eb02 |
|
BLAKE2b-256 | daef96bfb702632bf3d5f82185ddac7d5420ad3e183b315e67459919dc304f30 |