Federated learning simulation with JAX.
Project description
FedJAX: Federated learning simulation with JAX
NOTE: FedJAX is not an officially supported Google product. FedJAX is still in the early stages and the API will likely continue to change.
What is FedJAX?
FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research. With its simple primitives for implementing federated learning algorithms, prepackaged datasets, models and algorithms, and fast simulation speed, FedJAX aims to make developing and evaluating federated algorithms faster and easier for researchers. FedJAX works on accelerators (GPU and TPU) without much additional effort. Additional details and benchmarks can be found in our paper.
Installation
You will need a moderately recent version of Python. Please check the PyPI page for the up to date version requirement.
First, install JAX. For a CPU-only version:
pip install --upgrade pip
pip install --upgrade jax jaxlib # CPU-only version
For other devices (e.g. GPU), follow these instructions.
Then, install FedJAX from PyPI:
pip install fedjax
Or, to upgrade to the latest version of FedJAX:
pip install --upgrade git+https://github.com/google/fedjax.git
Getting Started
Below is a simple example to verify FedJAX is installed correctly.
import fedjax
import jax
import jax.numpy as jnp
import numpy as np
# {'client_id': client_dataset}.
fd = fedjax.InMemoryFederatedData({
'a': {
'x': np.array([1.0, 2.0, 3.0]),
'y': np.array([2.0, 4.0, 6.0]),
},
'b': {
'x': np.array([4.0]),
'y': np.array([12.0])
}
})
# Initial model parameters.
params = jnp.array(0.5)
# Mean squared error.
mse_loss = lambda params, batch: jnp.mean(
(jnp.dot(batch['x'], params) - batch['y'])**2)
# Loss for clients 'a' and 'b'.
print(f"client a loss = {mse_loss(params, fd.get_client('a').all_examples())}")
print(f"client b loss = {mse_loss(params, fd.get_client('b').all_examples())}")
The following tutorial notebooks provide an introduction to FedJAX:
You can also take a look at some of our working examples:
Citing FedJAX
To cite this repository:
@article{fedjax2021,
title={{F}ed{JAX}: Federated learning simulation with {JAX}},
author={Jae Hun Ro and Ananda Theertha Suresh and Ke Wu},
journal={arXiv preprint arXiv:2108.02117},
year={2021}
}
Useful pointers
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file fedjax-0.0.17.tar.gz
.
File metadata
- Download URL: fedjax-0.0.17.tar.gz
- Upload date:
- Size: 133.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8eab7a82b41b02095e804e50cb09c676edc1170affe4b881b34e320fac4e7b0c |
|
MD5 | ad5489c22462405f9dda99af803b4b2f |
|
BLAKE2b-256 | cf64c7c7929d9bdec871d6fe637e583b015ffe03ed5c6613c294488d7d47eeb3 |
File details
Details for the file fedjax-0.0.17-py3-none-any.whl
.
File metadata
- Download URL: fedjax-0.0.17-py3-none-any.whl
- Upload date:
- Size: 616.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 42be1d21a57843ccdf3f2af802fe6f8fcdf8530da9b2812b178d18af8ac8c0a1 |
|
MD5 | e5a88248ecabb68643e70bca8431987f |
|
BLAKE2b-256 | 81ba6c5195fbbe38d6bd92936f18fc85fbaa0042c1e8b6a0424ccc30e82472fb |