Skip to main content

Client application to interface with the BranchKey system

Project description

BranchKey Python Client

BK_logo

PyPI version Python License: GPL v3

Official Python client for the BranchKey federated learning platform. This library provides a simple interface to upload model weights, download aggregated results, and track training runs.

Installation

pip install branchkey

Requirements: Python 3.9 or higher

Quick Start

1. Get Credentials

Create a leaf entity through the BranchKey platform to obtain credentials via the /v2/entities API endpoint.

2. Initialise Client

from branchkey import (
    Client,
    Credentials,
    APIConfig,
    RabbitMQConfig,
    WebSocketConfig,
    RunConfig,
    RetryConfig,
)

# Create credentials
credentials = Credentials(
    id="your-leaf-uuid",
    name="my-client",
    session_token="your-session-token-uuid",
    owner_id="your-user-uuid",
    tree_id="your-tree-uuid",
    branch_id="your-branch-uuid",
)

# Initialise client with default settings
client = Client(credentials)

# Or with custom configuration
client = Client(
    credentials=credentials,
    api_config=APIConfig(
        host="https://app.branchkey.com",
        ssl=True,
    ),
    rabbitmq_config=RabbitMQConfig(
        port=5671,
        ssl=True,
    ),
    run_config=RunConfig(
        wait_for_run=False,
        check_interval_s=30,
    ),
)

3. Upload Model Weights

import numpy as np

# Prepare model weights
weighting = 1000  # Weight for aggregation (typically number of samples)
parameters = [layer1_weights, layer2_weights, ...]

# Save and upload
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
print(f"Uploaded: {file_id}")

4. Download Aggregated Results

# Wait for aggregation notification
aggregation_id = client.queue.get(block=True)  # Blocks until aggregation ready
client.file_download(aggregation_id)
print(f"Downloaded to: ./aggregated_output/{aggregation_id}.npz")

# Or check without blocking
if not client.queue.empty():
    aggregation_id = client.queue.get(block=False)
    client.file_download(aggregation_id)

Configuration

All configuration uses immutable dataclasses for type safety and clarity.

Credentials

from branchkey import Credentials

credentials = Credentials(
    id="leaf-uuid",
    name="my-leaf",
    session_token="token-uuid",
    owner_id="user-uuid",
    tree_id="tree-uuid",
    branch_id="branch-uuid",
)

# Or from a dictionary
credentials = Credentials.from_dict(creds_dict)

API Configuration

from branchkey import APIConfig

api_config = APIConfig(
    host="https://app.branchkey.com",  # API endpoint (default)
    ssl=True,                           # Verify SSL certificates (default)
    proxies=None,                       # Optional proxy dict
)

Transport: AMQP (RabbitMQ) vs WebSocket

The client supports two transport mechanisms for receiving aggregation notifications:

AMQP/RabbitMQ (Default)

from branchkey import Client, Credentials, RabbitMQConfig

client = Client(
    credentials=credentials,
    rabbitmq_config=RabbitMQConfig(
        host=None,                      # Auto-derived from API host
        port=5671,                      # TLS port (default)
        ssl=True,                       # Use TLS (default)
        max_reconnect_attempts=0,       # 0 = infinite retry (default)
        reconnect_backoff_factor=2.0,   # Exponential backoff multiplier
        reconnect_max_delay=60,         # Max delay in seconds
    ),
    use_websocket=False,  # Default
)

# Receive aggregations via queue
aggregation_id = client.queue.get(block=True)

WebSocket

from branchkey import Client, Credentials, WebSocketConfig

client = Client(
    credentials=credentials,
    websocket_config=WebSocketConfig(
        max_reconnect_attempts=0,       # 0 = infinite retry (default)
        reconnect_backoff_factor=2.0,   # Exponential backoff multiplier
        reconnect_max_delay=60,         # Max delay in seconds
    ),
    use_websocket=True,  # Enable WebSocket transport
)

# Receive aggregations via polling
aggregation_id = client.get_latest_aggregation_id()
if aggregation_id:
    client.file_download(aggregation_id)

Run Configuration

from branchkey import RunConfig

run_config = RunConfig(
    wait_for_run=False,     # Wait if run is paused before uploading
    check_interval_s=30,    # Run status check interval in seconds
)

HTTP Retry Configuration

The client automatically retries failed HTTP requests with exponential backoff:

from branchkey import RetryConfig

retry_config = RetryConfig(
    max_retries=3,                                   # Maximum retry attempts
    backoff_factor=1.0,                              # Backoff multiplier (seconds)
    total_timeout=30,                                # Request timeout in seconds
    status_forcelist=(408, 429, 500, 502, 503, 504), # HTTP codes to retry
    allowed_methods=("GET", "POST", "PUT"),          # Methods that support retry
)

client = Client(credentials, retry_config=retry_config)

Retry Behaviour:

  • Retries on: 408, 429, 5xx errors, connection timeouts
  • Does NOT retry: Other 4xx client errors (400, 401, 403, 404)
  • Backoff delays: Exponential (1s, 2s, 4s, ...)

Configuration Examples:

# Production: More retries, longer timeout
production_retry = RetryConfig(max_retries=5, backoff_factor=2.0, total_timeout=60)

# Development: Faster failure
dev_retry = RetryConfig(max_retries=1, backoff_factor=0.5, total_timeout=10)

Complete Configuration Example

from branchkey import (
    Client,
    Credentials,
    APIConfig,
    RabbitMQConfig,
    WebSocketConfig,
    RunConfig,
    RetryConfig,
)

