Skip to main content

A Python library for rapid prototyping, experimenting, and logging of federated learning using state-of-the-art models and datasets. Built using PyTorch and PyTorch Lightning.

Project description



Table of Contents

Features

  • Python 3.6+ support. Built using torch-1.10.1, torchvision-0.11.2, and pytorch-lightning-1.5.7.
  • Customizable implementations for state-of-the-art deep learning models which can be trained in federated or non-federated settings.
  • Supports finetuning of the pre-trained deep learning models, allowing for faster training using transfer learning.
  • PyTorch LightningDataModule wrappers for the most commonly used datasets to reduce the boilerplate code before experiments.
  • Built using the bottom-up approach for the datamodules and models which ensures abstractions while allowing for customization.
  • Provides implementation of the federated learning (FL) samplers, aggregators, and wrappers, to prototype FL experiments on-the-go.
  • Backwards compatible with the PyTorch LightningDataModule, LightningModule, loggers, and DevOps tools.
  • More details about the examples and usage can be found below.
  • For more documentation related to the usage, visit - https://torchfl.readthedocs.io/.

Installation

Stable Release

As of now, torchfl is available on PyPI and can be installed using the following command in your terminal:

$ pip install torchfl

This is the preferred method to install torchfl with the most stable release. If you don't have pip installed, this Python installation guide can guide you through the process.

Examples and Usage

Although torchfl is primarily built for quick prototyping of federated learning experiments, the models, datasets, and abstractions can also speed up the non-federated learning experiments. In this section, we will explore examples and usages under both the settings.

Non-Federated Learning

The following steps should be followed on a high-level to train a non-federated learning experiment. We are using the EMNIST (MNIST) dataset and densenet121 for this example.

  1. Import the relevant modules.

    from torchfl.datamodules.emnist import EMNISTDataModule
    from torchfl.models.wrapper.emnist import MNISTEMNIST
    
    import pytorch_lightning as pl
    from pytorch_lightning.loggers import TensorBoardLogger
    from pytorch_lightning.callbacks import (
    	ModelCheckpoint,
    	LearningRateMonitor,
    	DeviceStatsMonitor,
    	ModelSummary,
    	ProgressBar,
    	...
    )
    

    For more details, view the full list of PyTorch Lightning callbacks and loggers on the official website.

  2. Setup the PyTorch Lightning trainer.

    trainer = pl.Trainer(
    	...
    	logger=[
    		TensorBoardLogger(
    			name=experiment_name,
    			save_dir=os.path.join(checkpoint_save_path, experiment_name),
    		)
    	],
    	callbacks=[
    		ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
    		LearningRateMonitor("epoch"),
    		DeviceStatsMonitor(),
    		ModelSummary(),
    		ProgressBar(),
    	],
    	...
    )
    

    More details about the PyTorch Lightning Trainer API can be found on their official website.

  3. Prepare the dataset using the wrappers provided by torchfl.datamodules.

    datamodule = EMNISTDataModule(dataset_name="mnist")
    datamodule.prepare_data()
    datamodule.setup()
    
  4. Initialize the model using the wrappers provided by torchfl.models.wrappers.

    # check if the model can be loaded from a given checkpoint
    if (checkpoint_load_path) and os.path.isfile(checkpoint_load_path):
    	model = MNISTEMNIST(
    		"densenet121", "adam", {"lr": 0.001}
    		).load_from_checkpoint(checkpoint_load_path)
    else:
    	pl.seed_everything(42)
    	model = MNISTEMNIST("densenet121", "adam", {"lr": 0.001})
    	trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())
    
  5. Collect the results.

    val_result = trainer.test(
    	model, test_dataloaders=datamodule.val_dataloader(), verbose=True
    )
    test_result = trainer.test(
    	model, test_dataloaders=datamodule.test_dataloader(), verbose=True
    )
    
  6. The corresponding files for the experiment (model checkpoints and logger metadata) will be stored at default_root_dir argument given to the PyTorch Lightning Trainer object in Step 2. For this experiment, we use the Tensorboard logger. To view the logs (and related plots and metrics), go to the default_root_dir path and find the Tensorboard log files. Upload the files to the Tensorboard Development portal following the instructions here. Once the log files are uploaded, a unique url to your experiment will be generated which can be shared with ease! An example can be found here.

  7. Note that, torchfl is compatible with all the loggers supported by PyTorch Lightning. More information about the PyTorch Lightning loggers can be found here.

For full non-federated learning example scripts, check examples/trainers.

Federated Learning

