Skip to main content

Federated Learning framework with CNN and image input

Project description

mds3fl Function Reference

mds3fl provides a TensorFlow-based federated learning workflow for image regression tasks. The package exposes one main class:

from mds3fl import MDS3FL

MDS3FL can be used as either:

  • a federated learning server that aggregates client model weights, or
  • a federated learning client that trains locally and sends updated weights back to the server.

Main Class

MDS3FL

MDS3FL(
    is_server: bool,
    path1: str,
    path2: str,
    path3: str,
    path4: str,
    wr2: str,
    client_num: int,
    server_ip: str,
    port: int,
    input_shape,
    num_conv,
    num_nodes,
    use_batchnorm,
    use_dropout,
    use_reg,
    LR,
)

Creates a federated learning object, loads labels and image paths, and builds a configurable CNN model.

Parameters

Name Type Description
is_server bool Set to True for the aggregation server. Set to False for a training client.
path1 str Path to the Excel workbook containing training labels.
path2 str Directory containing PB1 image arrays.
path3 str Directory containing PB2 image arrays.
path4 str Directory containing PB3 image arrays.
wr2 str Directory containing WR2 image arrays.
client_num int Number of clients expected by the server.
server_ip str Server IP address used by clients when connecting.
port int TCP port for server/client communication.
input_shape tuple Shape of one model input, for example (128, 128, 1).
num_conv int Number of convolution blocks in the CNN.
num_nodes list[int] Dense layer sizes, for example [128, 64].
use_batchnorm bool Adds batch normalization after convolution layers when enabled.
use_dropout bool Adds dropout after dense layers when enabled.
use_reg int Regularization mode: 0 none, 1 L1, 2 L2, 3 L1/L2.
LR float Learning rate for the Adam optimizer.

Example

from mds3fl import MDS3FL

fl = MDS3FL(
    is_server=True,
    path1="/path/to/labels.xlsx",
    path2="/path/to/pb1",
    path3="/path/to/pb2",
    path4="/path/to/pb3",
    wr2="/path/to/wr2",
    client_num=2,
    server_ip="127.0.0.1",
    port=5000,
    input_shape=(128, 128, 1),
    num_conv=3,
    num_nodes=[128, 64],
    use_batchnorm=True,
    use_dropout=True,
    use_reg=0,
    LR=1e-4,
)

Public Methods

train(num_rounds=10)

Runs the federated learning process.

fl.train(num_rounds=10)

Behavior depends on is_server:

  • Server mode:

    • opens a TCP socket,
    • waits for client_num clients,
    • sends global model weights to each client,
    • receives updated client weights,
    • aggregates client weights with federated averaging.
  • Client mode:

    • connects to the server,
    • receives global model weights,
    • trains a local model,
    • sends updated weights back to the server.

Parameters

Name Type Default Description
num_rounds int 10 Number of federated training rounds.

Returns

This method does not currently return a value.

prepare_data(client_id: int, batch_size=8)

Builds a TensorFlow dataset for one client.

dataset = fl.prepare_data(client_id=0, batch_size=8)

The current implementation combines PB1, PB2, and WR2 samples from index 800 onward, shuffles them, and then selects a fixed slice of 100 samples for the requested client.

Parameters

Name Type Default Description
client_id int required Client index used to select a data slice.
batch_size int 8 Number of samples per training batch.

Returns

Returns a batched tf.data.Dataset of (image, label) pairs.

dynamic_cnn_model()

Builds and returns a TensorFlow Keras CNN model from the current object configuration.

model = fl.dynamic_cnn_model()

The model contains:

  1. an input layer,
  2. configurable convolution blocks,
  3. a flatten layer,
  4. configurable dense layers,
  5. a single-value regression output.

Returns

Returns a tf.keras.Model.

federated_avg(weights, client_sizes)

Computes a weighted average of model weights from multiple clients.

global_weights = fl.federated_avg(client_weights, client_sizes)

Each client contributes proportionally to its dataset size:

client_weight * (client_size / total_size)

Parameters

Name Type Description
weights list List of Keras weight lists, one per client.
client_sizes list Number of training samples for each client.

Returns

Returns a list of NumPy arrays representing aggregated model weights.

train_client_model(dataset, initial_weights, num_epochs=1)

Trains a local client model starting from server-provided weights.

updated_weights = fl.train_client_model(
    dataset=dataset,
    initial_weights=global_weights,
    num_epochs=1,
)

The method creates a new CNN model, compiles it with Adam and mean squared error, loads initial_weights, trains on the provided dataset, and returns the updated weights.

Parameters

