Skip to main content

Client application to interface with the BranchKey system

Project description

BranchKey Python Client

BK_logo

PyPI version Python License: GPL v3

Official Python client for the BranchKey federated learning platform. This library provides a simple interface to upload model weights, download aggregated results, and track training runs.

Installation

pip install branchkey

Requirements: Python 3.9 or higher

Quick Start

1. Get Credentials

Create a leaf entity through the BranchKey platform to obtain credentials via the /v2/entities API endpoint:

credentials = {
    "id": "your-leaf-uuid",
    "name": "my-client",
    "session_token": "your-session-token-uuid",
    "owner_id": "your-user-uuid",
    "tree_id": "your-tree-uuid",
    "branch_id": "your-branch-uuid"
}

2. Initialize Client

from branchkey.client import Client

# Connect to BranchKey
client = Client(credentials, host="https://app.branchkey.com")

3. Upload Model Weights

import numpy as np

# Prepare model weights
weighting = 1000  # Weight for aggregation (typically number of samples)
parameters = [layer1_weights, layer2_weights, ...]

# Save and upload
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
print(f"Uploaded: {file_id}")

4. Download Aggregated Results

# Wait for aggregation notification from RabbitMQ
# The queue is populated by the background RabbitMQ consumer
aggregation_id = client.queue.get(block=True)  # Blocks until aggregation ready
client.file_download(aggregation_id)
print(f"Downloaded to: ./aggregated_files/{aggregation_id}.npz")

# Or check without blocking
if not client.queue.empty():
    aggregation_id = client.queue.get(block=False)
    client.file_download(aggregation_id)

Configuration Options

Basic Configuration

client = Client(
    credentials,
    host="https://app.branchkey.com",         # API endpoint
    rbmq_host=None,                           # RabbitMQ host (auto-derived from host)
    rbmq_port=5671,                           # RabbitMQ port (5671 for TLS)
    rbmq_ssl=True,                            # Use TLS for RabbitMQ
    rbmq_max_reconnect_attempts=10,           # Max RabbitMQ reconnection attempts
    rbmq_reconnect_backoff_factor=2.0,        # Exponential backoff multiplier
    rbmq_reconnect_max_delay=60,              # Max reconnection delay (seconds)
    ssl=True,                                 # Verify SSL certificates
    wait_for_run=False,                       # Wait if run is paused
    run_check_interval_s=30,                  # Run status check interval
    proxies=None,                             # HTTP/HTTPS proxy dict
    retry_config=None                         # Custom retry configuration (optional)
)

Retry Configuration

The client automatically retries failed HTTP requests with exponential backoff. Default settings:

from branchkey.retry_config import RetryConfig

# Default configuration (applied automatically)
retry_config = RetryConfig(
    max_retries=3,              # Maximum retry attempts
    backoff_factor=1.0,         # Exponential backoff multiplier (seconds)
    total_timeout=30,           # Request timeout in seconds
    status_forcelist=(408, 429, 500, 502, 503, 504)  # HTTP codes to retry
)

client = Client(credentials, retry_config=retry_config)

Retry Behaviour:

  • Retries on:
    • 408 Request Timeout - Client-side timeouts
    • 429 Too Many Requests - Rate limiting
    • 5xx Server Errors - 500, 502, 503, 504
    • Connection timeouts and network failures
  • Does NOT retry: Other 4xx client errors (400, 401, 403, 404, etc.)
  • Backoff delays: 1s, 2s, 4s (exponential with backoff_factor)

Custom Configuration Examples:

# Production: More retries, longer timeout
production_retry = RetryConfig(max_retries=5, backoff_factor=2.0, total_timeout=60)
client = Client(credentials, retry_config=production_retry)

# Development: Faster failure
dev_retry = RetryConfig(max_retries=1, backoff_factor=0.5, total_timeout=10)
client = Client(credentials, retry_config=dev_retry)

RabbitMQ Reconnection

The RabbitMQ consumer automatically reconnects with exponential backoff if the connection is lost.

Default Settings:

  • Max reconnection attempts: 10
  • Backoff factor: 2.0x (delays: 2s, 4s, 8s, 16s, 32s, 60s...)
  • Max delay: 60 seconds

Custom Configuration:

# Pass reconnection settings as Client parameters
client = Client(
    credentials,
    rbmq_max_reconnect_attempts=20,     # More attempts
    rbmq_reconnect_backoff_factor=3.0,  # Faster backoff increase
    rbmq_reconnect_max_delay=120        # 2 minute max delay (seconds)
)

Model Weight Format

Model weights are stored in compressed NPZ format.

Structure

# Format: (weighting, [list_of_parameter_arrays])
weighting = 1000  # Weight for aggregation (see below)
parameters = [layer1, layer2, ...]  # List of numpy arrays

Weighting Options

The weighting parameter controls how much influence this update has during aggregation:

1. By Sample Count (Most Common)

weighting = len(train_dataset)  # e.g., 1000 samples
# Client with 1000 samples has 2x influence of client with 500 samples

2. Equal Weighting

weighting = 1  # All clients have equal influence

3. Quality-Based Weighting

validation_accuracy = 0.85
weighting = len(train_dataset) * validation_accuracy  # Weight by quality

4. Manual Weighting