The following steps should be followed on a high-level to train a federated learning experiment.

  1. Pick a dataset and use the datamodules to create federated data shards with iid or non-iid distribution.
    def get_datamodule() -> EMNISTDataModule:
    	datamodule: EMNISTDataModule = EMNISTDataModule(
    		dataset_name=SUPPORTED_DATASETS_TYPE.MNIST, train_batch_size=10
    	)
    	datamodule.prepare_data()
    	datamodule.setup()
    	return datamodule
    
    agent_data_shard_map = get_agent_data_shard_map().federated_iid_dataloader(
        num_workers=fl_params.num_agents,
        workers_batch_size=fl_params.local_train_batch_size,
    )
    
  2. Use the TorchFL agents module and the models module to initialize the global model, agents, and distribute their models.
    def initialize_agents(
    	fl_params: FLParams, agent_data_shard_map: Dict[int, DataLoader]
    ) -> List[V1Agent]:
    	"""Initialize agents."""
    	agents = []
    	for agent_id in range(fl_params.num_agents):
    		agent = V1Agent(
    			id=agent_id,
    			model=MNISTEMNIST(
    				model_name=EMNIST_MODELS_ENUM.MOBILENETV3SMALL,
    				optimizer_name=OPTIMIZERS_TYPE.ADAM,
    				optimizer_hparams={"lr": 0.001},
    				model_hparams={"pre_trained": True, "feature_extract": True},
    				fl_hparams=fl_params,
    			),
    			data_shard=agent_data_shard_map[agent_id],
    		)
    		agents.append(agent)
    	return agents
    
    global_model = MNISTEMNIST(
        model_name=EMNIST_MODELS_ENUM.MOBILENETV3SMALL,
        optimizer_name=OPTIMIZERS_TYPE.ADAM,
        optimizer_hparams={"lr": 0.001},
        model_hparams={"pre_trained": True, "feature_extract": True},
        fl_hparams=fl_params,
    )
    
    all_agents = initialize_agents(fl_params, agent_data_shard_map)
    
  3. Initiliaze an FLParam object with the desired FL hyperparameters and pass it on to the Entrypoint object which will abstract the training.
    fl_params = FLParams(
        experiment_name="iid_mnist_fedavg_10_agents_5_sampled_50_epochs_mobilenetv3small_latest",
        num_agents=10,
        global_epochs=10,
        local_epochs=2,
        sampling_ratio=0.5,
    )
    entrypoint = Entrypoint(
        global_model=global_model,
        global_datamodule=get_agent_data_shard_map(),
        fl_hparams=fl_params,
        agents=all_agents,
        aggregator=FedAvgAggregator(all_agents=all_agents),
        sampler=RandomSampler(all_agents=all_agents),
    )
    entrypoint.run()
    

For full federated learning example scripts, check examples/federated.

Available Models

For the initial release, torchfl will only support state-of-the-art computer vision models. The following table summarizes the available models, support for pre-training, and the possibility of feature-extracting. Please note that the models have been tested with all the available datasets. Therefore, the link to the tests will be provided in the next section.

Name Pre-Training Feature Extraction
AlexNet :white_check_mark: :white_check_mark:
DenseNet121 :white_check_mark: :white_check_mark:
DenseNet161 :white_check_mark: :white_check_mark:
DenseNet169 :white_check_mark: :white_check_mark:
DenseNet201 :white_check_mark: :white_check_mark:
LeNet :x: :x:
MLP :x: :x:
MobileNetV2 :white_check_mark: :white_check_mark:
MobileNetV3Small :white_check_mark: :white_check_mark:
MobileNetV3Large :white_check_mark: :white_check_mark:
ResNet18 :white_check_mark: :white_check_mark:
ResNet34 :white_check_mark: :white_check_mark:
ResNet50 :white_check_mark: :white_check_mark:
ResNet101 :white_check_mark: :white_check_mark:
ResNet152 :white_check_mark: :white_check_mark:
ResNext50(32x4d) :white_check_mark: :white_check_mark:
ResNext101(32x8d) :white_check_mark: :white_check_mark:
WideResNet(50x2) :white_check_mark: :white_check_mark:
WideResNet(101x2) :white_check_mark: :white_check_mark:
ShuffleNetv2(x0.5) :white_check_mark: :white_check_mark:
ShuffleNetv2(x1.0) :white_check_mark: :white_check_mark:
ShuffleNetv2(x1.5) :x: :x:
ShuffleNetv2(x2.0) :x: :x:
SqueezeNet1.0 :white_check_mark: :white_check_mark:
SqueezeNet1.1 :white_check_mark: :white_check_mark:
VGG11 :white_check_mark: :white_check_mark:
VGG11_BatchNorm :white_check_mark: :white_check_mark:
VGG13 :white_check_mark: :white_check_mark:
VGG13_BatchNorm :white_check_mark: :white_check_mark:
VGG16 :white_check_mark: :white_check_mark:
VGG16_BatchNorm :white_check_mark: :white_check_mark:
VGG19 :white_check_mark: :white_check_mark:
VGG19_BatchNorm :white_check_mark: :white_check_mark:

Available Datasets

Following datasets have been wrapped inside a LightningDataModule and made available for the initial release of torchfl. To add a new dataset, check the source code in torchfl.datamodules, add tests, and create a PR with Features tag.

Group Datasets IID Split Non-IID Split Datamodules Tests Models Models Tests
CIFAR :white_check_mark: :white_check_mark:
EMNIST :white_check_mark: :white_check_mark:
FashionMNIST FashionMNIST :white_check_mark: :white_check_mark: FashionMNIST

Contributing

Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given.

You can contribute in many ways:

Types of Contributions

Report Bugs

Report bugs at https://github.com/vivekkhimani/torchfl/issues.

