Weights & Biases plugin for Flyte
Project description
Weights & Biases Plugin
This plugin provides integration between Flyte and Weights & Biases (W&B) for experiment tracking, including support for distributed training with PyTorch Elastic.
Quickstart
from flyteplugins.wandb import wandb_init, wandb_config, get_wandb_run
@wandb_init(project="my-project", entity="my-team")
@env.task
def train():
run = get_wandb_run()
run.log({"loss": 0.5, "accuracy": 0.9})
Core concepts
Decorator order
@wandb_init and @wandb_sweep must be the outermost decorators (applied after @env.task):
@wandb_init # Outermost
@env.task # Task decorator
def my_task():
...
Run modes
The run_mode parameter controls how W&B runs are created:
"auto"(default): Creates a new run if no parent exists, otherwise shares the parent's run"new": Always creates a new W&B run with a unique ID"shared": Always shares the parent's run ID (useful for child tasks)
Accessing the run
Use get_wandb_run() to access the current W&B run:
from flyteplugins.wandb import get_wandb_run
run = get_wandb_run()
if run:
run.log({"metric": value})
Returns None if not within a @wandb_init decorated task or if the current rank should not log (in distributed training).
Distributed training
The plugin automatically detects distributed training environments (PyTorch Elastic) and configures W&B appropriately.
Environment variables
Distributed training is detected via these environment variables (set by torchrun/torch.distributed.elastic):
| Variable | Description |
|---|---|
RANK |
Global rank of the process |
WORLD_SIZE |
Total number of processes |
LOCAL_RANK |
Rank within the current node |
LOCAL_WORLD_SIZE |
Number of processes per node |
GROUP_RANK |
Worker/node index (0, 1, 2, ...) |
Rank scope
The rank_scope parameter controls the granularity of W&B runs in multi-node distributed training:
"global"(default): Treat all workers as one unit → 1 run (or 1 group forrun_mode="new")"worker": Treat each worker/node independently → N runs (or N groups forrun_mode="new")
The effect of rank_scope depends on run_mode:
run_mode="auto" + rank_scope
# Global scope (default): Only global rank 0 logs → 1 run total
@wandb_init
@multi_node_env.task
def train():
run = get_wandb_run() # Non-None only for global rank 0
...
# Worker scope: Local rank 0 of each worker logs → N runs (1 per worker)
@wandb_init(rank_scope="worker")
@multi_node_env.task
def train():
run = get_wandb_run() # Non-None for local_rank 0 on each worker
...
run_mode="shared" + rank_scope
# Global scope: All ranks log to 1 shared run
@wandb_init(run_mode="shared")
@multi_node_env.task
def train():
run = get_wandb_run() # All ranks get a run object, all log to same run
...
# Worker scope: All ranks on each worker share a run → N runs total
@wandb_init(run_mode="shared", rank_scope="worker")
@multi_node_env.task
def train():
run = get_wandb_run() # All ranks get a run, grouped by worker
...
run_mode="new" + rank_scope
# Global scope: Each rank gets own run, all grouped together → N×M runs, 1 group
@wandb_init(run_mode="new")
@multi_node_env.task
def train():
run = get_wandb_run() # Each rank has its own run
# Run IDs: {base}-rank-{global_rank}
...
# Worker scope: Each rank gets own run, grouped per worker → N×M runs, N groups
@wandb_init(run_mode="new", rank_scope="worker")
@multi_node_env.task
def train():
run = get_wandb_run() # Each rank has its own run
# Run IDs: {base}-worker-{idx}-rank-{local_rank}
...
Run modes in distributed context
| run_mode | rank_scope | Who initializes W&B? | W&B Runs | Grouping |
|---|---|---|---|---|
"auto" |
"global" |
global rank 0 only | 1 | - |
"auto" |
"worker" |
local_rank 0 per worker | N | - |
"shared" |
"global" |
all ranks (shared mode) | 1 | - |
"shared" |
"worker" |
all ranks (shared mode) | N | - |
"new" |
"global" |
all ranks | N×M | 1 group |
"new" |
"worker" |
all ranks | N×M | N groups |
Where N = number of workers/nodes, M = processes per worker.
Run ID patterns
| Scenario | Run ID Pattern | Group |
|---|---|---|
| Single-node auto/shared | {base} |
- |
| Single-node new | {base}-rank-{rank} |
{base} |
| Multi-node auto (global) | {base} |
- |
| Multi-node auto (worker) | {base}-worker-{idx} |
- |
| Multi-node shared (global) | {base} |
- |
| Multi-node shared (worker) | {base}-worker-{idx} |
- |
| Multi-node new (global) | {base}-rank-{global_rank} |
{base} |
| Multi-node new (worker) | {base}-worker-{idx}-rank-{local_rank} |
{base}-worker-{idx} |
Where {base} = {run_name}-{action_name}
Example: Distributed training task
from flyteplugins.wandb import wandb_init, wandb_config, get_wandb_run, get_distributed_info
from flyteplugins.pytorch.task import Elastic
# Multi-node environment (2 nodes, 4 GPUs each)
multi_node_env = flyte.TaskEnvironment(
name="multi_node_env",
resources=flyte.Resources(gpu="V100:4", shm="auto"),
plugin_config=Elastic(nproc_per_node=4, nnodes=2),
secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
)
@wandb_init # run_mode="auto", rank_scope="global" by default → 1 run total
@multi_node_env.task
def train_multi_node():
import torch.distributed as dist
dist.init_process_group("nccl")
run = get_wandb_run() # Returns run for global rank 0 only, None for others
dist_info = get_distributed_info()
# Training loop...
if run:
run.log({"loss": loss.item()})
dist.destroy_process_group()
Worker scope for per-worker logging
Use rank_scope="worker" when you want each worker/node to have its own W&B run:
@wandb_init(rank_scope="worker") # 1 run per worker
@multi_node_env.task
def train_per_worker():
run = get_wandb_run() # Returns run for local_rank 0 of each worker
if run:
# Each worker logs to its own run
run.log({"loss": loss.item(), "worker": dist_info["worker_index"]})
Shared mode for all-Rank logging
Use run_mode="shared" when you want all ranks to log to the same W&B run:
@wandb_init(run_mode="shared")
@multi_node_env.task
def train_all_ranks_log():
run = get_wandb_run() # All ranks get a run object
# All ranks can log - W&B handles deduplication
run.log({"loss": loss.item(), "rank": dist.get_rank()})
New mode for per-rank runs
Use run_mode="new" when you want each rank to have its own W&B run:
@wandb_init(run_mode="new")
@multi_node_env.task
def train_per_rank():
run = get_wandb_run() # Each rank gets its own run
# Runs are grouped in W&B UI for easy comparison
run.log({"loss": loss.item()})
Configuration
wandb_config
Use wandb_config() to pass configuration that propagates to child tasks:
from flyteplugins.wandb import wandb_config
# With flyte.with_runcontext
run = flyte.with_runcontext(
custom_context=wandb_config(
project="my-project",
entity="my-team",
tags=["experiment-1"],
)
).run(my_task)
# As a context manager
with wandb_config(project="override-project"):
await child_task()
Configuring run_mode and rank_scope
Both run_mode and rank_scope can be set via decorator or context:
# Via decorator (takes precedence)
@wandb_init(run_mode="shared", rank_scope="worker")
@multi_node_env.task
def train():
...
# Via context (useful for dynamic configuration)
run = flyte.with_runcontext(
custom_context=wandb_config(
project="my-project",
run_mode="shared",
rank_scope="worker",
)
).run(train)
When both are specified, decorator arguments take precedence over context config.
Decorator vs context config
| Source | Scope | Use case |
|---|---|---|
Decorator (@wandb_init(...)) |
Current task and traces only | Static per-task config |
Context (wandb_config(...)) |
Propagates to child tasks | Dynamic/shared config |
Priority order (highest to lowest):
- Decorator arguments
- Context config (
wandb_config) - Defaults (
run_mode="auto",rank_scope="global")
W&B links
Tasks decorated with @wandb_init or @wandb_sweep automatically get W&B links in the Flyte UI:
- With
rank_scope="global"(default): A single link to the one W&B run - With
rank_scope="worker": Each worker gets its own link - Links point directly to the corresponding W&B runs or sweeps
- Project/entity are retrieved from decorator parameters or context configuration
Sweeps
Use @wandb_sweep to create W&B sweeps:
from flyteplugins.wandb import wandb_sweep, wandb_sweep_config, get_wandb_sweep_id
@wandb_init
def objective():
# Training logic - this runs for each sweep trial
run = get_wandb_run()
config = run.config # Sweep parameters are passed via run.config
# Train with sweep-suggested hyperparameters
model = train(lr=config.lr, batch_size=config.batch_size)
wandb.log({"loss": loss, "accuracy": accuracy})
@wandb_sweep
@env.task
def run_sweep():
sweep_id = get_wandb_sweep_id()
# Launch sweep agents to run trials
# count=10 means run 10 trials total
wandb.agent(sweep_id, function=objective, count=10)
Note: A maximum of 20 sweep agents can be launched at a time.
Configure sweeps with wandb_sweep_config():
run = flyte.with_runcontext(
custom_context=wandb_sweep_config(
method="bayes",
metric={"name": "loss", "goal": "minimize"},
parameters={"lr": {"min": 1e-5, "max": 1e-2}},
project="my-project",
)
).run(run_sweep)
Downloading logs
Set download_logs=True to download W&B run/sweep logs after task completion. The download I/O is traced by Flyte's @flyte.trace, making the logs visible in the Flyte UI:
@wandb_init(download_logs=True)
@env.task
def train():
...
# Or via context
wandb_config(download_logs=True)
wandb_sweep_config(download_logs=True)
The downloaded logs include all files uploaded to W&B during the run (metrics, artifacts, etc.).
API reference
Functions
get_wandb_run()- Get the current W&B run object (orNone)get_wandb_sweep_id()- Get the current sweep ID (orNone)get_distributed_info()- Get distributed training info dict (orNone)wandb_config(...)- Create W&B configuration for contextwandb_sweep_config(...)- Create sweep configuration for context
Decorators
@wandb_init- Initialize W&B for a task or functionrun_mode:"auto"(default),"new", or"shared"rank_scope:"global"(default) or"worker"- controls which ranks log in distributed trainingdownload_logs: IfTrue, download W&B logs after task completionproject,entity: W&B project and entity names
@wandb_sweep- Create a W&B sweep for a task
Links
Wandb- Link class for W&B runsWandbSweep- Link class for W&B sweeps
Types
RankScope- Literal type:"global"|"worker"RunMode- Literal type:"auto"|"new"|"shared"
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flyteplugins_wandb-2.0.0b60-py3-none-any.whl.
File metadata
- Download URL: flyteplugins_wandb-2.0.0b60-py3-none-any.whl
- Upload date:
- Size: 22.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
479404925922fa8172b60d5d315be5fccf063995ec140badd6bec12040be1bcc
|
|
| MD5 |
fdcd541f3f4a8614e7d1981077a432b8
|
|
| BLAKE2b-256 |
4ec06d7a9ff8c7a21b477e8540ecc9e55acf33182db9c4e132bdff228ac58c24
|