Skip to main content

Federated learning simulation with JAX.

Project description

FedJAX: Federated learning simulation with JAX

Build and minimal test Documentation Status PyPI version

Documentation | Paper

NOTE: FedJAX is not an officially supported Google product. FedJAX is still in the early stages and the API will likely continue to change.

What is FedJAX?

FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research. With its simple primitives for implementing federated learning algorithms, prepackaged datasets, models and algorithms, and fast simulation speed, FedJAX aims to make developing and evaluating federated algorithms faster and easier for researchers. FedJAX works on accelerators (GPU and TPU) without much additional effort. Additional details and benchmarks can be found in our paper.

Quickstart

The following tutorial notebooks provide an introduction to FedJAX:

You can also take a look at some of our examples:

Below, we walk through a simple example of federated averaging for linear regression implemented in FedJAX. The first steps are to set up the experiment by loading the federated dataset, initializing the model parameters, and defining the loss and gradient functions. The federated dataset can be thought of as a simple mapping from client identifiers to each client's local dataset.

import jax
import jax.numpy as jnp
import fedjax

# {'client_id': client_dataset}.
federated_data = fedjax.FederatedData()
# Initialize model parameters.
server_params = jnp.array(0.5)
# Mean squared error.
mse_loss = lambda params, batch: jnp.mean(
        (jnp.dot(batch['x'], params) - batch['y'])**2)
# jax.jit for XLA and jax.grad for autograd.
grad_fn = jax.jit(jax.grad(mse_loss))

Next, we use fedjax.for_each_client to coordinate the training that occurs across multiple clients. For federated averaging, client_init initializes the client model using the server model, client_step completes one step of local mini-batch SGD, and client_final returns the difference between the initial server model and the trained client model. By using fedjax.for_each_client, this work will run on any available accelerators and possibly in parallel because it is backed by jax.jit and jax.pmap. However, while this is already straightforward to write, the same could also be written out as a basic for loop over clients if desired.

# For loop over clients with client learning rate 0.1.
for_each_client = fedjax.for_each_client(
  client_init=lambda server_params, _: server_params,
  client_step=(
    lambda params, batch: params - grad_fn(params, batch) * 0.1),
  client_final=lambda server_params, params: server_params - params)

Finally, we run federated averaging for 100 training rounds by sampling clients from the federated dataset, training across these clients using the fedjax.for_each_client, and aggregating the client updates using weighted averaging to update the server model.

# 100 rounds of federated training.
for _ in range(100):
  clients = federated_data.clients()
  client_updates = []
  client_weights = []
  for client_id, update in for_each_client(server_params, clients):
    client_updates.append(update)
    client_weights.append(federated_data.client_size(client_id))
  # Weighted average of client updates.
  server_update = (
    jnp.sum(client_updates * client_weights) /
    jnp.sum(client_weights))
  # Server learning rate of 0.01.
  server_params = server_params - server_update * 0.01

Installation

You will need Python 3.6 or later and a working JAX installation. For a CPU-only version:

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

For other devices (e.g. GPU), follow these instructions.

Then, install fedjax from PyPi:

pip install fedjax

Or, to upgrade to the latest version of fedjax:

pip install --upgrade git+https://github.com/google/fedjax.git

Citing FedJAX

To cite this repository:

@software{fedjax2020github,
  author = {Jae Hun Ro and Ananda Theertha Suresh and Ke Wu},
  title = {{F}ed{JAX}: Federated learning simulation with {JAX}},
  url = {http://github.com/google/fedjax},
  version = {0.0.6},
  year = {2020},
}

In the above bibtex entry, the version number is intended to be that from fedjax/version.py, and the year corresponds to the project's open-source release. There is also an associated paper.

Useful pointers

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fedjax-0.0.7.tar.gz (152.0 kB view hashes)

Uploaded Source

Built Distribution

fedjax-0.0.7-py3-none-any.whl (254.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page