Name Type Default Description
dataset tf.data.Dataset required Local client training dataset.
initial_weights list required Initial Keras model weights received from the server.
num_epochs int 1 Number of local training epochs.

Returns

Returns the trained model weights as a list of NumPy arrays.

Data Loading Methods

retrieveData()

Loads labels from the Excel workbook and collects image file paths from the PB1, PB2, PB3, and WR2 directories.

data = fl.retrieveData()

The method expects the workbook to contain a sheet named Beta Fraction and uses fixed column positions and row ranges.

Returns

Returns a tuple:

(
    pb1_image_paths,
    pb2_image_paths,
    pb3_image_paths,
    wr2_image_paths,
    pb1_label,
    pb2_label,
    pb3_label,
    wr2_label,
)

path_label_generator(img_paths, labels)

Generates (image_path, label) pairs from parallel path and label lists.

for image_path, label in fl.path_label_generator(img_paths, labels):
    ...

This method is mainly used internally by prepare_data().

load_image(file_path, label)

Loads one .npy image array and pairs it with its label.

image, label = fl.load_image(file_path, label)

The method:

  1. loads the file with np.load,
  2. adds a channel dimension,
  3. converts the image to tf.float32,
  4. converts the label to tf.float32.

This method is called through TensorFlow's tf.py_function.

load_image_wrapper(file_path, label)

TensorFlow wrapper around load_image().

image, label = fl.load_image_wrapper(file_path, label)

This wrapper allows NumPy-based file loading to run inside a TensorFlow dataset pipeline.

Socket Communication Methods

send_data(sock, data)

Serializes Python data with pickle and sends it through a socket.

fl.send_data(sock, global_weights)

Parameters

Name Type Description
sock socket.socket Connected socket object.
data Any Pickle-serializable data to send.

recv_data(sock)

Receives bytes from a socket and deserializes them with pickle.

weights = fl.recv_data(sock)

Parameters

Name Type Description
sock socket.socket Connected socket object.

Returns

Returns the deserialized Python object.

Minimal Server and Client Example

Start the server first:

from mds3fl import MDS3FL

server = MDS3FL(
    is_server=True,
    path1="/path/to/labels.xlsx",
    path2="/path/to/pb1",
    path3="/path/to/pb2",
    path4="/path/to/pb3",
    wr2="/path/to/wr2",
    client_num=2,
    server_ip="127.0.0.1",
    port=5000,
    input_shape=(128, 128, 1),
    num_conv=3,
    num_nodes=[128, 64],
    use_batchnorm=True,
    use_dropout=True,
    use_reg=0,
    LR=1e-4,
)

server.train(num_rounds=10)

Then start each client process:

from mds3fl import MDS3FL

client = MDS3FL(
    is_server=False,
    path1="/path/to/labels.xlsx",
    path2="/path/to/pb1",
    path3="/path/to/pb2",
    path4="/path/to/pb3",
    wr2="/path/to/wr2",
    client_num=2,
    server_ip="127.0.0.1",
    port=5000,
    input_shape=(128, 128, 1),
    num_conv=3,
    num_nodes=[128, 64],
    use_batchnorm=True,
    use_dropout=True,
    use_reg=0,
    LR=1e-4,
)

client.train(num_rounds=10)

Notes for PyPI Users

  • This package is intended for trusted research environments.
  • The socket transport uses pickle, so it should not receive data from untrusted clients or servers.
  • The current data loader expects a specific Excel workbook structure.
  • Image files should be NumPy .npy arrays.
  • The model is currently configured for regression with mean squared error.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mds3fl-0.1.2.tar.gz (12.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mds3fl-0.1.2-py3-none-any.whl (7.7 kB view details)

Uploaded Python 3

File details

Details for the file mds3fl-0.1.2.tar.gz.

File metadata

  • Download URL: mds3fl-0.1.2.tar.gz
  • Upload date:
  • Size: 12.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for mds3fl-0.1.2.tar.gz
Algorithm Hash digest
SHA256 acb3d965ef91206705a7a72646dfa20360851da47bbd317e9ac8b8ebea1dd6f3
MD5 dcbe5ac9ef6e49cdb38fb97bcf889951
BLAKE2b-256 c3ed1c1716cc29b0421d8a9ddc9d7c156b6da7abe2aa1238c8e3753c6aa62aba

See more details on using hashes here.

File details

Details for the file mds3fl-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: mds3fl-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 7.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for mds3fl-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 ade7386953ac7662a14966f4735f4ff8ae13e0f6f931850b6cee56a8d4c4a886
MD5 11f4d01a647aa2ae648114c561b40ae3
BLAKE2b-256 0f1a085ab2d27436e64e19828785617e483ceff4ac4decd78d76c07ac48934a1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page