Elegy is a Neural Networks framework based on Jax and Haiku.
Project description
Elegy
Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.
Main Features
- Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
- Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
- Agnostic: Elegy supports a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
- Compatible: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.
For more information take a look at the Documentation.
Installation
Install Elegy using pip:
pip install elegy
For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.
Quick Start: High-level API
Elegy's high-level API provides a very simple interface you can use by implementing following steps:
1. Define the architecture inside a Module. We will use Flax Linen for this example:
import flax.linen as nn
import jax
class MLP(nn.Module):
@nn.compact
def call(self, x):
x = nn.Dense(300)(x)
x = jax.nn.relu(x)
x = nn.Dense(10)(x)
return x
2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:
import elegy, optax
model = elegy.Model(
module=MLP(),
loss=[
elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
elegy.regularizers.GlobalL2(l=1e-5),
],
metrics=elegy.metrics.SparseCategoricalAccuracy(),
optimizer=optax.rmsprop(1e-3),
)
3. Train the model using the fit method:
model.fit(
x=X_train,
y=y_train,
epochs=100,
steps_per_epoch=200,
batch_size=64,
validation_data=(X_test, y_test),
shuffle=True,
callbacks=[elegy.callbacks.TensorBoard("summaries")]
)
Quick Start: Low-level API
In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the test_step to implement a linear classifier in pure jax:
1. Calculate our loss, logs, and states:
class LinearClassifier(elegy.Model):
# request parameters by name via depending injection.
# names: x, y_true, sample_weight, class_weight, states, initializing
def test_step(
self,
x, # inputs
y_true, # labels
states: elegy.States, # model state
initializing: bool, # if True we should initialize our parameters
):
rng: elegy.RNGSeq = states.rng
# flatten + scale
x = jnp.reshape(x, (x.shape[0], -1)) / 255
# initialize or use existing parameters
if initializing:
w = jax.random.uniform(
rng.next(), shape=[np.prod(x.shape[1:]), 10]
)
b = jax.random.uniform(rng.next(), shape=[1])
else:
w, b = states.net_params
# model
logits = jnp.dot(x, w) + b
# categorical crossentropy loss
labels = jax.nn.one_hot(y_true, 10)
loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
# metrics
logs = dict(
accuracy=accuracy,
loss=loss,
)
return loss, logs, states.update(net_params=(w, b))
2. Instantiate our LinearClassifier with an optimizer:
model = LinearClassifier(
optimizer=optax.rmsprop(1e-3),
)
3. Train the model using the fit method:
model.fit(
x=X_train,
y=y_train,
epochs=100,
steps_per_epoch=200,
batch_size=64,
validation_data=(X_test, y_test),
shuffle=True,
callbacks=[elegy.callbacks.TensorBoard("summaries")]
)
Using Jax Frameworks
It is straightforward to integrate other functional JAX libraries with this low-level API:
class LinearClassifier(elegy.Model):
def test_step(
self, x, y_true, states: elegy.States, initializing: bool
):
rng: elegy.RNGSeq = states.rng
x = jnp.reshape(x, (x.shape[0], -1)) / 255
if initializing:
logits, variables = self.module.init_with_output(
{"params": rng.next(), "dropout": rng.next()}, x
)
else:
variables = dict(params=states.net_params, **states.net_states)
logits, variables = self.module.apply(
variables, x, rngs={"dropout": rng.next()}, mutable=True
)
net_states, net_params = variables.pop("params")
labels = jax.nn.one_hot(y_true, 10)
loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
logs = dict(accuracy=accuracy, loss=loss)
return loss, logs, states.update(net_params=net_params, net_states=net_states)
More Info
- Getting Started: High-level API tutorial.
- Getting Started: Low-level API tutorial.
- Elegy's Documentation.
- The examples directory.
- What is Jax?
Examples
To run the examples first install some required packages:
pip install -r examples/requirements.txt
Now run the example:
python examples/flax_mnist_vae.py
Contributing
Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.
About Us
We are some friends passionate about ML.
License
Apache
Citing Elegy
To cite this project:
BibTeX
@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.7.2},
year = {2020},
}
Where the current version may be retrieved either from the Release tag or the file elegy/__init__.py and the year corresponds to the project's release year.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file elegy-0.7.2.tar.gz.
File metadata
- Download URL: elegy-0.7.2.tar.gz
- Upload date:
- Size: 138.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.4 CPython/3.9.0 Linux/5.8.0-7642-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
316de8e2b00037ad0da78f2078e7f2960281dddbb154f466b9bc42d807d2ee2f
|
|
| MD5 |
0c2757ef17101aafac5e74db076a59cd
|
|
| BLAKE2b-256 |
82eaaceae54d24a68e080266ae3ab3b0979da1192e8a2d46b2cd6784bc01a37f
|
File details
Details for the file elegy-0.7.2-py3-none-any.whl.
File metadata
- Download URL: elegy-0.7.2-py3-none-any.whl
- Upload date:
- Size: 213.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.4 CPython/3.9.0 Linux/5.8.0-7642-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c053d22150f234e1c7bc3c51f3e6272654c852d7f5395afa202a9516c0375582
|
|
| MD5 |
cbc3c95c9b622f60ea1e685087541d38
|
|
| BLAKE2b-256 |
787bb53f906b5a9f78fa5ca46e684e1d6538b2c22e5bc5a9fef5ca2312949b29
|