Skip to main content

Client application to interface with the BranchKey system

Project description

BranchKey Python Client

BK_logo

PyPI version Python License: GPL v3

Official Python client for the BranchKey federated learning platform. This library provides a simple interface to upload model weights, download aggregated results, and track training runs.

Installation

pip install branchkey

Requirements: Python 3.9 or higher

Quick Start

1. Get Credentials

Create a leaf entity through the BranchKey platform to obtain credentials via the /v2/entities API endpoint:

credentials = {
    "id": "your-leaf-uuid",
    "name": "my-client",
    "session_token": "your-session-token-uuid",
    "owner_id": "your-user-uuid",
    "tree_id": "your-tree-uuid",
    "branch_id": "your-branch-uuid"
}

2. Initialize Client

from branchkey.client import Client

# Connect to BranchKey
client = Client(credentials, host="https://app.branchkey.com")

3. Upload Model Weights

import numpy as np

# Prepare model weights
weighting = 1000  # Weight for aggregation (typically number of samples)
parameters = [layer1_weights, layer2_weights, ...]

# Save and upload
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
print(f"Uploaded: {file_id}")

4. Download Aggregated Results

# Wait for aggregation notification from RabbitMQ
# The queue is populated by the background RabbitMQ consumer
aggregation_id = client.queue.get(block=True)  # Blocks until aggregation ready
client.file_download(aggregation_id)
print(f"Downloaded to: ./aggregated_files/{aggregation_id}.npz")

# Or check without blocking
if not client.queue.empty():
    aggregation_id = client.queue.get(block=False)
    client.file_download(aggregation_id)

Configuration Options

Basic Configuration

client = Client(
    credentials,
    host="https://app.branchkey.com",         # API endpoint
    rbmq_host=None,                           # RabbitMQ host (auto-derived from host)
    rbmq_port=5671,                           # RabbitMQ port (5671 for TLS)
    rbmq_ssl=True,                            # Use TLS for RabbitMQ
    rbmq_max_reconnect_attempts=10,           # Max RabbitMQ reconnection attempts
    rbmq_reconnect_backoff_factor=2.0,        # Exponential backoff multiplier
    rbmq_reconnect_max_delay=60,              # Max reconnection delay (seconds)
    ssl=True,                                 # Verify SSL certificates
    wait_for_run=False,                       # Wait if run is paused
    run_check_interval_s=30,                  # Run status check interval
    proxies=None,                             # HTTP/HTTPS proxy dict
    retry_config=None                         # Custom retry configuration (optional)
)

Retry Configuration

The client automatically retries failed HTTP requests with exponential backoff. Default settings:

from branchkey.retry_config import RetryConfig

# Default configuration (applied automatically)
retry_config = RetryConfig(
    max_retries=3,              # Maximum retry attempts
    backoff_factor=1.0,         # Exponential backoff multiplier (seconds)
    total_timeout=30,           # Request timeout in seconds
    status_forcelist=(408, 429, 500, 502, 503, 504)  # HTTP codes to retry
)

client = Client(credentials, retry_config=retry_config)

Retry Behaviour:

  • Retries on:
    • 408 Request Timeout - Client-side timeouts
    • 429 Too Many Requests - Rate limiting
    • 5xx Server Errors - 500, 502, 503, 504
    • Connection timeouts and network failures
  • Does NOT retry: Other 4xx client errors (400, 401, 403, 404, etc.)
  • Backoff delays: 1s, 2s, 4s (exponential with backoff_factor)

Custom Configuration Examples:

# Production: More retries, longer timeout
production_retry = RetryConfig(max_retries=5, backoff_factor=2.0, total_timeout=60)
client = Client(credentials, retry_config=production_retry)

# Development: Faster failure
dev_retry = RetryConfig(max_retries=1, backoff_factor=0.5, total_timeout=10)
client = Client(credentials, retry_config=dev_retry)

RabbitMQ Reconnection

The RabbitMQ consumer automatically reconnects with exponential backoff if the connection is lost.

Default Settings:

  • Max reconnection attempts: 10
  • Backoff factor: 2.0x (delays: 2s, 4s, 8s, 16s, 32s, 60s...)
  • Max delay: 60 seconds

Custom Configuration:

# Pass reconnection settings as Client parameters
client = Client(
    credentials,
    rbmq_max_reconnect_attempts=20,     # More attempts
    rbmq_reconnect_backoff_factor=3.0,  # Faster backoff increase
    rbmq_reconnect_max_delay=120        # 2 minute max delay (seconds)
)

Model Weight Format

Model weights are stored in compressed NPZ format.

Structure

# Format: (weighting, [list_of_parameter_arrays])
weighting = 1000  # Weight for aggregation (see below)
parameters = [layer1, layer2, ...]  # List of numpy arrays

Weighting Options

The weighting parameter controls how much influence this update has during aggregation:

1. By Sample Count (Most Common)

weighting = len(train_dataset)  # e.g., 1000 samples
# Client with 1000 samples has 2x influence of client with 500 samples

2. Equal Weighting

weighting = 1  # All clients have equal influence

3. Quality-Based Weighting

validation_accuracy = 0.85
weighting = len(train_dataset) * validation_accuracy  # Weight by quality

4. Manual Weighting

weighting = 5.0  # Trusted client gets higher weight

