Partial Model Sharing for Federated Learning - A PyTorch and Flower extension
Project description
PSFed: Partial Model Sharing for Federated Learning
PSFed is a research-grade Python package that implements partial model sharing in federated learning. Instead of communicating the entire model between server and clients, PSFed enables selective parameter synchronization based on configurable masking strategies.
Key Features
- 🎯 Parameter-level granularity: Share any subset of model parameters
- 🔀 Dynamic masking: Masks change per round, ensuring eventual full synchronization
- 📊 Multiple strategies: Random, top-k magnitude, gradient-based, and custom selectors
- 🌸 Flower integration: Drop-in strategy and client implementations
- 🔥 PyTorch native: Works with any
nn.Module - ⚡ Communication efficient: Reduce bandwidth by 50-90%
Installation
pip install psfed
Or install from source:
git clone https://github.com/ehsan-lari/psfed.git
cd psfed
pip install -e ".[dev]"
Quick Start
Basic Usage (Without Flower)
import torch
import torch.nn as nn
from psfed import FlattenedModel, RandomMaskSelector
# Your PyTorch model
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# Wrap for partial sharing
flat_model = FlattenedModel(model)
print(f"Total parameters: {flat_model.num_parameters}")
# Create a mask selector (50% of parameters, changes each round)
selector = RandomMaskSelector(fraction=0.5, seed=42)
# Generate mask for round 1
mask = selector.select(flat_model.num_parameters, round_num=1)
print(f"Selected {mask.count} / {mask.size} parameters")
# Extract selected parameters (for sending to clients)
partial_params = flat_model.extract(mask)
# ... transmit partial_params to client ...
# Apply received parameters (on client side)
flat_model.apply(partial_params, mask)
Federated Learning with Flower
Server:
import flwr as fl
from psfed.flower import PSFedAvg
# Define strategy with partial sharing
strategy = PSFedAvg(
fraction_fit=0.1,
fraction_evaluate=0.1,
min_fit_clients=2,
min_available_clients=2,
# PSFed-specific parameters
mask_fraction=0.5, # Share 50% of parameters
mask_strategy="random", # Per-round random selection
mask_seed=42, # Reproducibility
)
# Start server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=10),
strategy=strategy,
)
Client:
import flwr as fl
from psfed.flower import PSFedClient
class MyClient(PSFedClient):
def __init__(self, model, trainloader):
super().__init__(model)
self.trainloader = trainloader
def train_local(self, epochs: int = 1):
# Your training logic here
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
self.model.train()
for _ in range(epochs):
for images, labels in self.trainloader:
optimizer.zero_grad()
loss = criterion(self.model(images), labels)
loss.backward()
optimizer.step()
# Start client
client = MyClient(model, trainloader)
fl.client.start_client(server_address="127.0.0.1:8080", client=client)
Mask Selection Strategies
PSFed provides several built-in mask selection strategies:
| Strategy | Description | Use Case |
|---|---|---|
RandomMaskSelector |
Per-round random selection | Default, ensures coverage over time |
TopKMagnitudeSelector |
Select largest parameters by absolute value | Focus on important weights |
GradientBasedSelector |
Select by gradient magnitude | Active/changing parameters |
StructuredMaskSelector |
Layer-aware selection | Preserve structure |
FixedMaskSelector |
User-defined indices | Full control |
Custom Selector
from psfed import MaskSelector, Mask
class MyCustomSelector(MaskSelector):
def select(
self,
num_parameters: int,
round_num: int,
**kwargs
) -> Mask:
# Your selection logic
indices = your_custom_logic(num_parameters, round_num)
return Mask.from_indices(indices, size=num_parameters)
API Reference
Core Classes
FlattenedModel: Wraps a PyTorch model for flatten/unflatten operationsMask: Binary mask with efficient storage and operationsMaskSelector: Abstract base class for selection strategies
Flower Integration
PSFedAvg: FedAvg strategy with partial model sharingPSFedClient: Base client class handling partial parameters
Research Background
This package implements the concept of partial model sharing in federated learning, where communication efficiency is achieved by transmitting only a subset of model parameters each round.
Key properties:
- Communication reduction: Proportional to
(1 - mask_fraction) - Convergence: Dynamic masking ensures all parameters are eventually synchronized
- Privacy: Non-shared parameters remain local
For theoretical analysis, see docs/research.md.
Examples
- MNIST Basic - Simple partial sharing example
- CIFAR-10 Advanced - Custom selectors and analysis
Contributing
Contributions are welcome! Please see CONTRIBUTING.md for guidelines.
# Setup development environment
pip install -e ".[dev]"
pre-commit install
# Run tests
pytest
# Type checking
mypy src/psfed
# Linting
ruff check src/psfed
Citation
If you use PSFed in your research, please cite:
@article{lari2025resilience,
title={Resilience in Online Federated Learning: Mitigating Model-Poisoning Attacks via Partial Sharing},
author={Lari, Ehsan and Arablouei, Reza and Gogineni, Vinay Chakravarthi and Werner, Stefan},
journal={IEEE Transactions on Signal and Information Processing over Networks},
year={2025},
publisher={IEEE}
}
License
MIT License - see LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file psfed-0.1.2.tar.gz.
File metadata
- Download URL: psfed-0.1.2.tar.gz
- Upload date:
- Size: 1.3 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8d210f6950ea5ecab6acb220c43d07464d3ff5d72e3678a6046ed58c4810ba7a
|
|
| MD5 |
a28c55f3a8b35eee07b79455fb50c5e8
|
|
| BLAKE2b-256 |
9f7427c669162e298a1b2f3d5ff316b82701e13959e7c7a0db0cfc0956d81e11
|
File details
Details for the file psfed-0.1.2-py3-none-any.whl.
File metadata
- Download URL: psfed-0.1.2-py3-none-any.whl
- Upload date:
- Size: 24.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f001f4e477994f6ce88f12b8ceec6f07254680dc5236ca6ab89b577d7ecbedb0
|
|
| MD5 |
e49edfb5be4643e8a67045f2a0f0b489
|
|
| BLAKE2b-256 |
12afcce9f17fe5966ef062f331e2cb23fabe60e2d709a022ef0cefd3f6582a6f
|