Federated learning simulation with JAX.
Project description
FedJAX: Federated learning simulation with JAX
NOTE: FedJAX is not an officially supported Google product. FedJAX is still in the early stages and the API will likely continue to change.
What is FedJAX?
FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research. With its simple primitives for implementing federated learning algorithms, prepackaged datasets, models and algorithms, and fast simulation speed, FedJAX aims to make developing and evaluating federated algorithms faster and easier for researchers. FedJAX works on accelerators (GPU and TPU) without much additional effort. Additional details and benchmarks can be found in our paper.
Quickstart
The following tutorial notebooks provide an introduction to FedJAX:
You can also take a look at some of our examples:
Below, we walk through a simple example of federated averaging for linear regression implemented in FedJAX. The first steps are to set up the experiment by loading the federated dataset, initializing the model parameters, and defining the loss and gradient functions. The federated dataset can be thought of as a simple mapping from client identifiers to each client's local dataset.
import jax
import jax.numpy as jnp
import fedjax
# {'client_id': client_dataset}.
federated_data = fedjax.FederatedData()
# Initialize model parameters.
server_params = jnp.array(0.5)
# Mean squared error.
mse_loss = lambda params, batch: jnp.mean(
(jnp.dot(batch['x'], params) - batch['y'])**2)
# jax.jit for XLA and jax.grad for autograd.
grad_fn = jax.jit(jax.grad(mse_loss))
Next, we use
fedjax.for_each_client
to coordinate the training that occurs across multiple clients. For federated
averaging, client_init
initializes the client model using the server model,
client_step
completes one step of local mini-batch SGD, and client_final
returns the difference between the initial server model and the trained client
model. By using fedjax.for_each_client
, this work will run on any available
accelerators and possibly in parallel because it is backed by jax.jit
and
jax.pmap
. However, while this is already straightforward to write, the same
could also be written out as a basic for loop over clients if desired.
# For loop over clients with client learning rate 0.1.
for_each_client = fedjax.for_each_client(
client_init=lambda server_params, _: server_params,
client_step=(
lambda params, batch: params - grad_fn(params, batch) * 0.1),
client_final=lambda server_params, params: server_params - params)
Finally, we run federated averaging for 100
training rounds by sampling
clients from the federated dataset, training across these clients using the
fedjax.for_each_client
, and aggregating the client updates using weighted
averaging to update the server model.
# 100 rounds of federated training.
for _ in range(100):
clients = federated_data.clients()
client_updates = []
client_weights = []
for client_id, update in for_each_client(server_params, clients):
client_updates.append(update)
client_weights.append(federated_data.client_size(client_id))
# Weighted average of client updates.
server_update = (
jnp.sum(client_updates * client_weights) /
jnp.sum(client_weights))
# Server learning rate of 0.01.
server_params = server_params - server_update * 0.01
Installation
You will need Python 3.6 or later and a working JAX installation. For a CPU-only version:
pip install --upgrade pip
pip install --upgrade jax jaxlib # CPU-only version
For other devices (e.g. GPU), follow these instructions.
Then, install fedjax from PyPi:
pip install fedjax
Or, to upgrade to the latest version of fedjax:
pip install --upgrade git+https://github.com/google/fedjax.git
Citing FedJAX
To cite this repository:
@software{fedjax2020github,
author = {Jae Hun Ro and Ananda Theertha Suresh and Ke Wu},
title = {{F}ed{JAX}: Federated learning simulation with {JAX}},
url = {http://github.com/google/fedjax},
version = {0.0.6},
year = {2020},
}
In the above bibtex entry, the version number is intended to be that from fedjax/version.py, and the year corresponds to the project's open-source release. There is also an associated paper.
Useful pointers
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.