Elegy is a Neural Networks framework based on Jax and Haiku.
Project description
Elegy
A High Level API for Deep Learning in JAX
Main Features
- 😀 Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to use for most common tasks.
- 💪 Flexible: Elegy provides a Pytorch Lightning-like low-level API that offers maximum flexibility when needed.
- 🔌 Compatible: Elegy supports various frameworks and data sources including Flax & Haiku Modules, Optax Optimizers, TensorFlow Datasets, Pytorch DataLoaders, and more.
Elegy is built on top of Treex and Treeo and reexports their APIs for convenience.
Getting Started | Examples | Documentation
What is included?
- A
Modelclass with an Estimator-like API. - A
callbacksmodule with common Keras callbacks.
From Treex
- A
Moduleclass. - A
nnmodule for with common layers. - A
lossesmodule with common loss functions. - A
metricsmodule with common metrics.
Installation
Install using pip:
pip install elegy
For Windows users, we recommend the Windows subsystem for Linux 2 WSL2 since jax does not support it yet.
Quick Start: High-level API
Elegy's high-level API provides a straightforward interface you can use by implementing the following steps:
1. Define the architecture inside a Module:
import jax
import elegy as eg
class MLP(eg.Module):
@eg.compact
def __call__(self, x):
x = eg.Linear(300)(x)
x = jax.nn.relu(x)
x = eg.Linear(10)(x)
return x
2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:
import optax optax
import elegy as eg
model = eg.Model(
module=MLP(),
loss=[
eg.losses.Crossentropy(),
eg.regularizers.L2(l=1e-5),
],
metrics=eg.metrics.Accuracy(),
optimizer=optax.rmsprop(1e-3),
)
3. Train the model using the fit method:
model.fit(
inputs=X_train,
labels=y_train,
epochs=100,
steps_per_epoch=200,
batch_size=64,
validation_data=(X_test, y_test),
shuffle=True,
callbacks=[eg.callbacks.TensorBoard("summaries")]
)
Using Flax
Show
To use Flax with Elegy just create a flax.linen.Module and pass it to Model.
import jax
import elegy as eg
import optax optax
import flax.linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(300)(x)
x = jax.nn.relu(x)
x = nn.Dense(10)(x)
return x
model = eg.Model(
module=MLP(),
loss=[
eg.losses.Crossentropy(),
eg.regularizers.L2(l=1e-5),
],
metrics=eg.metrics.Accuracy(),
optimizer=optax.rmsprop(1e-3),
)
As shown here, Flax Modules can optionally request a training argument to __call__ which will be provided by Elegy / Treex.
Using Haiku
Show
To use Haiku with Elegy do the following:
- Create a
forwardfunction. - Create a
TransformedWithStateobject by feedingforwardtohk.transform_with_state. - Pass your
TransformedWithStatetoModel.
You can also optionally create your own hk.Module and use it in forward if needed. Putting everything together should look like this:
import jax
import elegy as eg
import optax optax
import haiku as hk
def forward(x, training: bool):
x = hk.Linear(300)(x)
x = jax.nn.relu(x)
x = hk.Linear(10)(x)
return x
model = eg.Model(
module=hk.transform_with_state(forward),
loss=[
eg.losses.Crossentropy(),
eg.regularizers.L2(l=1e-5),
],
metrics=eg.metrics.Accuracy(),
optimizer=optax.rmsprop(1e-3),
)
As shown here, forward can optionally request a training argument which will be provided by Elegy / Treex.
Quick Start: Low-level API
Elegy's low-level API lets you explicitly define what goes on during training, testing, and inference. Let's define our own custom Model to implement a LinearClassifier with pure JAX:
1. Define a custom init_step method:
class LinearClassifier(eg.Model):
# use treex's API to declare parameter nodes
w: jnp.ndarray = eg.Parameter.node()
b: jnp.ndarray = eg.Parameter.node()
def init_step(self, key: jnp.ndarray, inputs: jnp.ndarray):
self.w = jax.random.uniform(
key=key,
shape=[features_in, 10],
)
self.b = jnp.zeros([10])
self.optimizer = self.optimizer.init(self)
return self
Here we declared the parameters w and b using Treex's Parameter.node() for pedagogical reasons, however normally you don't have to do this since you typically use a sub-Module instead.
2. Define a custom test_step method:
def test_step(self, inputs, labels):
# flatten + scale
inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255
# forward
logits = jnp.dot(inputs, self.w) + self.b
# crossentropy loss
target = jax.nn.one_hot(labels["target"], 10)
loss = optax.softmax_cross_entropy(logits, target).mean()
# metrics
logs = dict(
acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]),
loss=loss,
)
return loss, logs, self
3. Instantiate our LinearClassifier with an optimizer:
model = LinearClassifier(
optimizer=optax.rmsprop(1e-3),
)
4. Train the model using the fit method:
model.fit(
inputs=X_train,
labels=y_train,
epochs=100,
steps_per_epoch=200,
batch_size=64,
validation_data=(X_test, y_test),
shuffle=True,
callbacks=[eg.callbacks.TensorBoard("summaries")]
)
Using other JAX Frameworks
Show
It is straightforward to integrate other functional JAX libraries with this low-level API, here is an example with Flax:
import elegy as eg
import flax.linen as nn
class LinearClassifier(eg.Model):
params: Mapping[str, Any] = eg.Parameter.node()
batch_stats: Mapping[str, Any] = eg.BatchStat.node()
next_key: eg.KeySeq
def __init__(self, module: nn.Module, **kwargs):
self.flax_module = module
super().__init__(**kwargs)
def init_step(self, key, inputs):
self.next_key = eg.KeySeq(key)
variables = self.flax_module.init(
{"params": self.next_key(), "dropout": self.next_key()}, x
)
self.params = variables["params"]
self.batch_stats = variables["batch_stats"]
self.optimizer = self.optimizer.init(self.parameters())
def test_step(self, inputs, labels):
# forward
variables = dict(
params=self.params,
batch_stats=self.batch_stats,
)
logits, variables = self.flax_module.apply(
variables,
inputs,
rngs={"dropout": self.next_key()},
mutable=True,
)
self.batch_stats = variables["batch_stats"]
# loss
target = jax.nn.one_hot(labels["target"], 10)
loss = optax.softmax_cross_entropy(logits, target).mean()
# logs
logs = dict(
accuracy=accuracy,
loss=loss,
)
return loss, logs, self
Examples
Check out the /example directory for some inspiration. To run an example, first install some requirements:
pip install -r examples/requirements.txt
And the run it normally with python e.g.
python examples/flax/mnist_vae.py
Contributing
If your are interested in helping improve Elegy check out the Contributing Guide.
Sponsors 💚
- Quansight - paid development time
Citing Elegy
BibTeX
@software{elegy2020repository,
title = {Elegy: A High Level API for Deep Learning in JAX},
author = {PoetsAI},
year = 2021,
url = {https://github.com/poets-ai/elegy},
version = {0.8.1}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file elegy-0.8.6.tar.gz.
File metadata
- Download URL: elegy-0.8.6.tar.gz
- Upload date:
- Size: 60.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.4 CPython/3.8.12 Linux/5.11.0-1028-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1c21b2854b4c55395ec2cf5bfc110990eb67b9206168f3c4ed38dd66f9f0bb48
|
|
| MD5 |
5d9c75b6b6ef1689624f6138f35e7945
|
|
| BLAKE2b-256 |
31262efee9bb7bcb8b0d2f4d5d77a630f0cd5da71ee5e89d721c29855297c896
|
File details
Details for the file elegy-0.8.6-py3-none-any.whl.
File metadata
- Download URL: elegy-0.8.6-py3-none-any.whl
- Upload date:
- Size: 72.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.4 CPython/3.8.12 Linux/5.11.0-1028-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ce6ab4653a573581c1b1faceee3c114d1045df9ba37913f10a8e02fd30744b42
|
|
| MD5 |
646200d1adfad0ceb3f9067166f3eb02
|
|
| BLAKE2b-256 |
daef96bfb702632bf3d5f82185ddac7d5420ad3e183b315e67459919dc304f30
|