Generic Federated Learning Simulator with PyTorch
Project description
FedSim
FedSim is a Generic Federated Learning Simulator. It aims to provide the researchers with an easy to develope/maintain simulator for Federated Learning. See documentation at here!
Installation
pip install fedsim
Usage
As module
Here is a demo:
from torch.utils.tensorboard import SummaryWriter
from fedsim.distributed.centralized.training import FedAvg
from fedsim.distributed.data_management import BasicDataManager
from fedsim.models.mcmahan_nets import cnn_cifar100
from fedsim.scores import cross_entropy
from fedsim.scores import accuracy
n_clients = 1000
dm = BasicDataManager("./data", "cifar100", n_clients)
sw = SummaryWriter()
alg = FedAvg(
data_manager=dm,
num_clients=n_clients,
sample_scheme="uniform",
sample_rate=0.01,
model_class=cnn_cifar100,
epochs=5,
loss_fn=cross_entropy,
batch_size=32,
metric_logger=sw,
device="cuda",
)
alg.hook_global_score_function("test", "accuracy", accuracy)
for key in dm.get_local_splits_names():
alg.hook_local_score_function(key, "accuracy", accuracy)
alg.train(rounds=1)
Included cli tool
For help with cli check here:
fedsim --help
DataManager
Any custome DataManager class should inherit from fedsim.data_manager.data_manager.DataManager (or its children) and implement its abstract methods. For example:
from fedsim.distributed.data_management.basic_data_manager import BasicDataManager
class CustomDataManager(DataManager)
def __init__(self, root, other_arg, ...):
self.other_arg = other_arg
# note that super should be called at the end of init \
# because the abstract classes are called in its __init__
super(BasicDataManager, self).__init__(root, seed, save_path=save_path)
def make_datasets(self, root: str) -> Iterable[Dict[str, object]]:
"""Abstract method to be implemented by child class.
Args:
dataset_name (str): name of the dataset.
root (str): directory to download and manipulate data.
save_path (str): directory to store the data after partitioning.
Raises:
NotImplementedError: if the dataset_name is not defined
Returns:
Iterable[Dict[str, object]]: dict of local datasets [split:dataset]
followed by global ones.
"""
raise NotImplementedError
def partition_local_data(self, datasets: Dict[str, object]) -> Dict[str, Iterable[Iterable[int]]]:
raise NotImplementedError
def get_identifiers(self) -> Sequence[str]:
""" Returns identifiers
to be used for saving the partition info.
Raises:
NotImplementedError: this abstract method should be
implemented by child classes
Returns:
Sequence[str]: a sequence of str identifing class instance
"""
raise NotImplementedError
Integration with included cli (DataManager)
To automatically include your custom data manager in the provided cli tool, you can place your class in a file under fedsim/data_manager. Then, call it using option --data-manager. To deliver arguments to the __init__ method of your custom data manager, you can pass options in form of --d-<arg-name> where <arg-name> is the argument. Example
fedsim fed-learn --data-manager CustomDataManager --d-other_arg <other_arg_value> ...
Included DataManager
Provided with the simulator is a basic DataManager called BasicDataManager which for now supports the following datasets
It supports the popular partitioning schemes (iid, Dirichlet distribution, unbalanced, etc.).
FLAlgorithm
Any custome DataManager class should inherit from fedsim.fl.fl_algorithm.FLAlgorithm (or its children) and implement its abstract methods. For example:
from typing import Optional, Hashable, Mapping, Dict, Any
from fedsim.distributed.centralized.training.centralized_fl_algorithm import FLAlgorithm
class CustomFLAlgorithm(FLAlgorithm):
def __init__(
self, data_manager, num_clients, sample_scheme, sample_rate, model_class, epochs, loss_fn,
batch_size, test_batch_size, local_weight_decay, slr, clr, clr_decay, clr_decay_type,
min_clr, clr_step_size, metric_logger, device, log_freq, other_arg, ... , *args, **kwargs,
):
self.other_arg = other_arg
super(FedAvg, self).__init__(
data_manager, num_clients, sample_scheme, sample_rate, model_class, epochs, loss_fn,
batch_size, test_batch_size, local_weight_decay, slr, clr, clr_decay, clr_decay_type,
min_clr, clr_step_size, metric_logger, device, log_freq,
)
# make mode and optimizer
model = self.get_model_class()().to(self.device)
params = deepcopy(
parameters_to_vector(model.parameters()).clone().detach())
optimizer = SGD(params=[params], lr=slr)
# write model and optimizer to server
self.write_server('model', model)
self.write_server('cloud_params', params)
self.write_server('optimizer', optimizer)
...
def send_to_client(self, client_id: int) -> Mapping[Hashable, Any]:
""" returns context to send to the client corresponding to client_id.
Do not send shared objects like server model if you made any
before you deepcopy it.
Args:
client_id (int): id of the receiving client
Raises:
NotImplementedError: abstract class to be implemented by child
Returns:
Mapping[Hashable, Any]: the context to be sent in form of a Mapping
"""
raise NotImplementedError
def send_to_server(
self, client_id: int, datasets: Dict[str, Iterable], epochs: int, loss_fn: nn.Module,
batch_size: int, lr: float, weight_decay: float = 0, device: Union[int, str] = 'cuda',
ctx: Optional[Dict[Hashable, Any]] = None, *args, **kwargs
) -> Mapping[str, Any]:
""" client operation on the recieved information.
Args:
client_id (int): id of the client
datasets (Dict[str, Iterable]): this comes from Data Manager
epochs (int): number of epochs to train
loss_fn (nn.Module): either 'ce' (for cross-entropy) or 'mse'
batch_size (int): training batch_size
lr (float): client learning rate
weight_decay (float, optional): weight decay for SGD. Defaults to 0.
device (Union[int, str], optional): Defaults to 'cuda'.
ctx (Optional[Dict[Hashable, Any]], optional): context reveived from server. Defaults to None.
Raises:
NotImplementedError: abstract class to be implemented by child
Returns:
Mapping[str, Any]: client context to be sent to the server
"""
raise NotImplementedError
def receive_from_client(self, client_id: int, client_msg: Mapping[Hashable, Any], aggregator: Any):
""" receive and aggregate info from selected clients
Args:
client_id (int): id of the sender (client)
client_msg (Mapping[Hashable, Any]): client context that is sent
aggregator (Any): aggregator instance to collect info
Raises:
NotImplementedError: abstract class to be implemented by child
"""
raise NotImplementedError
def optimize(self, aggregator: Any) -> Mapping[Hashable, Any]:
""" optimize server mdoel(s) and return metrics to be reported
Args:
aggregator (Any): Aggregator instance
Raises:
NotImplementedError: abstract class to be implemented by child
Returns:
Mapping[Hashable, Any]: context to be reported
"""
raise NotImplementedError
def deploy(self) -> Optional[Mapping[Hashable, Any]]:
""" return Mapping of name -> parameters_set to test the model
Raises:
NotImplementedError: abstract class to be implemented by child
"""
raise NotImplementedError
def report(
self, dataloaders, metric_logger: Any, device: str, optimize_reports: Mapping[Hashable, Any],
deployment_points: Optional[Mapping[Hashable, torch.Tensor]] = None
) -> None:
"""test on global data and report info
Args:
dataloaders (Any): dict of data loaders to test the global model(s)
metric_logger (Any): the logging object (e.g., SummaryWriter)
device (str): 'cuda', 'cpu' or gpu number
optimize_reports (Mapping[Hashable, Any]): dict returned by optimzier
deployment_points (Mapping[Hashable, torch.Tensor], optional): output of deploy method
Raises:
NotImplementedError: abstract class to be implemented by child
"""
raise NotImplementedError
Integration with included cli (FLAlgorithm)
To automatically include your custom algorithm by the provided cli tool, you can place your class in a file under fedsim/fl/algorithms. Then, call it using option –algorithm. To deliver arguments to the init method of your custom algorithm, you can pass options in form of –a-<arg-name> where <arg-name> is the argument. Example
fedsim fed-learn --algorithm CustomFLAlgorithm --a-other_arg <other_arg_value> ...
other attributes and methods provide by FLAlgorithm
method |
functionality |
---|---|
FLAlgorithm.get_model_class() |
returns the class object of the model architecture |
FLAlgorithm.write_server(key, obj) |
stores obj in server memory, accessible with key |
FLAlgorithm.write_client(client_id, key, obj) |
stores obj in client_id’s memory, accessible with key |
FLAlgorithm.read_server(key) |
returns obj associated with key in server memory |
FLAlgorithm.read_client(client_id, key) |
returns obj associated with key in client_id’s memory |
Included FL algorithms
Alias |
Paper |
---|---|
FedAvg |
|
FedAvgM |
|
FedNova |
|
FedProx |
|
FedDyn |
|
AdaBest |
Model Architectures
Included Architectures
The models used by FedAvg paper are supported:
McMahan’s 2 layer mlp for MNIST
McMahan’s CNN for CIFAR10 and CIFAR100
To use them import fedsim.model.mcmahan_nets.
Integration with included cli
If you want to use a custom pytorch class model with the cli tool, then you can simply place it under fedsim.models and call it:
fedsim fed-learn --model CustomModule ...
0.1.2 (2022-07-05)
changed ownership of repo from fedsim-dev to varnio
0.1.1 (2022-06-22)
added fedsim.scores which wraps torch loss functions and sklearn scores
moved reporting mechanism of distributed algorithm for supporting auto monitor
added AppendixAggregator which is used to hold metric scores and report final results
apply a patch for wrong pypi supported python versions
0.1.0 (2022-06-21)
First major pre-release.
The package is restructured
docs is updated and checked to pass through tox steps
0.0.4 (2022-06-14)
Fourth release on PyPI.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.