Client application to interface with the BranchKey system
Project description
BranchKey Python Client
Official Python client for the BranchKey federated learning platform. This library provides a simple interface to upload model weights, download aggregated results, and track training runs.
Installation
pip install branchkey
Requirements: Python 3.9 or higher
Quick Start
1. Get Credentials
Create a leaf entity through the BranchKey platform to obtain credentials via the /v2/entities API endpoint:
credentials = {
"id": "your-leaf-uuid",
"name": "my-client",
"session_token": "your-session-token-uuid",
"owner_id": "your-user-uuid",
"tree_id": "your-tree-uuid",
"branch_id": "your-branch-uuid"
}
2. Initialize Client
from branchkey.client import Client
# Connect to BranchKey
client = Client(credentials, host="https://app.branchkey.com")
3. Upload Model Weights
import numpy as np
# Prepare model weights
weighting = 1000 # Weight for aggregation (typically number of samples)
parameters = [layer1_weights, layer2_weights, ...]
# Save and upload
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
print(f"Uploaded: {file_id}")
4. Download Aggregated Results
# Wait for aggregation notification from RabbitMQ
# The queue is populated by the background RabbitMQ consumer
aggregation_id = client.queue.get(block=True) # Blocks until aggregation ready
client.file_download(aggregation_id)
print(f"Downloaded to: ./aggregated_files/{aggregation_id}.npz")
# Or check without blocking
if not client.queue.empty():
aggregation_id = client.queue.get(block=False)
client.file_download(aggregation_id)
Configuration Options
Basic Configuration
client = Client(
credentials,
host="https://app.branchkey.com", # API endpoint
rbmq_host=None, # RabbitMQ host (auto-derived from host)
rbmq_port=5671, # RabbitMQ port (5671 for TLS)
rbmq_ssl=True, # Use TLS for RabbitMQ
rbmq_max_reconnect_attempts=10, # Max RabbitMQ reconnection attempts
rbmq_reconnect_backoff_factor=2.0, # Exponential backoff multiplier
rbmq_reconnect_max_delay=60, # Max reconnection delay (seconds)
ssl=True, # Verify SSL certificates
wait_for_run=False, # Wait if run is paused
run_check_interval_s=30, # Run status check interval
proxies=None, # HTTP/HTTPS proxy dict
retry_config=None # Custom retry configuration (optional)
)
Retry Configuration
The client automatically retries failed HTTP requests with exponential backoff. Default settings:
from branchkey.retry_config import RetryConfig
# Default configuration (applied automatically)
retry_config = RetryConfig(
max_retries=3, # Maximum retry attempts
backoff_factor=1.0, # Exponential backoff multiplier (seconds)
total_timeout=30, # Request timeout in seconds
status_forcelist=(408, 429, 500, 502, 503, 504) # HTTP codes to retry
)
client = Client(credentials, retry_config=retry_config)
Retry Behaviour:
- Retries on:
- 408 Request Timeout - Client-side timeouts
- 429 Too Many Requests - Rate limiting
- 5xx Server Errors - 500, 502, 503, 504
- Connection timeouts and network failures
- Does NOT retry: Other 4xx client errors (400, 401, 403, 404, etc.)
- Backoff delays: 1s, 2s, 4s (exponential with
backoff_factor)
Custom Configuration Examples:
# Production: More retries, longer timeout
production_retry = RetryConfig(max_retries=5, backoff_factor=2.0, total_timeout=60)
client = Client(credentials, retry_config=production_retry)
# Development: Faster failure
dev_retry = RetryConfig(max_retries=1, backoff_factor=0.5, total_timeout=10)
client = Client(credentials, retry_config=dev_retry)
RabbitMQ Reconnection
The RabbitMQ consumer automatically reconnects with exponential backoff if the connection is lost.
Default Settings:
- Max reconnection attempts: 10
- Backoff factor: 2.0x (delays: 2s, 4s, 8s, 16s, 32s, 60s...)
- Max delay: 60 seconds
Custom Configuration:
# Pass reconnection settings as Client parameters
client = Client(
credentials,
rbmq_max_reconnect_attempts=20, # More attempts
rbmq_reconnect_backoff_factor=3.0, # Faster backoff increase
rbmq_reconnect_max_delay=120 # 2 minute max delay (seconds)
)
Model Weight Format
Model weights are stored in compressed NPZ format.
Structure
# Format: (weighting, [list_of_parameter_arrays])
weighting = 1000 # Weight for aggregation (see below)
parameters = [layer1, layer2, ...] # List of numpy arrays
Weighting Options
The weighting parameter controls how much influence this update has during aggregation:
1. By Sample Count (Most Common)
weighting = len(train_dataset) # e.g., 1000 samples
# Client with 1000 samples has 2x influence of client with 500 samples
2. Equal Weighting
weighting = 1 # All clients have equal influence
3. Quality-Based Weighting
validation_accuracy = 0.85
weighting = len(train_dataset) * validation_accuracy # Weight by quality
4. Manual Weighting
weighting = 5.0 # Trusted client gets higher weight
PyTorch Example
import numpy as np
# Method 1: Using client helper (recommended)
weighting = len(train_dataset)
parameters = []
for name, param in model.named_parameters():
parameters.append(param.data.cpu().detach().numpy())
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
# Method 2: Using convert_pytorch_numpy
weighting, parameters = client.convert_pytorch_numpy(
model.named_parameters(),
weighting=len(train_dataset)
)
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
TensorFlow/Keras Example
import numpy as np
weighting = len(train_dataset)
parameters = [layer.numpy() for layer in model.trainable_weights]
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
Manual NPZ Creation
import numpy as np
# Save manually (without using client helper)
arrays_dict = {'weighting': np.array([weighting], dtype=np.float64)}
for i, arr in enumerate(parameters):
arrays_dict[f'layer_{i}'] = arr
np.savez_compressed("model_weights.npz", **arrays_dict) # Must include .npz
Loading Aggregated Weights
# Load aggregated weights from NPZ file
npz_data = np.load("aggregated_files/aggregation_id.npz")
# Note: Aggregated results only contain layers (no weighting)
layer_keys = sorted([k for k in npz_data.files if k.startswith('layer_')])
parameters = [npz_data[k] for k in layer_keys]
# Apply to your model
for i, param in enumerate(model.parameters()):
param.data = torch.from_numpy(parameters[i])
Example NPZ File Contents
Client Upload Format (with weighting):
>>> npz_data.files
['weighting', 'layer_0', 'layer_1', 'layer_2', 'layer_3', ...]
>>> npz_data['weighting']
array([1530.]) # Weight for aggregation
>>> npz_data['layer_0'].shape, npz_data['layer_0'].dtype
((32, 1, 5, 5), dtype('float32'))
>>> npz_data['layer_0'][:1, :2, :2, :]
array([[[[-0.18576819, -0.03041792, 0.19532707, -0.11234483, -0.01512307],
[ 0.19993757, -0.06492048, 0.08324468, -0.19899307, -0.0412709 ]]]],
dtype=float32)
Aggregated Result Format (layers only):
>>> npz_data.files
['layer_0', 'layer_1', 'layer_2', ...] # No weighting in aggregated results
Performance Metrics
Submit training or testing metrics:
import json
metrics = {"accuracy": 0.95, "loss": 0.12}
client.send_performance_metrics(
aggregation_id="aggregation-uuid",
data=json.dumps(metrics),
mode="test" # "test", "train", or "non-federated"
)
Client Properties
client.run_status # Current run status: "start", "stop", or "pause"
client.run_number # Current run iteration
client.leaf_id # Your leaf UUID
client.branch_id # Parent branch UUID
client.is_initialized # Initialization status
Branch Configuration
Fetch branch configuration including model-specific settings:
config = client.get_branch_config()
model_config = config.get("model_config", {})
sklearn_params = model_config.get("sklearn_params", {})
Advanced Features
Proxy Support
For networks requiring proxy access:
proxies = {
'http': 'http://user:password@proxy.example.com:8080',
'https': 'http://user:password@proxy.example.com:8080',
}
client = Client(credentials, host="https://app.branchkey.com", proxies=proxies)
Context Manager
Use the client as a context manager for automatic cleanup:
with Client(credentials) as client:
# Upload model
file_path = client.save_weights("model", 1000, parameters)
file_id = client.file_upload(file_path)
# Download aggregation
if not client.queue.empty():
aggregation_id = client.queue.get(block=False)
client.file_download(aggregation_id)
# Connections automatically closed
Error Handling
The client provides detailed error messages for debugging:
try:
file_id = client.file_upload(file_path)
except Exception as e:
print(f"Upload failed: {e}")
# Logs include:
# - HTTP status codes
# - Response content preview
# - Retry attempt information
Support
- Website: https://branchkey.com
- Repository: https://gitlab.com/branchkey/client_application
- Email: info@branchkey.com
BranchKey - Federated Learning Platform
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file branchkey-2.9.0.tar.gz.
File metadata
- Download URL: branchkey-2.9.0.tar.gz
- Upload date:
- Size: 24.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bdb9d512bd2fc2009982a5ea26f9c6064caec7fb117e61dc50d767ae717695b3
|
|
| MD5 |
7a5f1481583dacc8c983a1e4f228080d
|
|
| BLAKE2b-256 |
b54697da8240905e7a3a5b44eed5152cb87a6ee0066e3a10b8f89897665cecfc
|
File details
Details for the file branchkey-2.9.0-py3-none-any.whl.
File metadata
- Download URL: branchkey-2.9.0-py3-none-any.whl
- Upload date:
- Size: 25.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3f7cefcc883c429012fdf9990888528beadd154b09c9f21ce0c0c9ad3412ad7c
|
|
| MD5 |
4a2d55d3e78a2ab25e5eb904d30ca041
|
|
| BLAKE2b-256 |
47b249116ed874de074376258fe3130101b4c2b85b8f3a5f19eb084fc4b858b4
|