Weights & Biases plugin for Flyte
Project description
Weights & Biases Plugin
This plugin provides integration between Flyte and Weights & Biases (W&B) for experiment tracking, including support for distributed training with PyTorch Elastic.
Quickstart
from flyteplugins.wandb import wandb_init, wandb_config, get_wandb_run
@wandb_init(project="my-project", entity="my-team")
@env.task
def train():
run = get_wandb_run()
run.log({"loss": 0.5, "accuracy": 0.9})
Core concepts
Decorator order
@wandb_init and @wandb_sweep must be the outermost decorators (applied after @env.task):
@wandb_init # Outermost
@env.task # Task decorator
def my_task():
...
Run modes
The run_mode parameter controls how W&B runs are created:
"auto"(default): Creates a new run if no parent exists, otherwise shares the parent's run"new": Always creates a new W&B run with a unique ID"shared": Always shares the parent's run ID (useful for child tasks)
Accessing the run
Use get_wandb_run() to access the current W&B run:
from flyteplugins.wandb import get_wandb_run
run = get_wandb_run()
if run:
run.log({"metric": value})
Returns None if not within a @wandb_init decorated task or if the current rank should not log (in distributed training).
Distributed training
The plugin automatically detects distributed training environments (PyTorch Elastic) and configures W&B appropriately.
Environment variables
Distributed training is detected via these environment variables (set by torchrun/torch.distributed.elastic):
| Variable | Description |
|---|---|
RANK |
Global rank of the process |
WORLD_SIZE |
Total number of processes |
LOCAL_RANK |
Rank within the current node |
LOCAL_WORLD_SIZE |
Number of processes per node |
GROUP_RANK |
Worker/node index (0, 1, 2, ...) |
Run modes in distributed context
| Mode | Single-Node | Multi-Node |
|---|---|---|
"auto" |
Only rank 0 logs → 1 run | Local rank 0 of each worker logs → N runs (1 per worker) |
"shared" |
All ranks log to 1 shared run | All ranks per worker log to shared run → N runs (1 per worker) |
"new" |
Each rank gets its own run (grouped) → N runs | Each rank gets its own run (grouped per worker) → N×GPUs runs |
Run ID patterns
| Scenario | Run ID Pattern |
|---|---|
| Single-node auto/shared | {run_name}-{action_name} |
| Single-node new | {run_name}-{action_name}-rank-{rank} |
| Multi-node auto/shared | {run_name}-{action_name}-worker-{worker_index} |
| Multi-node new | {run_name}-{action_name}-worker-{worker_index}-rank-{local_rank} |
Example: Distributed training task
from flyteplugins.wandb import wandb_init, wandb_config, get_wandb_run, get_distributed_info
from flyteplugins.pytorch.task import Elastic
# Multi-node environment (2 nodes, 4 GPUs each)
multi_node_env = flyte.TaskEnvironment(
name="multi_node_env",
resources=flyte.Resources(gpu="V100:4", shm="auto"),
plugin_config=Elastic(nproc_per_node=4, nnodes=2),
secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
)
@wandb_init # run_mode="auto" by default
@multi_node_env.task
def train_multi_node():
import torch.distributed as dist
dist.init_process_group("nccl")
run = get_wandb_run() # Returns run for local_rank 0, None for others
dist_info = get_distributed_info()
# Training loop...
if run:
run.log({"loss": loss.item()})
dist.destroy_process_group()
Shared mode for all-Rank logging
Use run_mode="shared" when you want all ranks to log to the same W&B run:
@wandb_init(run_mode="shared")
@multi_node_env.task
def train_all_ranks_log():
run = get_wandb_run() # All ranks get a run object
# All ranks can log - W&B handles deduplication
run.log({"loss": loss.item(), "rank": dist.get_rank()})
New mode for per-rank runs
Use run_mode="new" when you want each rank to have its own W&B run:
@wandb_init(run_mode="new")
@multi_node_env.task
def train_per_rank():
run = get_wandb_run() # Each rank gets its own run
# Runs are grouped in W&B UI for easy comparison
run.log({"loss": loss.item()})
Configuration
wandb_config
Use wandb_config() to pass configuration that propagates to child tasks:
from flyteplugins.wandb import wandb_config
# With flyte.with_runcontext
run = flyte.with_runcontext(
custom_context=wandb_config(
project="my-project",
entity="my-team",
tags=["experiment-1"],
)
).run(my_task)
# As a context manager
with wandb_config(project="override-project"):
await child_task()
Decorator vs context config
- Decorator arguments (
@wandb_init(project=...)) are available only within the current task and its traces - Context config (
wandb_config(...)) propagates to child tasks
W&B links
Tasks decorated with @wandb_init or @wandb_sweep automatically get W&B links in the Flyte UI:
- For distributed training with multiple workers, each worker gets its own link
- Links point directly to the corresponding W&B runs or sweeps
- Project/entity are retrieved from decorator parameters or context configuration
Sweeps
Use @wandb_sweep to create W&B sweeps:
from flyteplugins.wandb import wandb_sweep, wandb_sweep_config, get_wandb_sweep_id
@wandb_init
def objective():
# Training logic - this runs for each sweep trial
run = get_wandb_run()
config = run.config # Sweep parameters are passed via run.config
# Train with sweep-suggested hyperparameters
model = train(lr=config.lr, batch_size=config.batch_size)
wandb.log({"loss": loss, "accuracy": accuracy})
@wandb_sweep
@env.task
def run_sweep():
sweep_id = get_wandb_sweep_id()
# Launch sweep agents to run trials
# count=10 means run 10 trials total
wandb.agent(sweep_id, function=objective, count=10)
Note: A maximum of 20 sweep agents can be launched at a time.
Configure sweeps with wandb_sweep_config():
run = flyte.with_runcontext(
custom_context=wandb_sweep_config(
method="bayes",
metric={"name": "loss", "goal": "minimize"},
parameters={"lr": {"min": 1e-5, "max": 1e-2}},
project="my-project",
)
).run(run_sweep)
Downloading logs
Set download_logs=True to download W&B run/sweep logs after task completion. The download I/O is traced by Flyte's @flyte.trace, making the logs visible in the Flyte UI:
@wandb_init(download_logs=True)
@env.task
def train():
...
# Or via context
wandb_config(download_logs=True)
wandb_sweep_config(download_logs=True)
The downloaded logs include all files uploaded to W&B during the run (metrics, artifacts, etc.).
API reference
Functions
get_wandb_run()- Get the current W&B run object (orNone)get_wandb_sweep_id()- Get the current sweep ID (orNone)get_distributed_info()- Get distributed training info dict (orNone)wandb_config(...)- Create W&B configuration for contextwandb_sweep_config(...)- Create sweep configuration for context
Decorators
@wandb_init- Initialize W&B for a task or function@wandb_sweep- Create a W&B sweep for a task
Links
Wandb- Link class for W&B runsWandbSweep- Link class for W&B sweeps
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flyteplugins_wandb-2.0.0b54-py3-none-any.whl.
File metadata
- Download URL: flyteplugins_wandb-2.0.0b54-py3-none-any.whl
- Upload date:
- Size: 19.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
feff54195a5ac5e79a0bef272f6e459651a745040fcf970ef289bd050c72a47c
|
|
| MD5 |
f8cb4c54cad4dd09011248c03b0c908e
|
|
| BLAKE2b-256 |
a9422565b64ca397c0f3f56ad1b4111afdf8a1f4145f0aae5ba52f44b7dc78fe
|