If you are reporting a bug, please include:

  • Your operating system name and version.
  • Any details about your local setup that might be helpful in troubleshooting.
  • Detailed steps to reproduce the bug.

Fix Bugs

Look through the GitHub issues for bugs. Anything tagged with "bug" and "help wanted" is open to whoever wants to implement it.

Implement Features

Look through the GitHub issues for features. Anything tagged with "enhancement", "help wanted", "feature" is open to whoever wants to implement it.

Write Documentation

torchfl could always use more documentation, whether as part of the official torchfl docs, in docstrings, or even on the web in blog posts, articles, and such.

Submit Feedback

The best way to send feedback is to file an issue at https://github.com/vivekkhimani/torchfl/issues. If you are proposing a feature:

  • Explain in detail how it would work.
  • Keep the scope as narrow as possible, to make it easier to implement.
  • Remember that this is a volunteer-driven project, and that contributions are welcome :)

Get Started

Ready to contribute? Here's how to set up torchfl for local development.

  1. Fork the torchfl repo on GitHub.
  2. Clone your fork locally:
$ git clone git@github.com:<your_username_here>/torchfl.git
  1. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:
$ mkvirtualenv torchfl
$ cd torchfl/
$ python setup.py develop
  1. Create a branch for local development:
$ git checkout -b name-of-your-bugfix-or-feature

Now you can make your changes locally and maintain them on your own branch.

  1. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox:
tox

Run tox --help to explore more features for tox.

  1. Your changes need to pass the existing test cases and add the new ones if required under the tests directory. Following approaches can be used to run the test cases.
  • To run all the test cases.
$ coverage run -m pytest tests
  • To run a specific file containing the test cases.
$ coverage run -m pytest <path-to-the-file>
  1. Commit your changes and push your branch to GitHub:
$ git add --all
$ git commit -m "Your detailed description of your changes."
$ git push origin <name-of-your-bugfix-or-feature>
  1. Submit a pull request through the Github web interface.
  2. Once the pull request has been submitted, the continuous integration pipelines on Github Actions will be triggered. Ensure that all of them pass before one of the maintainers can review the request.

Pull Request Guidelines

Before you submit a pull request, check that it meets these guidelines:

  1. The pull request should include tests.
    • Try adding new test cases for new features or enhancements and make changes to the CI pipelines accordingly.
    • Modify the existing tests (if required) for the bug fixes.
  2. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.md.
  3. The pull request should pass all the existing CI pipelines (Github Actions) and the new/modified workflows should be added as required.
  4. Please note that the test cases should only be run in the CI pipeline if the direct/indirect dependencies of the tests have changed. Look at the workflow YAML files for more details or reach out to one of the contributors.

Deploying

A reminder for the maintainers on how to deploy. Make sure all your changes are committed (including an entry in HISTORY.rst). Then run:

$ bump2version patch # possible: major / minor / patch
$ git push
$ git push --tags

Citation

Please cite the following article if you end up using this software:

@misc{https://doi.org/10.48550/arxiv.2211.00735,
  doi = {10.48550/ARXIV.2211.00735},
  url = {https://arxiv.org/abs/2211.00735},
  author = {Khimani, Vivek and Jabbari, Shahin},
  keywords = {Machine Learning (cs.LG), Distributed, Parallel, and Cluster Computing (cs.DC), Systems and Control (eess.SY), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering, I.2.11},
  title = {TorchFL: A Performant Library for Bootstrapping Federated Learning Experiments},
  publisher = {arXiv},
  year = {2022},
  copyright = {Creative Commons Attribution Non Commercial Share Alike 4.0 International}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchfl-0.1.9.tar.gz (59.0 kB view details)

Uploaded Source

Built Distribution

torchfl-0.1.9-py3-none-any.whl (124.4 kB view details)

Uploaded Python 3

File details

Details for the file torchfl-0.1.9.tar.gz.

File metadata

  • Download URL: torchfl-0.1.9.tar.gz
  • Upload date:
  • Size: 59.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.10.6 Linux/5.15.0-1033-azure

File hashes

Hashes for torchfl-0.1.9.tar.gz
Algorithm Hash digest
SHA256 71cb1c245593cd092235e829c5ba1821f9a68e591890d4f160af8a46de2d3a21
MD5 6930be36ca555836488b6ae62789d1e0
BLAKE2b-256 ac1350f347c0fd46bc796f6880f50c4d5a052be93d07923ceb264daa70c8e35c

See more details on using hashes here.

File details

Details for the file torchfl-0.1.9-py3-none-any.whl.

File metadata

  • Download URL: torchfl-0.1.9-py3-none-any.whl
  • Upload date:
  • Size: 124.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.10.6 Linux/5.15.0-1033-azure

File hashes

Hashes for torchfl-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 351a8d4e80766167bdcf3c2a197b919c8db56fb239c5be228ce02849d4e1c7bc
MD5 b1a305c140422223c62f062d9d5afcc4
BLAKE2b-256 62981cdbd30b6a47dbbb35c0c0f8cecbb1eefabef61d27cc8a506e75eccd9a6a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page