Skip to main content

Flax: A neural network library for JAX designed for flexibility

Project description

Flax: A neural network library for JAX designed for flexibility

NOTE: This is alpha software, but we encourage trying it out. Changes will come to the API, but we'll use deprecation warnings when we can, and keep track of them our Changelog.

A growing community of researchers at Google are happily using Flax daily for their research, and now we'd like to extend that support to the open source community. GitHub issues are encouraged for open converation, but in case you need to reach us directly, we're at flax-dev@google.com.

Quickstart

Full documentation and API reference

Annotated full end-to-end MNIST example

The Flax Guide -- a guided walkthrough of the parts of Flax

Background: JAX

JAX is NumPy + autodiff + GPU/TPU

It allows for fast scientific computing and machine learning with the normal NumPy API (+ additional APIs for special accelerator ops when needed)

JAX comes with powerful primitives, which you can compose arbitrarily:

  • Autodiff (jax.grad): Efficient any-order gradients w.r.t any variables
  • JIT compilation (jax.jit): Trace any function ⟶ fused accelerator ops
  • Vectorization (jax.vmap): Automatically batch code written for individual samples
  • Parallelization (jax.pmap): Automatically parallelize code across multiple accelerators (including across hosts, e.g. for large TPUs)

What is Flax?

Flax is a high-performance neural network library for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not by adding features to the framework.

Flax comes with everything you need to start your research, including:

  • A module abstraction (flax.nn.Module) for parameterized functions such as neural network layers.

  • Common layers (flax.nn): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout

  • Optimizers (flax.optim): SGD, Momentum, Adam, LARS

  • Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device

  • Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging

  • HOWTO guides -- diffs that add functionality to educational base exampless

  • Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet ImageNet, Transformer LM1b

An annotated MNIST example

See docs/annotated_mnist.md for an MNIST example with detailed annotations for each code block.

Flax Modules

The core of Flax is the Module abstraction. Modules allow you to write parameterized functions just as if you were writing a normal numpy function with JAX. The Module api allows you to declare parameters and use them directly with the JAX api’s.

Modules are the one part of Flax with "magic" -- the magic is constrained, and enables a very ergonomic style, where modules are defined in a single function with minimal boilerplate.

A few things to know about Modules:

  1. Create a new module by subclassing flax.nn.Module and implementing the apply method.

  2. Within apply, call self.param(name, shape, init_func) to register a new parameter and returns its initial value.

  3. Apply submodules by calling MySubModule(...args...) within MyModule.apply. Parameters of MySubModule are stored as a dictionary under the parameters MyModule. NOTE: this returns the output of MySubModule, not an instance. To get an access to an instance of MySubModule for re-use, use Module.partial or Module.shared

  4. MyModule.init(rng, ...) is a pure function that calls apply in "init mode" and returnes a nested Python dict of initialized parameter values

  5. MyModule.call(params, ...) is a pure function that calls apply in "call mode" and returnes the output of the module.

For example you can define a learned linear transformation as follows:

from flax import nn
import jax.numpy as jnp

class Linear(nn.Module):
  def apply(self, x, num_features, kernel_init_fn):
    input_features = x.shape[-1]
    W = self.param('W', (input_features, num_features), kernel_init_fn)
    return jnp.dot(x, W)

You can also use nn.module as a function decorator to create a new module, as long as you don't need access to self for creating parameters directly:

@nn.module
def DenseLayer(x, features):
  x = flax.nn.Dense(x, features)
  x = flax.nn.relu(x)
  return x

Read more about Flax Modules and the other parts of the Flax API in the Flax Guide

CPU-only Installation

You will need Python 3.5 or later.

Now install flax from Github:

> pip install git+https://github.com/google-research/flax.git@prerelease

GPU accelerated installation

First install jaxlib; please follow the instructions in the JAX readme. If they are not already installed, you will need to install CUDA and CuDNN runtimes.

Now install flax from Github:

> pip install git+https://github.com/google-research/flax.git@prerelease

Full end-to-end MNIST example

import jax
import flax
import numpy as onp
import jax.numpy as jnp
import tensorflow_datasets as tfds

