Skip to main content

Flax: A neural network library for JAX designed for flexibility

Project description

logo

Flax: A neural network library and ecosystem for JAX designed for flexibility

Flax - Test PyPI version

Overview | Quick install | What does Flax look like? | Documentation

Released in 2024, Flax NNX is a new simplified Flax API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, enabling reference sharing and mutability.

Flax NNX evolved from the Flax Linen API, which was released in 2020 by engineers and researchers at Google Brain in close collaboration with the JAX team.

You can learn more about Flax NNX on the dedicated Flax documentation site. Make sure you check out:

Note: Flax Linen's documentation has its own site.

The Flax team's mission is to serve the growing JAX neural network research ecosystem - both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request threads.

You can make feature requests, let us know what you are working on, report issues, ask questions in our Flax GitHub discussion forum.

We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use Changelog entries and deprecation warnings when possible.

In case you want to reach us directly, we're at flax-dev@google.com.

Overview

Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

Quick install

Flax uses JAX, so do check out JAX installation instructions on CPUs, GPUs and TPUs.

You will need Python 3.8 or later. Install Flax from PyPi:

pip install flax

To upgrade to the latest version of Flax, you can use:

pip install --upgrade git+https://github.com/google/flax.git

To install some additional dependencies (like matplotlib) that are required but not included by some dependencies, you can use:

pip install "flax[all]"

What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the Module abstraction, check out our docs, our broad intro to the Module abstraction. For additional concrete demonstrations of best practices, refer to our guides and developer notes.

Example of an MLP:

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

Example of a CNN:

class CNN(nnx.Module):
  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, 10, rngs=rngs)

  def __call__(self, x):
    x = self.avg_pool(nnx.relu(self.conv1(x)))
    x = self.avg_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = self.linear2(x)
    return x

Example of an autoencoder:

Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)

class AutoEncoder(nnx.Module):
  def __init__(self, rngs):
    self.encoder = Encoder(rngs)
    self.decoder = Decoder(rngs)

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)

Citing Flax

To cite this repository:

@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.12.2},
  year = {2024},
}

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from flax/version.py, and the year corresponds to the project's open-source release.

Note

Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax-0.12.2.tar.gz (5.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax-0.12.2-py3-none-any.whl (488.0 kB view details)

Uploaded Python 3

File details

Details for the file flax-0.12.2.tar.gz.

File metadata

  • Download URL: flax-0.12.2.tar.gz
  • Upload date:
  • Size: 5.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for flax-0.12.2.tar.gz
Algorithm Hash digest
SHA256 e9723b0881e571abe61885bb8770f53fdb3c383b6b3f5a923dcf6f1e9a687905
MD5 3b555ad122b3d8299b546bd92a0209f1
BLAKE2b-256 6b7ec4c66ab9b41149cf7a1961907d9a844832af1e76b121b35235a618c92825

See more details on using hashes here.

Provenance

The following attestation bundles were made for flax-0.12.2.tar.gz:

Publisher: flax_publish.yml on google/flax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flax-0.12.2-py3-none-any.whl.

File metadata

  • Download URL: flax-0.12.2-py3-none-any.whl
  • Upload date:
  • Size: 488.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for flax-0.12.2-py3-none-any.whl
Algorithm Hash digest
SHA256 912fdd8a7c623ec8b2694b28d2827608e7fc82a3a6f8fff17ec5038f2bca66f4
MD5 33ad6e93da6ed76b761f538ae2a9031e
BLAKE2b-256 b76b7b75508251f4220df8f68e7718b476ee3d614a2a51f9eace97393ee91b46

See more details on using hashes here.

Provenance

The following attestation bundles were made for flax-0.12.2-py3-none-any.whl:

Publisher: flax_publish.yml on google/flax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page