client = Client(
    credentials=Credentials(
        id="leaf-uuid",
        name="my-leaf",
        session_token="token",
        tree_id="tree-uuid",
        branch_id="branch-uuid",
        owner_id="user-uuid",
    ),
    api_config=APIConfig(
        host="https://app.branchkey.com",
        ssl=True,
    ),
    rabbitmq_config=RabbitMQConfig(
        port=5671,
        ssl=True,
        max_reconnect_attempts=10,
    ),
    websocket_config=WebSocketConfig(
        max_reconnect_attempts=10,
    ),
    run_config=RunConfig(
        wait_for_run=True,
        check_interval_s=15,
    ),
    retry_config=RetryConfig(
        max_retries=5,
        backoff_factor=2.0,
    ),
    use_websocket=False,  # False for AMQP, True for WebSocket
)

Model Weight Format

Model weights are stored in compressed NPZ format.

Structure

# Format: (weighting, [list_of_parameter_arrays])
weighting = 1000  # Weight for aggregation (see below)
parameters = [layer1, layer2, ...]  # List of numpy arrays

Weighting Options

The weighting parameter controls how much influence this update has during aggregation:

1. By Sample Count (Most Common)

weighting = len(train_dataset)  # e.g., 1000 samples
# Client with 1000 samples has 2x influence of client with 500 samples

2. Equal Weighting

weighting = 1  # All clients have equal influence

3. Quality-Based Weighting

validation_accuracy = 0.85
weighting = len(train_dataset) * validation_accuracy  # Weight by quality

PyTorch Example

import numpy as np

# Using client helper
weighting = len(train_dataset)
parameters = []
for name, param in model.named_parameters():
    parameters.append(param.data.cpu().detach().numpy())

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
# Using convert_pytorch_numpy
weighting, parameters = client.convert_pytorch_numpy(
    model.named_parameters(),
    weighting=len(train_dataset)
)
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)

TensorFlow/Keras Example

import numpy as np

weighting = len(train_dataset)
parameters = [layer.numpy() for layer in model.trainable_weights]

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)

Loading Aggregated Weights

import numpy as np

# Load aggregated weights from NPZ file
npz_data = np.load("aggregated_output/aggregation_id.npz")

# Note: Aggregated results only contain layers (no weighting)
layer_keys = sorted([k for k in npz_data.files if k.startswith('layer_')])
parameters = [npz_data[k] for k in layer_keys]

# Apply to PyTorch model
import torch
for i, param in enumerate(model.parameters()):
    param.data = torch.from_numpy(parameters[i])

Performance Metrics

Submit training or testing metrics:

import json

metrics = {"accuracy": 0.95, "loss": 0.12}
client.send_performance_metrics(
    aggregation_id="aggregation-uuid",
    data=json.dumps(metrics),
    mode="test"  # "test", "train", or "non-federated"
)

Client Properties

client.run_status        # Current run status: "start", "stop", or "pause"
client.run_number        # Current run iteration
client.leaf_id           # Your leaf UUID
client.branch_id         # Parent branch UUID
client.tree_id           # Tree UUID
client.is_initialized    # Initialisation status
client.use_websocket     # True if using WebSocket transport

Branch Configuration

Fetch branch configuration including model-specific settings:

config = client.get_branch_config()
model_config = config.get("model_config", {})
sklearn_params = model_config.get("sklearn_params", {})

Advanced Features

Proxy Support

from branchkey import Client, Credentials, APIConfig

proxies = {
    'http': 'http://user:password@proxy.example.com:8080',
    'https': 'http://user:password@proxy.example.com:8080',
}

client = Client(
    credentials=credentials,
    api_config=APIConfig(proxies=proxies),
)

Context Manager

Use the client as a context manager for automatic cleanup:

from branchkey import Client, Credentials

with Client(credentials) as client:
    # Upload model
    file_path = client.save_weights("model", 1000, parameters)
    file_id = client.file_upload(file_path)

    # Download aggregation
    if not client.queue.empty():
        aggregation_id = client.queue.get(block=False)
        client.file_download(aggregation_id)
# Connections automatically closed

Error Handling

try:
    file_id = client.file_upload(file_path)
except Exception as e:
    print(f"Upload failed: {e}")
    # Logs include:
    # - HTTP status codes
    # - Response content preview
    # - Retry attempt information

Public API

from branchkey import (
    # Main client
    Client,

    # Configuration (frozen dataclasses)
    Credentials,
    APIConfig,
    RabbitMQConfig,
    WebSocketConfig,
    RunConfig,
    RetryConfig,

    # Utilities
    get_metadata,
    AGGREGATED_OUTPUT_DIR,
)

Support


BranchKey - Federated Learning Platform

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

branchkey-2.9.1.tar.gz (24.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

branchkey-2.9.1-py3-none-any.whl (25.4 kB view details)

Uploaded Python 3

File details

Details for the file branchkey-2.9.1.tar.gz.

File metadata

  • Download URL: branchkey-2.9.1.tar.gz
  • Upload date:
  • Size: 24.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for branchkey-2.9.1.tar.gz
Algorithm Hash digest
SHA256 31fd772314e3f1fed709db7690464aad3aa59cc490bcbd42876d357ca38bff70
MD5 5bfd02575f750c67d39662893f8a3321
BLAKE2b-256 63ca8c1dd2ecdc548f0e24078f989fcdf752d70df2db4a3cfffa433b8d05ee2d

See more details on using hashes here.

File details

Details for the file branchkey-2.9.1-py3-none-any.whl.

File metadata

  • Download URL: branchkey-2.9.1-py3-none-any.whl
  • Upload date:
  • Size: 25.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for branchkey-2.9.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bb64f189c7fa9cf01329799103f946481fb24db967b44e56e43ab63269556bc7
MD5 66e0e0a90abcb5ef3948231560e2356a
BLAKE2b-256 895598a21e275d486308e21914c0a6b889a01af4e419627ff6ba3929b52ec137

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page