class CNN(flax.nn.Module):
  def apply(self, x):
    x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))
    x = flax.nn.Dense(x, features=256)
    x = flax.nn.relu(x)
    x = flax.nn.Dense(x, features=10)
    x = flax.nn.log_softmax(x)
    return x

@jax.vmap
def cross_entropy_loss(logits, label):
  return -logits[label]

def compute_metrics(logits, labels):
  loss = jnp.mean(cross_entropy_loss(logits, labels))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return {'loss': loss, 'accuracy': accuracy}

@jax.jit
def train_step(optimizer, batch):
  def loss_fn(model):
    logits = model(batch['image'])
    loss = jnp.mean(cross_entropy_loss(
        logits, batch['label']))
    return loss
  grad = jax.grad(loss_fn)(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
  return optimizer

@jax.jit
def eval(model, eval_ds):
  logits = model(eval_ds['image'] / 255.0)
  return compute_metrics(logits, eval_ds['label'])

def train():
  train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
  train_ds = train_ds.map(lambda x: {'image':tf.cast(x['image'], tf.float32),
                                     'label':tf.cast(x['label'], tf.int32)})
  train_ds = train_ds.cache().shuffle(1000).batch(128)
  test_ds = tfds.as_numpy(tfds.load(
      'mnist', split=tfds.Split.TEST, batch_size=-1))
  test_ds = {'image': test_ds['image'].astype(jnp.float32),
             'label': test_ds['label'].astype(jnp.int32)}

  _, initial_params = CNN.init_by_shape(
      jax.random.PRNGKey(0),
      [((1, 28, 28, 1), jnp.float32)])
  model = nn.Model(CNN, initial_params)

  optimizer = flax.optim.Momentum(
      learning_rate=0.1, beta=0.9).create(model)

  for epoch in range(10):
    for batch in tfds.as_numpy(train_ds):
      batch['image'] = batch['image'] / 255.0
      optimizer = train_step(optimizer, batch)

    metrics = eval(optimizer.target, test_ds)
    print('eval epoch: %d, loss: %.4f, accuracy: %.2f'
         % (epoch+1,
          metrics['loss'], metrics['accuracy'] * 100))

More end-to-end examples

NOTE: We are still testing these examples across all supported hardware configurations.

Getting involved

Have questions? Want to learn more? Reach out to us at flax-dev@google.com

Want to help?

We're happy to work together, either remotely or in Amsterdam.

In addition to general improvements to the framework, here are some specific things that would be great to have:

Help build more HOWTOs

(TODO: clarify list)

Help build new end-to-end examples

  • Semantic Segmentation
  • GAN
  • VAE
  • ...and your proposal!

Note

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax-0.1.0rc1.tar.gz (47.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax-0.1.0rc1-py3-none-any.whl (60.2 kB view details)

Uploaded Python 3

File details

Details for the file flax-0.1.0rc1.tar.gz.

File metadata

  • Download URL: flax-0.1.0rc1.tar.gz
  • Upload date:
  • Size: 47.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/45.1.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.6

File hashes

Hashes for flax-0.1.0rc1.tar.gz
Algorithm Hash digest
SHA256 d88f24cd59e22fcfe4bfc037c88ecc34d3f7d9ce8db297c317421f7664d6b678
MD5 de4c79d3ef5c4a3c3695dc123ddc4d32
BLAKE2b-256 3d1bf363177a24ec73bb28cc6425d85de8f4d589ab51871d9af1b8af4e7c1e2b

See more details on using hashes here.

File details

Details for the file flax-0.1.0rc1-py3-none-any.whl.

File metadata

  • Download URL: flax-0.1.0rc1-py3-none-any.whl
  • Upload date:
  • Size: 60.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/45.1.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.6

File hashes

Hashes for flax-0.1.0rc1-py3-none-any.whl
Algorithm Hash digest
SHA256 8c7cceaf0a1116aeb38d5a7f4b24d9baea5a9dddfa4fdf87802363ed2b887ea7
MD5 19dafd44431cb1176469b67df2004e37
BLAKE2b-256 8c5eb4d781e8a1689ed1b71296be8eaf79e21d3e6e4e473a79b6362e65faed02

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page