Paving the way between black-box and white-box modeling.
Project description
About WeightsLab
WeightsLab is a powerful tool for editing and inspecting data & AI models.
What Problems Does It Solve?
WeightsLab addresses critical AI research challenges:
- Dataset insights & optimization
- Overfitting and training plateau
- Over/Under parameterization
Key Capabilities
The granular statistics and interactive paradigm enable powerful workflows:
- Monitor granular insights on data samples, signals, and weight parameters
- Use the AI agent to:
- Create slices of data and discard them for the next training iteration
- Discard low-quality samples from training data
- Iterative pruning or growing of the architectures (INCOMING feature)
Quick Start
Requirements
- Docker Desktop v4.77 or newer — required to deploy the Weights Studio UI (
weightslab ui launch). - Docker Compose v2 (the
docker composeCLI plugin, bundled with Docker Desktop) — recommended. The legacy v1 standalone binary (docker-compose, ≥ 1.27) also works:weightslab ui launchauto-detects whichever is installed and uses it. Compose v1 below 1.27 is not supported. - Python >=3.10, <3.15 — to install and run the
weightslabframework.
Installation
Install directly on your machine.
[!TIP] Setting a clean Python environment:
python -m venv weightslab_venv ./weightslab_venv/Scripts/activate
Install our framework:
pip install weightslab
Deploy our interface:
weightslab ui launch
The command weightslab ui launch removes any stale weightslab/weights_studio Docker resources that could break the launch, then starts the UI stack. By default, it runs unsecured (HTTP, no gRPC auth) — no certificates are generated. However, communication are not safe.
[!TIP] To run secured communication, pass the arguments
--certs:weightslab ui launch --certs # generates TLS certs + a gRPC auth token if missing, then launches securedWhen using certs, set
WEIGHTSLAB_CERTS_DIRso the training backend and any new terminal use the same certificates (it is the single source of truth).weightslab seandweightslab ui launch --certsprint the exact export/setxcommand for your shell. You can also generate certs up front withweightslab se.
[!IMPORTANT] For a detailed installation guide and more advanced features, please see the Installation Documentation.
Quick Training Example
Step-by-Step Integration
-
Add the import at the top of your script:
import weightslab as wl # ← Include our SDK into your experiment
-
Wrap your parameters with WeightsLab tracking:
model = wl.watch_or_edit(parameters, flag='hp', ...) # ← Now WeightsLab monitors your parameters and allow you to update them from your UI
-
Wrap your model with WeightsLab tracking:
model = wl.watch_or_edit(SimpleModel(...), flag='model', ...) # ← Now WeightsLab monitors your model state
-
Wrap your optimizer with WeightsLab tracking:
optimizer = wl.watch_or_edit(optim.Adam(...), flag='opt', ...) # ← Tracks optimizer state and update optimizer learning rate from your UI
-
Wrap your signal with WeightsLab tracking:
train_criterion = wl.watch_or_edit(nn.CrossEntropyLoss(reduction="none"), flag='signal', name="train_loss/sample", per_sample=True, log=True) # ← Tracks this signal and others (metrics, ..etc) from your UI test_criterion = wl.watch_or_edit(nn.CrossEntropyLoss(reduction="none"), flag='signal', name="test_loss/sample", per_sample=True, log=False) # ← Tracks this signal and others (metrics, ..etc) from your UI - Plot is disabled, only per sample signal
-
Wrap your dataset with WeightsLab tracking:
train_loader = wl.watch_or_edit(train_dataset, flag='data', loader_name="train_loader", ...) # ← Tracks this dataset and others (validation, test) from your UI val_loader = wl.watch_or_edit(val_dataset, flag='data', loader_name="val_loader", ...) # ← Tracks this dataset and others (validation, test) from your UI
-
Run your training script as usual:
python train.py -
Launch the UI in another terminal:
weightslab ui launch
-
Open your browser to
https://localhost:5173to track experiment evoluation and results!
Details
Here's a complete example showing how to integrate WeightsLab into a basic PyTorch training script:
#!/usr/bin/env python3
"""
Basic PyTorch training script with WeightsLab integration
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import weightslab as wl # ← Import WeightsLab (uses TLS certs from WEIGHTSLAB_CERTS_DIR if present)
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__(input_shape=12, output_shape=2)
self.linear = nn.Linear(input_shape, 1)
def forward(self, x):
return self.linear(x)
# Create synthetic data
def create_data(n_samples=1000):
X = torch.randn(n_samples, 10)
y = X.sum(dim=1, keepdim=True) + 0.1 * torch.randn(n_samples, 1)
return TensorDataset(X, y)
# Main training function
def main():
# Initialize WeightsLab - this creates certificates automatically!
print("🚀 Initializing WeightsLab...")
# Load hyperparameters (from YAML if present)
parameters = {}
config_path = os.path.join(os.path.dirname(__file__), "config.yaml")
if os.path.exists(config_path):
with open(config_path, "r") as fh:
parameters = yaml.safe_load(fh) or {}
parameters = wl.watch_or_edit(
parameters,
flag="hyperparameters",
defaults=parameters,
poll_interval=1.0,
) or {} # Wrap the hyperparameters
# Wrap your model and optimizer with WeightsLab
model = wl.watch_or_edit(
SimpleModel(
input_shape=parameters.get('model', {}).get('input_shape', 10),
output_shape=parameters.get('model', {}).get('output_shape', 1)
)
) # ← WeightsLab tracks your model
optimizer = wl.watch_or_edit(
optim.Adam(model.parameters(), lr=parameters.get('model', {}).get('optimizer', {}).get('lr', 0.01)),
flag='optimizer'
) # ← WeightsLab tracks optimizer
# Create and wrap criterion
criterion = wl.watch_or_edit(
nn.CrossEntropyLoss(reduction="none"),
flag="loss",
signal_name="train-loss-CE",
log=True # If log is False, only save per sample value, not plot criterion
)
# Create data and dataloader
dataset = create_data()
train_loader = wl.watch_or_edit(
dataset,
flag="data",
loader_name="loader",
batch_size=parameters.get('data', {}).get('train_loader', {}).get('batch_size', 8),
shuffle=parameters.get('data', {}).get('train_loader', {}).get('shuffle', False),
is_training=True, # Is it the training dataloader ?
compute_hash=parameters.get('data', {}).get('train_loader', {}).get('compute_hash', True), # Compute hash for train loader to allow dynamic augmentations and dataset sanity check
preload_labels=parameters.get('data', {}).get('train_loader', {}).get('preload_labels', True),
preload_metadata=parameters.get('data', {}).get('train_loader', {}).get('preload_metadata', True),
enable_h5_persistence=parameters.get('data', {}).get('train_loader', {}).get('enable_h5_persistence', True),
num_workers=parameters.get('data', {}).get('train_loader', {}).get('num_workers', 4)
)
# Training loop
print("🏃 Starting training...")
print("💡 Launch the UI with: weightslab ui launch")
print("🌐 Open browser to: https://localhost:5173")
n_epochs = parameters.get('n_epochs')
pbar = tqdm.tqdm(range(n_epochs), desc='Training..') if parameters.get('tqdm_display', False) else range(n_epochs)
for epoch in pbar: # Train for 5 epochs
total_loss = 0
for batch_X, batch_y in dataloader:
# Forward pass
predictions = model(batch_X)
loss = criterion(predictions, batch_y)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/5 - Loss: {avg_loss:.4f}")
print("✅ Training complete!")
if __name__ == "__main__":
main()
Migrating from wandb? See the diff:
--- train_baseline.py
+++ train_wl.py
@@ -1,11 +1,12 @@
import argparse
import torch
import torch.nn as nn
-from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchmetrics.classification import MulticlassAccuracy
-import wandb
+import weightslab as wl
+from weightslab.components.global_monitoring import (
+ guard_training_context, guard_testing_context)
+
+@wl.signal(name="byte_adjusted_loss", subscribe_to="loss/CE")
+def byte_adjusted_loss(ctx): return ctx.subscribed_value / ctx.image_bytes # chains on image_bytes
+
def main():
@@ -15,29 +16,38 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parameters = {"batch_size": 128, "lr": 1e-3}
- wandb.init(project="cifar10")
-
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
train_set = datasets.CIFAR10("./data", train=True, download=True, transform=transform)
test_set = datasets.CIFAR10("./data", train=False, download=True, transform=transform)
- train_loader = DataLoader(train_set, batch_size=parameters["batch_size"], shuffle=True, num_workers=2)
- test_loader = DataLoader(test_set, batch_size=256, num_workers=2)
+ wl.watch_or_edit(parameters, flag="hyperparameters") # live-editable in UI
+
+ train_loader = wl.watch_or_edit(
+ train_set, flag="data", loader_name="train_loader",
+ batch_size=parameters["batch_size"], shuffle=True, is_training=True)
+ test_loader = wl.watch_or_edit(
+ test_set, flag="data", loader_name="test_loader",
+ batch_size=256, shuffle=False, is_training=False)
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=parameters["lr"])
- criterion = nn.CrossEntropyLoss()
- accuracy = MulticlassAccuracy(num_classes=10).to(device)
+ criterion = wl.watch_or_edit(
+ nn.CrossEntropyLoss(), flag="loss", signal_name="loss/CE")
+ accuracy = wl.watch_or_edit(
+ MulticlassAccuracy(num_classes=10).to(device),
+ flag="metric", signal_name="acc")
+
+ wl.serve(serving_grpc=True)
for epoch in range(1, args.epochs + 1):
model.train()
accuracy.reset()
for x, y in train_loader:
+ with guard_training_context:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = criterion(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
accuracy.update(logits, y)
- wandb.log({"train/loss": loss.item()})
- wandb.log({"train/acc": accuracy.compute().item(), "epoch": epoch})
+ wl.save_signals(preds_raw=logits, targets=y,
+ signals={"metric/accuracy": accuracy.compute().item()})
model.eval()
accuracy.reset()
with torch.no_grad():
for x, y in test_loader:
+ with guard_testing_context:
x, y = x.to(device), y.to(device)
accuracy.update(model(x), y)
- wandb.log({"test/acc": accuracy.compute().item(), "epoch": epoch})
+ wl.save_signals(preds_raw=logits, targets=y,
+ signals={"metric/accuracy": accuracy.compute().item()})
- wandb.finish()
+ wl.keep_serving()
What WeightsLab Does Automatically
- 📊 Experiment tracking for reproducibility
- 📈 Provides live metrics and visualization in the web UI
- 🔄 Enables data supervision during training and experiment hyperparameter tuning through the UI
Examples
Local examples
After starting the UI, launch a local experiment with the command:
weightslab start example # classification (default)
# weightslab start example --cls # classification
# weightslab start example --seg # segmentation
# weightslab start example --det # detection
# weightslab start example --clus # clustering
# weightslab start example --gen # generation
Cloud examples
Find our sandbox online. The password is graybx.
Documentation (API + SDK)
Find our documentation online.
Contributing & onboarding
New here (human or AI coding agent)? Start with AGENTS.md — it
captures the cross-repo architecture (weightslab backend ↔ weights_studio
frontend via the shared proto), the module maps, the wl.watch_or_edit
integration pattern, where tests live, and the gotchas that aren't obvious from
any single file. It's the fastest way to orient before a first change.
Community
Graybx is building a wonderful community of AI researchers and engineers. Are you interested in joining our project? Contact us at hello [at] graybx [dot] com
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file weightslab-1.2.6.tar.gz.
File metadata
- Download URL: weightslab-1.2.6.tar.gz
- Upload date:
- Size: 2.9 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1eed5966475db4ec1824e9d0a79e1f64e0d73f40d13436781c4fdc2002ca7d98
|
|
| MD5 |
47241874c09bcff66a947313532fd75c
|
|
| BLAKE2b-256 |
1535c79bb86be4de5bf326fff3db1db2b9897879c16198dd5413d1a54e8829b3
|
File details
Details for the file weightslab-1.2.6-py3-none-any.whl.
File metadata
- Download URL: weightslab-1.2.6-py3-none-any.whl
- Upload date:
- Size: 2.8 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f40d7dbf613095d1c5a17ac68157cb91944a18912dac9f34a25a04c114e23850
|
|
| MD5 |
002f215b56d63a042a16db005e200853
|
|
| BLAKE2b-256 |
14b92b080628ce484199648028a46ecad6bd9506e0955d8fe9500cf8ef4247ba
|