Skip to main content

Federated learning simulation with JAX.

Project description

FedJAX: Federated learning simulation with JAX

Build and minimal test Documentation Status PyPI version

Documentation | Paper

NOTE: FedJAX is not an officially supported Google product. FedJAX is still in the early stages and the API will likely continue to change.

What is FedJAX?

FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research. With its simple primitives for implementing federated learning algorithms, prepackaged datasets, models and algorithms, and fast simulation speed, FedJAX aims to make developing and evaluating federated algorithms faster and easier for researchers. FedJAX works on accelerators (GPU and TPU) without much additional effort. Additional details and benchmarks can be found in our paper.

Installation

You will need a moderately recent version of Python. Please check the PyPI page for the up to date version requirement.

First, install JAX. For a CPU-only version:

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

For other devices (e.g. GPU), follow these instructions.

Then, install FedJAX from PyPI:

pip install fedjax

Or, to upgrade to the latest version of FedJAX:

pip install --upgrade git+https://github.com/google/fedjax.git

Getting Started

Below is a simple example to verify FedJAX is installed correctly.

import fedjax
import jax
import jax.numpy as jnp
import numpy as np

# {'client_id': client_dataset}.
fd = fedjax.InMemoryFederatedData({
    'a': {
        'x': np.array([1.0, 2.0, 3.0]),
        'y': np.array([2.0, 4.0, 6.0]),
    },
    'b': {
        'x': np.array([4.0]),
        'y': np.array([12.0])
    }
})
# Initial model parameters.
params = jnp.array(0.5)
# Mean squared error.
mse_loss = lambda params, batch: jnp.mean(
    (jnp.dot(batch['x'], params) - batch['y'])**2)
# Loss for clients 'a' and 'b'.
print(f"client a loss = {mse_loss(params, fd.get_client('a').all_examples())}")
print(f"client b loss = {mse_loss(params, fd.get_client('b').all_examples())}")

The following tutorial notebooks provide an introduction to FedJAX:

You can also take a look at some of our working examples:

Citing FedJAX

To cite this repository:

@article{fedjax2021,
  title={{F}ed{JAX}: Federated learning simulation with {JAX}},
  author={Jae Hun Ro and Ananda Theertha Suresh and Ke Wu},
  journal={arXiv preprint arXiv:2108.02117},
  year={2021}
}

Useful pointers

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fedjax-0.0.17.tar.gz (133.6 kB view details)

Uploaded Source

Built Distribution

fedjax-0.0.17-py3-none-any.whl (616.8 kB view details)

Uploaded Python 3

File details

Details for the file fedjax-0.0.17.tar.gz.

File metadata

  • Download URL: fedjax-0.0.17.tar.gz
  • Upload date:
  • Size: 133.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.2

File hashes

Hashes for fedjax-0.0.17.tar.gz
Algorithm Hash digest
SHA256 8eab7a82b41b02095e804e50cb09c676edc1170affe4b881b34e320fac4e7b0c
MD5 ad5489c22462405f9dda99af803b4b2f
BLAKE2b-256 cf64c7c7929d9bdec871d6fe637e583b015ffe03ed5c6613c294488d7d47eeb3

See more details on using hashes here.

File details

Details for the file fedjax-0.0.17-py3-none-any.whl.

File metadata

  • Download URL: fedjax-0.0.17-py3-none-any.whl
  • Upload date:
  • Size: 616.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.2

File hashes

Hashes for fedjax-0.0.17-py3-none-any.whl
Algorithm Hash digest
SHA256 42be1d21a57843ccdf3f2af802fe6f8fcdf8530da9b2812b178d18af8ac8c0a1
MD5 e5a88248ecabb68643e70bca8431987f
BLAKE2b-256 81ba6c5195fbbe38d6bd92936f18fc85fbaa0042c1e8b6a0424ccc30e82472fb

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page