Client application to interface with the BranchKey system
Project description
BranchKey Python Client
Official Python client for the BranchKey federated learning platform. This library provides a simple interface to upload model weights, download aggregated results, and track training runs.
Installation
pip install branchkey
Requirements: Python 3.9 or higher
Quick Start
1. Get Credentials
Create a leaf entity through the BranchKey platform to obtain credentials via the /v2/entities API endpoint.
2. Initialise Client
from branchkey import (
Client,
Credentials,
APIConfig,
RabbitMQConfig,
WebSocketConfig,
RunConfig,
RetryConfig,
)
# Create credentials
credentials = Credentials(
id="your-leaf-uuid",
name="my-client",
session_token="your-session-token-uuid",
owner_id="your-user-uuid",
tree_id="your-tree-uuid",
branch_id="your-branch-uuid",
)
# Initialise client with default settings
client = Client(credentials)
# Or with custom configuration
client = Client(
credentials=credentials,
api_config=APIConfig(
host="https://app.branchkey.com",
ssl=True,
),
rabbitmq_config=RabbitMQConfig(
port=5671,
ssl=True,
),
run_config=RunConfig(
wait_for_run=False,
check_interval_s=30,
),
)
3. Upload Model Weights
import numpy as np
# Prepare model weights
weighting = 1000 # Weight for aggregation (typically number of samples)
parameters = [layer1_weights, layer2_weights, ...]
# Save and upload
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
print(f"Uploaded: {file_id}")
4. Download Aggregated Results
# Wait for aggregation notification
aggregation_id = client.queue.get(block=True) # Blocks until aggregation ready
client.file_download(aggregation_id)
print(f"Downloaded to: ./aggregated_output/{aggregation_id}.npz")
# Or check without blocking
if not client.queue.empty():
aggregation_id = client.queue.get(block=False)
client.file_download(aggregation_id)
Configuration
All configuration uses immutable dataclasses for type safety and clarity.
Credentials
from branchkey import Credentials
credentials = Credentials(
id="leaf-uuid",
name="my-leaf",
session_token="token-uuid",
owner_id="user-uuid",
tree_id="tree-uuid",
branch_id="branch-uuid",
)
# Or from a dictionary
credentials = Credentials.from_dict(creds_dict)
API Configuration
from branchkey import APIConfig
api_config = APIConfig(
host="https://app.branchkey.com", # API endpoint (default)
ssl=True, # Verify SSL certificates (default)
proxies=None, # Optional proxy dict
)
Transport: AMQP (RabbitMQ) vs WebSocket
The client supports two transport mechanisms for receiving aggregation notifications:
AMQP/RabbitMQ (Default)
from branchkey import Client, Credentials, RabbitMQConfig
client = Client(
credentials=credentials,
rabbitmq_config=RabbitMQConfig(
host=None, # Auto-derived from API host
port=5671, # TLS port (default)
ssl=True, # Use TLS (default)
max_reconnect_attempts=0, # 0 = infinite retry (default)
reconnect_backoff_factor=2.0, # Exponential backoff multiplier
reconnect_max_delay=60, # Max delay in seconds
),
use_websocket=False, # Default
)
# Receive aggregations via queue
aggregation_id = client.queue.get(block=True)
WebSocket
from branchkey import Client, Credentials, WebSocketConfig
client = Client(
credentials=credentials,
websocket_config=WebSocketConfig(
max_reconnect_attempts=0, # 0 = infinite retry (default)
reconnect_backoff_factor=2.0, # Exponential backoff multiplier
reconnect_max_delay=60, # Max delay in seconds
),
use_websocket=True, # Enable WebSocket transport
)
# Receive aggregations via polling
aggregation_id = client.get_latest_aggregation_id()
if aggregation_id:
client.file_download(aggregation_id)
Run Configuration
from branchkey import RunConfig
run_config = RunConfig(
wait_for_run=False, # Wait if run is paused before uploading
check_interval_s=30, # Run status check interval in seconds
)
HTTP Retry Configuration
The client automatically retries failed HTTP requests with exponential backoff:
from branchkey import RetryConfig
retry_config = RetryConfig(
max_retries=3, # Maximum retry attempts
backoff_factor=1.0, # Backoff multiplier (seconds)
total_timeout=30, # Request timeout in seconds
status_forcelist=(408, 429, 500, 502, 503, 504), # HTTP codes to retry
allowed_methods=("GET", "POST", "PUT"), # Methods that support retry
)
client = Client(credentials, retry_config=retry_config)
Retry Behaviour:
- Retries on: 408, 429, 5xx errors, connection timeouts
- Does NOT retry: Other 4xx client errors (400, 401, 403, 404)
- Backoff delays: Exponential (1s, 2s, 4s, ...)
Configuration Examples:
# Production: More retries, longer timeout
production_retry = RetryConfig(max_retries=5, backoff_factor=2.0, total_timeout=60)
# Development: Faster failure
dev_retry = RetryConfig(max_retries=1, backoff_factor=0.5, total_timeout=10)
Complete Configuration Example
from branchkey import (
Client,
Credentials,
APIConfig,
RabbitMQConfig,
WebSocketConfig,
RunConfig,
RetryConfig,
)
client = Client(
credentials=Credentials(
id="leaf-uuid",
name="my-leaf",
session_token="token",
tree_id="tree-uuid",
branch_id="branch-uuid",
owner_id="user-uuid",
),
api_config=APIConfig(
host="https://app.branchkey.com",
ssl=True,
),
rabbitmq_config=RabbitMQConfig(
port=5671,
ssl=True,
max_reconnect_attempts=10,
),
websocket_config=WebSocketConfig(
max_reconnect_attempts=10,
),
run_config=RunConfig(
wait_for_run=True,
check_interval_s=15,
),
retry_config=RetryConfig(
max_retries=5,
backoff_factor=2.0,
),
use_websocket=False, # False for AMQP, True for WebSocket
)
Model Weight Format
Model weights are stored in compressed NPZ format.
Structure
# Format: (weighting, [list_of_parameter_arrays])
weighting = 1000 # Weight for aggregation (see below)
parameters = [layer1, layer2, ...] # List of numpy arrays
Weighting Options
The weighting parameter controls how much influence this update has during aggregation:
1. By Sample Count (Most Common)
weighting = len(train_dataset) # e.g., 1000 samples
# Client with 1000 samples has 2x influence of client with 500 samples
2. Equal Weighting
weighting = 1 # All clients have equal influence
3. Quality-Based Weighting
validation_accuracy = 0.85
weighting = len(train_dataset) * validation_accuracy # Weight by quality
PyTorch Example
import numpy as np
# Using client helper
weighting = len(train_dataset)
parameters = []
for name, param in model.named_parameters():
parameters.append(param.data.cpu().detach().numpy())
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
# Using convert_pytorch_numpy
weighting, parameters = client.convert_pytorch_numpy(
model.named_parameters(),
weighting=len(train_dataset)
)
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
TensorFlow/Keras Example
import numpy as np
weighting = len(train_dataset)
parameters = [layer.numpy() for layer in model.trainable_weights]
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
Loading Aggregated Weights
import numpy as np
# Load aggregated weights from NPZ file
npz_data = np.load("aggregated_output/aggregation_id.npz")
# Note: Aggregated results only contain layers (no weighting)
layer_keys = sorted([k for k in npz_data.files if k.startswith('layer_')])
parameters = [npz_data[k] for k in layer_keys]
# Apply to PyTorch model
import torch
for i, param in enumerate(model.parameters()):
param.data = torch.from_numpy(parameters[i])
Performance Metrics
Submit training or testing metrics:
import json
metrics = {"accuracy": 0.95, "loss": 0.12}
client.send_performance_metrics(
aggregation_id="aggregation-uuid",
data=json.dumps(metrics),
mode="test" # "test", "train", or "non-federated"
)
Client Properties
client.run_status # Current run status: "start", "stop", or "pause"
client.run_number # Current run iteration
client.leaf_id # Your leaf UUID
client.branch_id # Parent branch UUID
client.tree_id # Tree UUID
client.is_initialized # Initialisation status
client.use_websocket # True if using WebSocket transport
Branch Configuration
Fetch branch configuration including model-specific settings:
config = client.get_branch_config()
model_config = config.get("model_config", {})
sklearn_params = model_config.get("sklearn_params", {})
Advanced Features
Proxy Support
from branchkey import Client, Credentials, APIConfig
proxies = {
'http': 'http://user:password@proxy.example.com:8080',
'https': 'http://user:password@proxy.example.com:8080',
}
client = Client(
credentials=credentials,
api_config=APIConfig(proxies=proxies),
)
Context Manager
Use the client as a context manager for automatic cleanup:
from branchkey import Client, Credentials
with Client(credentials) as client:
# Upload model
file_path = client.save_weights("model", 1000, parameters)
file_id = client.file_upload(file_path)
# Download aggregation
if not client.queue.empty():
aggregation_id = client.queue.get(block=False)
client.file_download(aggregation_id)
# Connections automatically closed
Error Handling
try:
file_id = client.file_upload(file_path)
except Exception as e:
print(f"Upload failed: {e}")
# Logs include:
# - HTTP status codes
# - Response content preview
# - Retry attempt information
Public API
from branchkey import (
# Main client
Client,
# Configuration (frozen dataclasses)
Credentials,
APIConfig,
RabbitMQConfig,
WebSocketConfig,
RunConfig,
RetryConfig,
# Utilities
get_metadata,
AGGREGATED_OUTPUT_DIR,
)
Support
- Website: https://branchkey.com
- Repository: https://gitlab.com/branchkey/client_application
- Email: info@branchkey.com
BranchKey - Federated Learning Platform
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file branchkey-2.9.2.tar.gz.
File metadata
- Download URL: branchkey-2.9.2.tar.gz
- Upload date:
- Size: 24.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7aa6296ec9f1b92d3e1d24c6bdaa4040fbabc1bb0a9ee67fccfb20bfc1dfa299
|
|
| MD5 |
2b71f967d8b2badc91ad55902be8f04d
|
|
| BLAKE2b-256 |
40059bd2ae3297ea202d2609970b13218cea219c2a5d5a093c4b846a0deabdfa
|
File details
Details for the file branchkey-2.9.2-py3-none-any.whl.
File metadata
- Download URL: branchkey-2.9.2-py3-none-any.whl
- Upload date:
- Size: 25.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6f84c518f63f24e768c4284e0ce558220c6fdd9466c694d29b7c5135c1e3f698
|
|
| MD5 |
a7a5dd52602653a3a7599a6a5f527199
|
|
| BLAKE2b-256 |
011d738fad6bc1cb7738875555fd51828c41c7ac6a1950e935400315dcbe8fa1
|