Skip to main content

Flax: A neural network library for JAX designed for flexibility

Project description

logo

Flax: A neural network library and ecosystem for JAX designed for flexibility

Flax - Test PyPI version

Overview | Quick install | What does Flax look like? | Documentation

Released in 2024, Flax NNX is a new simplified Flax API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, enabling reference sharing and mutability.

Flax NNX evolved from the Flax Linen API, which was released in 2020 by engineers and researchers at Google Brain in close collaboration with the JAX team.

You can learn more about Flax NNX on the dedicated Flax documentation site. Make sure you check out:

Note: Flax Linen's documentation has its own site.

The Flax team's mission is to serve the growing JAX neural network research ecosystem - both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request threads.

You can make feature requests, let us know what you are working on, report issues, ask questions in our Flax GitHub discussion forum.

We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use Changelog entries and deprecation warnings when possible.

In case you want to reach us directly, we're at flax-dev@google.com.

Overview

Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

Quick install

Flax uses JAX, so do check out JAX installation instructions on CPUs, GPUs and TPUs.

You will need Python 3.8 or later. Install Flax from PyPi:

pip install flax

To upgrade to the latest version of Flax, you can use:

pip install --upgrade git+https://github.com/google/flax.git

To install some additional dependencies (like matplotlib) that are required but not included by some dependencies, you can use:

pip install "flax[all]"

What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the Module abstraction, check out our docs, our broad intro to the Module abstraction. For additional concrete demonstrations of best practices, refer to our guides and developer notes.

Example of an MLP:

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

Example of a CNN:

class CNN(nnx.Module):
  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, 10, rngs=rngs)

  def __call__(self, x):
    x = self.avg_pool(nnx.relu(self.conv1(x)))
    x = self.avg_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = self.linear2(x)
    return x

Example of an autoencoder:

Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)

class AutoEncoder(nnx.Module):
  def __init__(self, rngs):
    self.encoder = Encoder(rngs)
    self.decoder = Decoder(rngs)

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)

Citing Flax

To cite this repository:

@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.12.3},
  year = {2024},
}

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from flax/version.py, and the year corresponds to the project's open-source release.

Note

Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax-0.12.3.tar.gz (5.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax-0.12.3-py3-none-any.whl (490.5 kB view details)

Uploaded Python 3

File details

Details for the file flax-0.12.3.tar.gz.

File metadata

  • Download URL: flax-0.12.3.tar.gz
  • Upload date:
  • Size: 5.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for flax-0.12.3.tar.gz
Algorithm Hash digest
SHA256 8b9ca031d0a4759b0719718de578b27703a8b2d6105552b30f98989ce1be04f0
MD5 482bf190f756f0ea73b4d08de3c9eaf5
BLAKE2b-256 4dae6cbdc663dd3b66dfb0f556342813639d8d04d2fdb7660dbfa75e4000ebf1

See more details on using hashes here.

Provenance

The following attestation bundles were made for flax-0.12.3.tar.gz:

Publisher: flax_publish.yml on google/flax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flax-0.12.3-py3-none-any.whl.

File metadata

  • Download URL: flax-0.12.3-py3-none-any.whl
  • Upload date:
  • Size: 490.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for flax-0.12.3-py3-none-any.whl
Algorithm Hash digest
SHA256 add0432489da2a532357d8a3a014a349d70774156a5a0a594706ba4319efa599
MD5 853d1db705a08b05de5286c0b3153a9d
BLAKE2b-256 4a13b3d5c22d8d4f266eba54795130afda106b12cce0fb92f3e196984725e024

See more details on using hashes here.

Provenance

The following attestation bundles were made for flax-0.12.3-py3-none-any.whl:

Publisher: flax_publish.yml on google/flax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page