Federated Learning framework with CNN and image input
Project description
mds3fl Function Reference
mds3fl provides a TensorFlow-based federated learning workflow for image
regression tasks. The package exposes one main class:
from mds3fl import MDS3FL
MDS3FL can be used as either:
- a federated learning server that aggregates client model weights, or
- a federated learning client that trains locally and sends updated weights back to the server.
Main Class
MDS3FL
MDS3FL(
is_server: bool,
path1: str,
path2: str,
path3: str,
path4: str,
wr2: str,
client_num: int,
server_ip: str,
port: int,
input_shape,
num_conv,
num_nodes,
use_batchnorm,
use_dropout,
use_reg,
LR,
)
Creates a federated learning object, loads labels and image paths, and builds a configurable CNN model.
Parameters
| Name | Type | Description |
|---|---|---|
is_server |
bool |
Set to True for the aggregation server. Set to False for a training client. |
path1 |
str |
Path to the Excel workbook containing training labels. |
path2 |
str |
Directory containing PB1 image arrays. |
path3 |
str |
Directory containing PB2 image arrays. |
path4 |
str |
Directory containing PB3 image arrays. |
wr2 |
str |
Directory containing WR2 image arrays. |
client_num |
int |
Number of clients expected by the server. |
server_ip |
str |
Server IP address used by clients when connecting. |
port |
int |
TCP port for server/client communication. |
input_shape |
tuple |
Shape of one model input, for example (128, 128, 1). |
num_conv |
int |
Number of convolution blocks in the CNN. |
num_nodes |
list[int] |
Dense layer sizes, for example [128, 64]. |
use_batchnorm |
bool |
Adds batch normalization after convolution layers when enabled. |
use_dropout |
bool |
Adds dropout after dense layers when enabled. |
use_reg |
int |
Regularization mode: 0 none, 1 L1, 2 L2, 3 L1/L2. |
LR |
float |
Learning rate for the Adam optimizer. |
Example
from mds3fl import MDS3FL
fl = MDS3FL(
is_server=True,
path1="/path/to/labels.xlsx",
path2="/path/to/pb1",
path3="/path/to/pb2",
path4="/path/to/pb3",
wr2="/path/to/wr2",
client_num=2,
server_ip="127.0.0.1",
port=5000,
input_shape=(128, 128, 1),
num_conv=3,
num_nodes=[128, 64],
use_batchnorm=True,
use_dropout=True,
use_reg=0,
LR=1e-4,
)
Public Methods
train(num_rounds=10)
Runs the federated learning process.
fl.train(num_rounds=10)
Behavior depends on is_server:
-
Server mode:
- opens a TCP socket,
- waits for
client_numclients, - sends global model weights to each client,
- receives updated client weights,
- aggregates client weights with federated averaging.
-
Client mode:
- connects to the server,
- receives global model weights,
- trains a local model,
- sends updated weights back to the server.
Parameters
| Name | Type | Default | Description |
|---|---|---|---|
num_rounds |
int |
10 |
Number of federated training rounds. |
Returns
This method does not currently return a value.
prepare_data(client_id: int, batch_size=8)
Builds a TensorFlow dataset for one client.
dataset = fl.prepare_data(client_id=0, batch_size=8)
The current implementation combines PB1, PB2, and WR2 samples from index 800
onward, shuffles them, and then selects a fixed slice of 100 samples for the
requested client.
Parameters
| Name | Type | Default | Description |
|---|---|---|---|
client_id |
int |
required | Client index used to select a data slice. |
batch_size |
int |
8 |
Number of samples per training batch. |
Returns
Returns a batched tf.data.Dataset of (image, label) pairs.
dynamic_cnn_model()
Builds and returns a TensorFlow Keras CNN model from the current object configuration.
model = fl.dynamic_cnn_model()
The model contains:
- an input layer,
- configurable convolution blocks,
- a flatten layer,
- configurable dense layers,
- a single-value regression output.
Returns
Returns a tf.keras.Model.
federated_avg(weights, client_sizes)
Computes a weighted average of model weights from multiple clients.
global_weights = fl.federated_avg(client_weights, client_sizes)
Each client contributes proportionally to its dataset size:
client_weight * (client_size / total_size)
Parameters
| Name | Type | Description |
|---|---|---|
weights |
list |
List of Keras weight lists, one per client. |
client_sizes |
list |
Number of training samples for each client. |
Returns
Returns a list of NumPy arrays representing aggregated model weights.
train_client_model(dataset, initial_weights, num_epochs=1)
Trains a local client model starting from server-provided weights.
updated_weights = fl.train_client_model(
dataset=dataset,
initial_weights=global_weights,
num_epochs=1,
)
The method creates a new CNN model, compiles it with Adam and mean squared
error, loads initial_weights, trains on the provided dataset, and returns the
updated weights.
Parameters
| Name | Type | Default | Description |
|---|---|---|---|
dataset |
tf.data.Dataset |
required | Local client training dataset. |
initial_weights |
list |
required | Initial Keras model weights received from the server. |
num_epochs |
int |
1 |
Number of local training epochs. |
Returns
Returns the trained model weights as a list of NumPy arrays.
Data Loading Methods
retrieveData()
Loads labels from the Excel workbook and collects image file paths from the PB1, PB2, PB3, and WR2 directories.
data = fl.retrieveData()
The method expects the workbook to contain a sheet named Beta Fraction and
uses fixed column positions and row ranges.
Returns
Returns a tuple:
(
pb1_image_paths,
pb2_image_paths,
pb3_image_paths,
wr2_image_paths,
pb1_label,
pb2_label,
pb3_label,
wr2_label,
)
path_label_generator(img_paths, labels)
Generates (image_path, label) pairs from parallel path and label lists.
for image_path, label in fl.path_label_generator(img_paths, labels):
...
This method is mainly used internally by prepare_data().
load_image(file_path, label)
Loads one .npy image array and pairs it with its label.
image, label = fl.load_image(file_path, label)
The method:
- loads the file with
np.load, - adds a channel dimension,
- converts the image to
tf.float32, - converts the label to
tf.float32.
This method is called through TensorFlow's tf.py_function.
load_image_wrapper(file_path, label)
TensorFlow wrapper around load_image().
image, label = fl.load_image_wrapper(file_path, label)
This wrapper allows NumPy-based file loading to run inside a TensorFlow dataset pipeline.
Socket Communication Methods
send_data(sock, data)
Serializes Python data with pickle and sends it through a socket.
fl.send_data(sock, global_weights)
Parameters
| Name | Type | Description |
|---|---|---|
sock |
socket.socket |
Connected socket object. |
data |
Any |
Pickle-serializable data to send. |
recv_data(sock)
Receives bytes from a socket and deserializes them with pickle.
weights = fl.recv_data(sock)
Parameters
| Name | Type | Description |
|---|---|---|
sock |
socket.socket |
Connected socket object. |
Returns
Returns the deserialized Python object.
Minimal Server and Client Example
Start the server first:
from mds3fl import MDS3FL
server = MDS3FL(
is_server=True,
path1="/path/to/labels.xlsx",
path2="/path/to/pb1",
path3="/path/to/pb2",
path4="/path/to/pb3",
wr2="/path/to/wr2",
client_num=2,
server_ip="127.0.0.1",
port=5000,
input_shape=(128, 128, 1),
num_conv=3,
num_nodes=[128, 64],
use_batchnorm=True,
use_dropout=True,
use_reg=0,
LR=1e-4,
)
server.train(num_rounds=10)
Then start each client process:
from mds3fl import MDS3FL
client = MDS3FL(
is_server=False,
path1="/path/to/labels.xlsx",
path2="/path/to/pb1",
path3="/path/to/pb2",
path4="/path/to/pb3",
wr2="/path/to/wr2",
client_num=2,
server_ip="127.0.0.1",
port=5000,
input_shape=(128, 128, 1),
num_conv=3,
num_nodes=[128, 64],
use_batchnorm=True,
use_dropout=True,
use_reg=0,
LR=1e-4,
)
client.train(num_rounds=10)
Notes for PyPI Users
- This package is intended for trusted research environments.
- The socket transport uses
pickle, so it should not receive data from untrusted clients or servers. - The current data loader expects a specific Excel workbook structure.
- Image files should be NumPy
.npyarrays. - The model is currently configured for regression with mean squared error.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mds3fl-0.1.2.tar.gz.
File metadata
- Download URL: mds3fl-0.1.2.tar.gz
- Upload date:
- Size: 12.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
acb3d965ef91206705a7a72646dfa20360851da47bbd317e9ac8b8ebea1dd6f3
|
|
| MD5 |
dcbe5ac9ef6e49cdb38fb97bcf889951
|
|
| BLAKE2b-256 |
c3ed1c1716cc29b0421d8a9ddc9d7c156b6da7abe2aa1238c8e3753c6aa62aba
|
File details
Details for the file mds3fl-0.1.2-py3-none-any.whl.
File metadata
- Download URL: mds3fl-0.1.2-py3-none-any.whl
- Upload date:
- Size: 7.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ade7386953ac7662a14966f4735f4ff8ae13e0f6f931850b6cee56a8d4c4a886
|
|
| MD5 |
11f4d01a647aa2ae648114c561b40ae3
|
|
| BLAKE2b-256 |
0f1a085ab2d27436e64e19828785617e483ceff4ac4decd78d76c07ac48934a1
|