Sakura provides asynchronous training for DNN.
Project description
Sakura
ML-framework integrations for Zakuro — hide evaluation, logging, and checkpointing behind training so the main loop never waits.
Quick Start • Installation • Lightning • HuggingFace • TensorFlow • Benchmarks
What is Sakura?
Sakura wraps the framework you're already using (PyTorch Lightning, HuggingFace Trainer, TensorFlow Model.fit) with a callback that dispatches evaluation to a Zakuro worker instead of running it inline. Training keeps stepping while eval runs on a side pool; metrics come back through a non-blocking queue.
The old Sakura used MPI + Redis for this plumbing. The current Sakura uses Zakuro — one @zk.fn dispatch per epoch, shared connection, context-aware allocation across a pool of workers. No MPI, no Redis, no SAKURA_ROLE fork.
Quick start
Laptop-only (no worker setup)
import lightning as L
from sakura.lightning import SakuraTrainer
trainer = SakuraTrainer(
max_epochs=10,
accelerator="auto",
model_factory=MyLightningModule, # rebuilds on the eval worker
val_loader_factory=lambda: val_loader,
)
trainer.run(model, train_loader) # val_compute=None → Zakuro standalone fallback
No zakuro-worker needed — the eval runs in-process via Zakuro's standalone fallback, but the async dispatch pattern still works.
HuggingFace Trainer with a real worker
from transformers import Trainer, TrainingArguments
from sakura.huggingface import SakuraHFCallback
import zakuro as zk
trainer = Trainer(
model=model,
args=TrainingArguments(..., eval_strategy="no"), # we handle eval
train_dataset=train_ds,
callbacks=[
SakuraHFCallback(
model_factory=lambda: AutoModelForSequenceClassification.from_config(config),
eval_fn=my_eval_fn,
eval_payload=(val_inputs, 32),
val_compute=zk.Compute(uri="quic://worker:4433"),
fp16_state_dict=True,
on_backpressure="skip",
)
],
)
trainer.train()
on_backpressure="skip" makes the callback consult AdaptiveCompute.is_backpressured() before every dispatch — if the allocator reports saturation (the slow eval worker can't keep up), that epoch's eval is dropped rather than blocking training.
Installation
# Core + HuggingFace integration
pip install 'sakura-ml[huggingface]'
# Everything
pip install 'sakura-ml[huggingface,tensorflow,bench]'
# From source
git clone https://github.com/zakuro-ai/sakura && cd sakura
uv pip install -e '.[huggingface]'
Zakuro is pulled transitively. For a worker (HTTP or QUIC) install the [worker] extra on the zakuro package.
PyTorch Lightning
sakura.lightning.SakuraTrainer — a drop-in replacement for the async-eval case:
from sakura.lightning import SakuraTrainer
trainer = SakuraTrainer(
max_epochs=10,
accelerator="auto",
# how the eval worker rebuilds the model:
model_factory=lambda: MyLightningModule(),
# how the eval worker rebuilds the dataloader:
val_loader_factory=lambda: DataLoader(val_ds, batch_size=256),
# optional: where to run eval
val_compute=zk.Compute(uri="quic://eval-worker:4433"),
# optional: where to save the best-loss checkpoint
model_path="checkpoints/best.pth",
)
trainer.run(model, train_loader)
print(trainer.history) # [{epoch, val_loss, worker_name, elapsed_secs}, ...]
print(trainer.best_val_loss)
HuggingFace Trainer
sakura.huggingface.SakuraHFCallback is a transformers.TrainerCallback that cloudpickles state_dict on on_epoch_end, dispatches a remote eval, and lazily reaps futures as they finish. Knobs:
| parameter | what it does |
|---|---|
model_factory |
how the eval worker rebuilds the architecture (weights stream in from the callback) |
eval_fn(model, payload) |
the eval routine itself — runs on the worker, returns a dict of metrics |
eval_payload |
anything cloudpickle can serialise — dataset, tokenizer, batch size |
val_compute |
zk.Compute or zk.AdaptiveCompute; None → standalone |
drain="lazy" (default) / "strict" |
whether on_epoch_end blocks to reap the previous future |
cache_key=... |
keep the validator model architecture warm on the worker |
fp16_state_dict=True |
halve the wire bytes |
async_copy=True (default, CUDA-only) |
GPU→CPU snapshot on a dedicated stream, ~170 → 75 ms per epoch on x399 4090 |
on_backpressure={"skip","queue","block"} |
policy when AdaptiveCompute reports saturation |
max_pending |
cap on in-flight evaluations |
In-memory fast path is automatic: when val_compute resolves to standalone, torch.save/torch.load are skipped entirely — measured +23.6 % wall on a 3-epoch distilbert fine-tune vs forced serialisation.
TensorFlow / Keras
sakura.tensorflow.SakuraKerasCallback — a tf.keras.callbacks.Callback with the same pattern:
from sakura.tensorflow import SakuraKerasCallback
model.fit(
x_train, y_train,
epochs=10,
callbacks=[SakuraKerasCallback(
model_factory=lambda: tf.keras.Sequential([...]),
val_fn=lambda m, p: m.evaluate(*p, verbose=0, return_dict=True),
val_payload=(x_val, y_val),
val_compute=zk.Compute(uri="quic://eval-worker:4433"),
)],
)
Weights are transferred as numpy arrays via get_weights() / set_weights() — clean cloudpickle, no TF graph-state serialisation.
Generic async trainer (framework-agnostic)
sakura.ml.async_trainer.AsyncTrainer — for training loops that aren't Lightning / HF / Keras. Takes any object implementing train(loader), serialized_state_dict(), _epochs, _metrics, plus a model_factory and test_fn(model) -> dict. Same dispatch mechanics.
Benchmarks & notebooks
bert_demo/hf_async_features.ipynb— everySakuraHFCallbackknob exercised on distilbert / SST-2. Runs in ~1 min on a laptop. Verified viajupyter nbconvert --execute.bert_demo/bench_bert.py— serialTrainervs Sakura async, configurable.bert_demo/bench_in_memory_handle.py— A/B the in-memory-handle fast path againsttorch.save. +23.6 % end-to-end measured.
Measured performance wins (distilbert-base-uncased, 268 MB state_dict)
| slice | before | after | measured on |
|---|---|---|---|
blocking .cpu() → async CUDA-stream copy |
176 ms / epoch main-thread | 75 ms / epoch | x399 4090 |
cloudpickle → torch.save for state_dict |
482 ms / epoch pool | 282 ms / epoch | x399 CPU |
| in-memory handle for standalone | 9.12 s wall (3 epochs) | 7.59 s | Mac MPS |
See zakuro/PLAN.md for the consolidated numbers across both repos.
Development
git clone https://github.com/zakuro-ai/sakura && cd sakura
uv pip install -e '.[bench]'
uv run pytest tests/
License
BSD-3-Clause.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sakura_ml-0.1.4.tar.gz.
File metadata
- Download URL: sakura_ml-0.1.4.tar.gz
- Upload date:
- Size: 266.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7209fbbefa38cc13f4aa62ba6d776370d2ecd4c6a6dd4cf0074f56a468a8280d
|
|
| MD5 |
e5a73b149293ace5bfcc24f642d17070
|
|
| BLAKE2b-256 |
050ff038fa0a5eb700494b136859d9eee543faf0e8f87e44d17693bb473352fc
|
File details
Details for the file sakura_ml-0.1.4-py3-none-any.whl.
File metadata
- Download URL: sakura_ml-0.1.4-py3-none-any.whl
- Upload date:
- Size: 26.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a4a380dae3befd32ec207ebff9957b44f0ade0bb42920cb90f05fead4198b52f
|
|
| MD5 |
8ec45f8a5724ecd4cc96bd99b9558dfc
|
|
| BLAKE2b-256 |
4f2499d8903050b5cc558d5a49443480dd6028869e8e35225b0ce085214e7d0b
|