weighting = 5.0  # Trusted client gets higher weight

PyTorch Example

import numpy as np

# Method 1: Using client helper (recommended)
weighting = len(train_dataset)
parameters = []
for name, param in model.named_parameters():
    parameters.append(param.data.cpu().detach().numpy())

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
# Method 2: Using convert_pytorch_numpy
weighting, parameters = client.convert_pytorch_numpy(
    model.named_parameters(),
    weighting=len(train_dataset)
)
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)

TensorFlow/Keras Example

import numpy as np

weighting = len(train_dataset)
parameters = [layer.numpy() for layer in model.trainable_weights]

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)

Manual NPZ Creation

import numpy as np

# Save manually (without using client helper)
arrays_dict = {'weighting': np.array([weighting], dtype=np.float64)}
for i, arr in enumerate(parameters):
    arrays_dict[f'layer_{i}'] = arr

np.savez_compressed("model_weights.npz", **arrays_dict)  # Must include .npz

Loading Aggregated Weights

# Load aggregated weights from NPZ file
npz_data = np.load("aggregated_files/aggregation_id.npz")

# Note: Aggregated results only contain layers (no weighting)
layer_keys = sorted([k for k in npz_data.files if k.startswith('layer_')])
parameters = [npz_data[k] for k in layer_keys]

# Apply to your model
for i, param in enumerate(model.parameters()):
    param.data = torch.from_numpy(parameters[i])

Example NPZ File Contents

Client Upload Format (with weighting):

>>> npz_data.files
['weighting', 'layer_0', 'layer_1', 'layer_2', 'layer_3', ...]

>>> npz_data['weighting']
array([1530.])  # Weight for aggregation

>>> npz_data['layer_0'].shape, npz_data['layer_0'].dtype
((32, 1, 5, 5), dtype('float32'))

>>> npz_data['layer_0'][:1, :2, :2, :]
array([[[[-0.18576819, -0.03041792,  0.19532707, -0.11234483, -0.01512307],
         [ 0.19993757, -0.06492048,  0.08324468, -0.19899307, -0.0412709 ]]]],
       dtype=float32)

Aggregated Result Format (layers only):

>>> npz_data.files
['layer_0', 'layer_1', 'layer_2', ...]  # No weighting in aggregated results

Performance Metrics

Submit training or testing metrics:

import json

metrics = {"accuracy": 0.95, "loss": 0.12}
client.send_performance_metrics(
    aggregation_id="aggregation-uuid",
    data=json.dumps(metrics),
    mode="test"  # "test", "train", or "non-federated"
)

Client Properties

client.run_status        # Current run status: "start", "stop", or "pause"
client.run_number        # Current run iteration
client.leaf_id           # Your leaf UUID
client.branch_id         # Parent branch UUID
client.is_initialized    # Initialization status

Branch Configuration

Fetch branch configuration including model-specific settings:

config = client.get_branch_config()
model_config = config.get("model_config", {})
sklearn_params = model_config.get("sklearn_params", {})

Advanced Features

Proxy Support

For networks requiring proxy access:

proxies = {
    'http': 'http://user:password@proxy.example.com:8080',
    'https': 'http://user:password@proxy.example.com:8080',
}
client = Client(credentials, host="https://app.branchkey.com", proxies=proxies)

Context Manager

Use the client as a context manager for automatic cleanup:

with Client(credentials) as client:
    # Upload model
    file_path = client.save_weights("model", 1000, parameters)
    file_id = client.file_upload(file_path)

    # Download aggregation
    if not client.queue.empty():
        aggregation_id = client.queue.get(block=False)
        client.file_download(aggregation_id)
# Connections automatically closed

Error Handling

The client provides detailed error messages for debugging:

try:
    file_id = client.file_upload(file_path)
except Exception as e:
    print(f"Upload failed: {e}")
    # Logs include:
    # - HTTP status codes
    # - Response content preview
    # - Retry attempt information

Support


BranchKey - Federated Learning Platform

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

branchkey-2.9.0.tar.gz (24.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

branchkey-2.9.0-py3-none-any.whl (25.5 kB view details)

Uploaded Python 3

File details

Details for the file branchkey-2.9.0.tar.gz.

File metadata

  • Download URL: branchkey-2.9.0.tar.gz
  • Upload date:
  • Size: 24.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for branchkey-2.9.0.tar.gz
Algorithm Hash digest
SHA256 bdb9d512bd2fc2009982a5ea26f9c6064caec7fb117e61dc50d767ae717695b3
MD5 7a5f1481583dacc8c983a1e4f228080d
BLAKE2b-256 b54697da8240905e7a3a5b44eed5152cb87a6ee0066e3a10b8f89897665cecfc

See more details on using hashes here.

File details

Details for the file branchkey-2.9.0-py3-none-any.whl.

File metadata

  • Download URL: branchkey-2.9.0-py3-none-any.whl
  • Upload date:
  • Size: 25.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for branchkey-2.9.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3f7cefcc883c429012fdf9990888528beadd154b09c9f21ce0c0c9ad3412ad7c
MD5 4a2d55d3e78a2ab25e5eb904d30ca041
BLAKE2b-256 47b249116ed874de074376258fe3130101b4c2b85b8f3a5f19eb084fc4b858b4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page