PyTorch Example

import numpy as np

# Method 1: Using client helper (recommended)
weighting = len(train_dataset)
parameters = []
for name, param in model.named_parameters():
    parameters.append(param.data.cpu().detach().numpy())

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
# Method 2: Using convert_pytorch_numpy
weighting, parameters = client.convert_pytorch_numpy(
    model.named_parameters(),
    weighting=len(train_dataset)
)
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)

TensorFlow/Keras Example

import numpy as np

weighting = len(train_dataset)
parameters = [layer.numpy() for layer in model.trainable_weights]

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)

Manual NPZ Creation

import numpy as np

# Save manually (without using client helper)
arrays_dict = {'weighting': np.array([weighting], dtype=np.float64)}
for i, arr in enumerate(parameters):
    arrays_dict[f'layer_{i}'] = arr

np.savez_compressed("model_weights.npz", **arrays_dict)  # Must include .npz

Loading Aggregated Weights

# Load aggregated weights from NPZ file
npz_data = np.load("aggregated_files/aggregation_id.npz")

# Note: Aggregated results only contain layers (no weighting)
layer_keys = sorted([k for k in npz_data.files if k.startswith('layer_')])
parameters = [npz_data[k] for k in layer_keys]

# Apply to your model
for i, param in enumerate(model.parameters()):
    param.data = torch.from_numpy(parameters[i])

Example NPZ File Contents

Client Upload Format (with weighting):

>>> npz_data.files
['weighting', 'layer_0', 'layer_1', 'layer_2', 'layer_3', ...]

>>> npz_data['weighting']
array([1530.])  # Weight for aggregation

>>> npz_data['layer_0'].shape, npz_data['layer_0'].dtype
((32, 1, 5, 5), dtype('float32'))

>>> npz_data['layer_0'][:1, :2, :2, :]
array([[[[-0.18576819, -0.03041792,  0.19532707, -0.11234483, -0.01512307],
         [ 0.19993757, -0.06492048,  0.08324468, -0.19899307, -0.0412709 ]]]],
       dtype=float32)

Aggregated Result Format (layers only):

>>> npz_data.files
['layer_0', 'layer_1', 'layer_2', ...]  # No weighting in aggregated results

Performance Metrics

Submit training or testing metrics:

import json

metrics = {"accuracy": 0.95, "loss": 0.12}
client.send_performance_metrics(
    aggregation_id="aggregation-uuid",
    data=json.dumps(metrics),
    mode="test"  # "test", "train", or "non-federated"
)

Client Properties

client.run_status        # Current run status: "start", "stop", or "pause"
client.run_number        # Current run iteration
client.leaf_id           # Your leaf UUID
client.branch_id         # Parent branch UUID
client.is_initialized    # Initialization status

Branch Configuration

Fetch branch configuration including model-specific settings:

config = client.get_branch_config()
model_config = config.get("model_config", {})
sklearn_params = model_config.get("sklearn_params", {})

Advanced Features

Proxy Support

For networks requiring proxy access:

proxies = {
    'http': 'http://user:password@proxy.example.com:8080',
    'https': 'http://user:password@proxy.example.com:8080',
}
client = Client(credentials, host="https://app.branchkey.com", proxies=proxies)

Context Manager

Use the client as a context manager for automatic cleanup:

with Client(credentials) as client:
    # Upload model
    file_path = client.save_weights("model", 1000, parameters)
    file_id = client.file_upload(file_path)

    # Download aggregation
    if not client.queue.empty():
        aggregation_id = client.queue.get(block=False)
        client.file_download(aggregation_id)
# Connections automatically closed

Error Handling

The client provides detailed error messages for debugging:

try:
    file_id = client.file_upload(file_path)
except Exception as e:
    print(f"Upload failed: {e}")
    # Logs include:
    # - HTTP status codes
    # - Response content preview
    # - Retry attempt information

Support


BranchKey - Federated Learning Platform

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

branchkey-2.8.1.tar.gz (22.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

branchkey-2.8.1-py3-none-any.whl (21.8 kB view details)

Uploaded Python 3

File details

Details for the file branchkey-2.8.1.tar.gz.

File metadata

  • Download URL: branchkey-2.8.1.tar.gz
  • Upload date:
  • Size: 22.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for branchkey-2.8.1.tar.gz
Algorithm Hash digest
SHA256 329c5590945c7be524f7c2e9a13e50a396a805cc67c10402f53c54a7766267be
MD5 eb0ededc6d1972b25c64e76ef5c9dc0b
BLAKE2b-256 40a4ce1f400b42e7f73ccec082f1bc043ba7eda9169858e4100a4e7b03112d63

See more details on using hashes here.

File details

Details for the file branchkey-2.8.1-py3-none-any.whl.

File metadata

  • Download URL: branchkey-2.8.1-py3-none-any.whl
  • Upload date:
  • Size: 21.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for branchkey-2.8.1-py3-none-any.whl
Algorithm Hash digest
SHA256 51eb2b48937247e9cd7e1c0efa62f5706b2fed08023f69a834297ed7c0480a54
MD5 e92ec5ba03b1eeaef79cfbce087642ba
BLAKE2b-256 3d1c7a23bccd8dbd044cf28a8079f3f006c019cde44bc6ad106a650c122f79bd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page