Skip to main content

XAIN is an open source framework for federated learning.

Project description

Workflow Lint and test (master) PyPI GitHub license Documentation Status Gitter chat


Disclaimer: This is work-in-progress and not production-ready, expect errors to occur! Use at your own risk! Do not use for any security related issues!


XAIN SDK

Overview

The XAIN project is building a privacy layer for machine learning so that AI projects can meet compliance such as GDPR and CCPA. The approach relies on Federated Learning as enabling technology that allows production AI applications to be fully privacy compliant.

Federated Learning also enables different use-cases that are not strictly privacy related such as connecting data lakes, reaching higher model performance in unbalanced datasets and utilising AI models on the edge.

The main components:

  • Coordinator: The entity that manages all aspects of the execution of rounds for Federated Learning. This includes the registration of clients, the selection of participants for a given round, the determination of whether sufficiently many participants have sent updated local models, the computation of an aggregated global model, and the sending of the latter model to storage or other entities.
  • Participant: An entity that is the originator of a local dataset that can be selected for local training in the Federated Learning.
  • Selected Participant: A Participant that has been selected by the Coordinator to participate in the next or current round.
  • SDK: The library which allows Participants to interact with the Coordinator and the XAIN Platform.

The source code in this project implements the XAIN SDK to provide your local application a way to communicate with the XAIN Coordinator.

Getting Started

Run The XAIN Coordinator

There are two options to run the XAIN Coordinator to perform Federated Learning on locally trained models:

  • Go to the main page of the project and request a demo for the XAIN Platform.
  • For the self-hosted solution, see XAIN FL Project for more details.

Integrate The XAIN SDK Into Your Project

1. Install The XAIN SDK

To install the XAIN SDK package on your machine, simply run in your terminal:

pip install xain-sdk

2. Register Your Application And The Device To Participate In The Aggregation

Now you can register your Participants to participate in the Federated Learning rounds. To do so, just send the registration request to the XAIN Coordinator:

participant.py
from typing import Optional, Tuple
from numpy import ndarray
from xain_sdk.participant import Participant

class MyParticipant(Participant):
    def __init__(self):

        super(MyParticipant, self).__init__()

        # define or load a model to be trained
        self.model = ...
        self.model_shapes = self.get_MLframework_shapes(self.model)

        # define or load data to be trained on
        self.trainset = ...

    def init_weights(self) -> ndarray:

        ...  # initialize fresh weights for the model

        # return the new weights
        return self.get_MLframework_weights(self.model)

    def train_round(self, weights: ndarray, epochs: int, epoch_base: int) -> Tuple[ndarray, int]:

        # load weights into the model
        self.set_MLframework_weights(weights, self.model_shapes, self.model)

        # train the model for the specified number of epochs and optionally gather metrics
        number_samples = ...
        for epoch in range(epochs):
            ... # training
            self.update_metrics(epoch, epoch_base, MetricName=metric_value, ...)

        # return the updated weights and the number of training samples
        return self.get_MLframework_weights(self.model), number_samples

where the machine learning framework dependent utility methods get_MLframework_shapes(), set_MLframework_weights() and get_MLframework_weights() are described in more detail in section Utility and the gathering of metrics via update_metrics() is further explained in section Model Metrics.

start.py
from xain_sdk.participant_state_machine import start_participant

# Import MyParticipant from your participant.py file 
from participant import MyParticipant

# Create a new participant
p = MyParticipant()

# Register your new participant to interact with XAIN Coordinator 
# (hosted at XAIN Platform or self-hosted solution). 
# The function start_participant requires two arguments:
#   - your new participant to register to interact with Coordinator,
#   - a valid configuration (an example configuration can be found in config/example-config.toml).
config = Config.load("my_config.toml")
start_participant(p, config)

Now you have registered a participant. Simply repeat this step for all the participants you wish to register.

The XAIN Coordinator will take care of the rest:

  • The aggregation of your locally pretrained models.
  • Triggering new training rounds for the selected participants and aggregating these models.

Utility

The Participant base class provide some utility methods to help with the implementation of the train_round() method, namely:

  • set_tensorflow_weights(): Set the weights of a Tensorflow model from a flat weight vector.
  • get_tensorflow_weights(): Get and flatten the weights of a Tensorflow model.
  • get_tensorflow_shapes(): Get the shapes of the weights of a Tensorflow model.
  • set_pytorch_weights(): Set the weights of a Pytorch model from a flat weight vector.
  • get_pytorch_weights(): Get and flatten the weights of a Pytorch model.
  • get_pytorch_shapes(): Get the shapes of the weights of a Pytorch model.
  • update_metrics(): Gather metrics to be send to the coordinator.

Model Metrics

A monitoring feature, which will be available as a XAIN Platform solution. If you would like to compare the performance of aggregated models, please send the specific metrics of your use case that you wish to monitor to the XAIN Coordinator. This will then be reflected in the web interface under the Project Management tab.

In order to send your metrics to the XAIN Coordinator, you can simply add the following line to the training loop in the train_round() method

self.update_metrics(epoch, epoch_base, MetricName=metric_value, ...)

where epoch is the current local epoch number, epoch_base is the global epoch number and MetricName=metric_value can be any number of key-value pairs with a name of a metric as key and a numerical scalar/list/array as value.

Examples

Please see the following examples showing how to implement your Participant with the SDK:

Testing

You can connect multiple participants at once running in parallel to a coordinator with the following script:

Bash script for starting multiple participants

Getting help

If you get stuck, don't worry, we are here to help. You can find more information here:

In case you don't find answers to your problems or questions, simply ask your question to the community here:

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xain-sdk-0.6.0.tar.gz (19.8 kB view hashes)

Uploaded Source

Built Distribution

xain_sdk-0.6.0-py3-none-any.whl (22.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page