Skip to main content

Flax: A neural network library for JAX designed for flexibility

Project description

logo

Flax: A neural network library and ecosystem for JAX designed for flexibility

Build coverage

Overview | Quick install | What does Flax look like? | Documentation

Released in 2024, Flax NNX is a new simplified Flax API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, enabling reference sharing and mutability.

Flax NNX evolved from the Flax Linen API, which was released in 2020 by engineers and researchers at Google Brain in close collaboration with the JAX team.

You can learn more about Flax NNX on the dedicated Flax documentation site. Make sure you check out:

Note: Flax Linen's documentation has its own site.

The Flax team's mission is to serve the growing JAX neural network research ecosystem - both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request threads.

You can make feature requests, let us know what you are working on, report issues, ask questions in our Flax GitHub discussion forum.

We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use Changelog entries and deprecation warnings when possible.

In case you want to reach us directly, we're at flax-dev@google.com.

Overview

Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

Quick install

Flax uses JAX, so do check out JAX installation instructions on CPUs, GPUs and TPUs.

You will need Python 3.8 or later. Install Flax from PyPi:

pip install flax

To upgrade to the latest version of Flax, you can use:

pip install --upgrade git+https://github.com/google/flax.git

To install some additional dependencies (like matplotlib) that are required but not included by some dependencies, you can use:

pip install "flax[all]"

What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the Module abstraction, check out our docs, our broad intro to the Module abstraction. For additional concrete demonstrations of best practices, refer to our guides and developer notes.

Example of an MLP:

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

Example of a CNN:

class CNN(nnx.Module):
  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, 10, rngs=rngs)

  def __call__(self, x):
    x = self.avg_pool(nnx.relu(self.conv1(x)))
    x = self.avg_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = self.linear2(x)
    return x

Example of an autoencoder:

Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)

class AutoEncoder(nnx.Module):
  def __init__(self, rngs):
    self.encoder = Encoder(rngs)
    self.decoder = Decoder(rngs)

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)

Citing Flax

To cite this repository:

@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.11.0},
  year = {2024},
}

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from flax/version.py, and the year corresponds to the project's open-source release.

Note

Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax-0.11.0.tar.gz (5.1 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax-0.11.0-py3-none-any.whl (455.9 kB view details)

Uploaded Python 3

File details

Details for the file flax-0.11.0.tar.gz.

File metadata

  • Download URL: flax-0.11.0.tar.gz
  • Upload date:
  • Size: 5.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for flax-0.11.0.tar.gz
Algorithm Hash digest
SHA256 4b03b7938b5e1f8c5843f59af41116e9184f98e10ce61a8d33b5dd7d2ba91edd
MD5 2fe53d3f1adea86d769c916db39a6294
BLAKE2b-256 47441db7010083a43a4655325731729b413f085373315ce53f90e2ae92665043

See more details on using hashes here.

Provenance

The following attestation bundles were made for flax-0.11.0.tar.gz:

Publisher: flax_publish.yml on google/flax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flax-0.11.0-py3-none-any.whl.

File metadata

  • Download URL: flax-0.11.0-py3-none-any.whl
  • Upload date:
  • Size: 455.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for flax-0.11.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9dd67b7296e70c311a8cfd9a5fcd7bdaac9ccbc732ac70cc5c40b3534532d30a
MD5 328fd327fa62aab8a9443fc040d26c3f
BLAKE2b-256 4b70476765f56b64d0533e3d2c4ee2772f1c12a14fe505dd985bf3ee2c6e55b8

See more details on using hashes here.

Provenance

The following attestation bundles were made for flax-0.11.0-py3-none-any.whl:

Publisher: flax_publish.yml on google/flax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page