Skip to main content

Flax: A neural network library for JAX designed for flexibility

Project description

logo

Flax: A neural network library and ecosystem for JAX designed for flexibility

Build coverage

Overview | Quick install | What does Flax look like? | Documentation

Released in 2024, Flax NNX is a new simplified Flax API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, enabling reference sharing and mutability.

Flax NNX evolved from the Flax Linen API, which was released in 2020 by engineers and researchers at Google Brain in close collaboration with the JAX team.

You can learn more about Flax NNX on the dedicated Flax documentation site. Make sure you check out:

Note: Flax Linen's documentation has its own site.

The Flax team's mission is to serve the growing JAX neural network research ecosystem - both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request threads.

You can make feature requests, let us know what you are working on, report issues, ask questions in our Flax GitHub discussion forum.

We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use Changelog entries and deprecation warnings when possible.

In case you want to reach us directly, we're at flax-dev@google.com.

Overview

Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

Quick install

Flax uses JAX, so do check out JAX installation instructions on CPUs, GPUs and TPUs.

You will need Python 3.8 or later. Install Flax from PyPi:

pip install flax

To upgrade to the latest version of Flax, you can use:

pip install --upgrade git+https://github.com/google/flax.git

To install some additional dependencies (like matplotlib) that are required but not included by some dependencies, you can use:

pip install "flax[all]"

What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the Module abstraction, check out our docs, our broad intro to the Module abstraction. For additional concrete demonstrations of best practices, refer to our guides and developer notes.

Example of an MLP:

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

Example of a CNN:

class CNN(nnx.Module):
  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, 10, rngs=rngs)

  def __call__(self, x):
    x = self.avg_pool(nnx.relu(self.conv1(x)))
    x = self.avg_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = self.linear2(x)
    return x

Example of an autoencoder:

Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)

class AutoEncoder(nnx.Module):
  def __init__(self, rngs):
    self.encoder = Encoder(rngs)
    self.decoder = Decoder(rngs)

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)

Citing Flax

To cite this repository:

@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.10.3},
  year = {2024},
}

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from flax/version.py, and the year corresponds to the project's open-source release.

Note

Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax-0.10.3.tar.gz (5.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax-0.10.3-py3-none-any.whl (435.7 kB view details)

Uploaded Python 3

File details

Details for the file flax-0.10.3.tar.gz.

File metadata

  • Download URL: flax-0.10.3.tar.gz
  • Upload date:
  • Size: 5.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.1

File hashes

Hashes for flax-0.10.3.tar.gz
Algorithm Hash digest
SHA256 29cde8cf05ffbff39b7f7167f0fe9916694cce76ce4c14e8be3549c1fd1b7c81
MD5 626f560a332bf80fd37c8add0b137bb1
BLAKE2b-256 956379b204d9f99e855ff36207d4810b4932f83dc7c100e9512eebe1466f0c4f

See more details on using hashes here.

File details

Details for the file flax-0.10.3-py3-none-any.whl.

File metadata

  • Download URL: flax-0.10.3-py3-none-any.whl
  • Upload date:
  • Size: 435.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.1

File hashes

Hashes for flax-0.10.3-py3-none-any.whl
Algorithm Hash digest
SHA256 7158b5dd6a05837e662a1ce1beea7adbad6d3612c0551c986b1c0a56071e3021
MD5 b79ceffb867adcca1f42684b01b401a4
BLAKE2b-256 77285d84597a2b1703486b554ab430f1e22f5f26a732ad98c398a8b414cd22c4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page