Skip to main content

Flax: A neural network library for JAX designed for flexibility

Project description

Flax: A neural network library and ecosystem for JAX designed for flexibility

Overview | Quick install | What does Flax look like? | Documentation

coverage

See our full documentation to learn everything you need to know about Flax.

Flax is developed by a group within the Brain Team in Google AI, in close collaboration with the JAX team. Flax is being used by a growing community of hundreds of folks in various Alphabet research departments for their daily work, as well as a growing community of open source projects.

The Flax team's mission is to serve the growing JAX neural network research ecosystem -- both within Alphabet and with the broader , and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request thread. We are in the process of moving some remaining internal design docs and conversation threads to GitHub discussions, issues and pull requests. We hope to increasingly engage with the needs and clarifications of the broader ecosystem. Please let us know how we can help!

NOTE: The new Flax "Linen" module API is now stable and we recommend it for all new projects. The old flax.nn API will be deprecated.

Please report any feature requests, issues, questions or concerns in our discussion forum, or just let us know what you're working on!

We expect to add some improvements to Flax, but we only expect minor API changes to the core API. We will use Changelog entries and deprecation warnings when possible.

In case you want to reach us directly, we're at flax-dev@google.com.

Overview

Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not by adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

  • Neural network API (flax.linen): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout

  • Optimizers (flax.optim): SGD, Momentum, Adam, LARS, Adagrad, LAMB, RMSprop

  • Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device

  • Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging

  • Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet on ImageNet, Transformer LM1b

Quick install

You will need Python 3.6 or later and a working JAX installation (with or without GPU support, see instructions there). For a CPU-only version:

> pip install --upgrade pip # To support manylinux2010 wheels.
> pip install --upgrade jax jaxlib # CPU-only

Then install Flax from PyPi:

> pip install flax

To upgrade to the latest version of Flax, you can use:

> pip install --upgrade git+https://github.com/google/flax.git

What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the Module abstraction, please check our docs, our broad intro to the Module abstraction or visit our patterns page for additional concrete demonstrations of best practices.

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(Dense(feat)(x))
    x = Dense(self.features[-1])(x)
    return x
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x
class AutoEncoder(Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Tuple[int] = None

  def setup(self):
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (jnp.prod(self.input_shape, ))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

Note

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax-0.3.0.tar.gz (112.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax-0.3.0-py3-none-any.whl (154.9 kB view details)

Uploaded Python 3

File details

Details for the file flax-0.3.0.tar.gz.

File metadata

  • Download URL: flax-0.3.0.tar.gz
  • Upload date:
  • Size: 112.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.54.1 CPython/3.9.0

File hashes

Hashes for flax-0.3.0.tar.gz
Algorithm Hash digest
SHA256 f468d6db92c8a5a35ba66bf2744fe096566f7768a95f1a4e6a80ace0a01361ae
MD5 82dc31af0dcbb2fa770c0446dca0d68a
BLAKE2b-256 b4e24711aa6ec502cf62f8a632a556d8bc139d8fb005d0b87e4b719d17e645ac

See more details on using hashes here.

File details

Details for the file flax-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: flax-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 154.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.54.1 CPython/3.9.0

File hashes

Hashes for flax-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0ab0269815069f763fd3e387e78d434ac2611fe5bf1cb9b78d64e6ccf44f7591
MD5 222cce34f8ab841fc841d152f24ee9e9
BLAKE2b-256 c7c0941b4d2a2164c677fe665b6ddb5ac90306d76f8ffc298f44c41c64b30f1a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page