Skip to main content

Sakura provides asynchronous training for DNN.

Project description

Sakura

ML-framework integrations for Zakuro — hide evaluation, logging, and checkpointing behind training so the main loop never waits.

Quick StartInstallationLightningHuggingFaceTensorFlowBenchmarks


What is Sakura?

Sakura wraps the framework you're already using (PyTorch Lightning, HuggingFace Trainer, TensorFlow Model.fit) with a callback that dispatches evaluation to a Zakuro worker instead of running it inline. Training keeps stepping while eval runs on a side pool; metrics come back through a non-blocking queue.

The old Sakura used MPI + Redis for this plumbing. The current Sakura uses Zakuro — one @zk.fn dispatch per epoch, shared connection, context-aware allocation across a pool of workers. No MPI, no Redis, no SAKURA_ROLE fork.

Quick start

Laptop-only (no worker setup)

import lightning as L
from sakura.lightning import SakuraTrainer

trainer = SakuraTrainer(
    max_epochs=10,
    accelerator="auto",
    model_factory=MyLightningModule,     # rebuilds on the eval worker
    val_loader_factory=lambda: val_loader,
)
trainer.run(model, train_loader)          # val_compute=None → Zakuro standalone fallback

No zakuro-worker needed — the eval runs in-process via Zakuro's standalone fallback, but the async dispatch pattern still works.

HuggingFace Trainer with a real worker

from transformers import Trainer, TrainingArguments
from sakura.huggingface import SakuraHFCallback
import zakuro as zk

trainer = Trainer(
    model=model,
    args=TrainingArguments(..., eval_strategy="no"),   # we handle eval
    train_dataset=train_ds,
    callbacks=[
        SakuraHFCallback(
            model_factory=lambda: AutoModelForSequenceClassification.from_config(config),
            eval_fn=my_eval_fn,
            eval_payload=(val_inputs, 32),
            val_compute=zk.Compute(uri="quic://worker:4433"),
            fp16_state_dict=True,
            on_backpressure="skip",
        )
    ],
)
trainer.train()

on_backpressure="skip" makes the callback consult AdaptiveCompute.is_backpressured() before every dispatch — if the allocator reports saturation (the slow eval worker can't keep up), that epoch's eval is dropped rather than blocking training.

Installation

# Core + HuggingFace integration
pip install 'sakura-ml[huggingface]'

# Everything
pip install 'sakura-ml[huggingface,tensorflow,bench]'

# From source
git clone https://github.com/zakuro-ai/sakura && cd sakura
uv pip install -e '.[huggingface]'

Zakuro is pulled transitively. For a worker (HTTP or QUIC) install the [worker] extra on the zakuro package.

PyTorch Lightning

sakura.lightning.SakuraTrainer — a drop-in replacement for the async-eval case:

from sakura.lightning import SakuraTrainer

trainer = SakuraTrainer(
    max_epochs=10,
    accelerator="auto",
    # how the eval worker rebuilds the model:
    model_factory=lambda: MyLightningModule(),
    # how the eval worker rebuilds the dataloader:
    val_loader_factory=lambda: DataLoader(val_ds, batch_size=256),
    # optional: where to run eval
    val_compute=zk.Compute(uri="quic://eval-worker:4433"),
    # optional: where to save the best-loss checkpoint
    model_path="checkpoints/best.pth",
)
trainer.run(model, train_loader)

print(trainer.history)         # [{epoch, val_loss, worker_name, elapsed_secs}, ...]
print(trainer.best_val_loss)

HuggingFace Trainer

sakura.huggingface.SakuraHFCallback is a transformers.TrainerCallback that cloudpickles state_dict on on_epoch_end, dispatches a remote eval, and lazily reaps futures as they finish. Knobs:

parameter what it does
model_factory how the eval worker rebuilds the architecture (weights stream in from the callback)
eval_fn(model, payload) the eval routine itself — runs on the worker, returns a dict of metrics
eval_payload anything cloudpickle can serialise — dataset, tokenizer, batch size
val_compute zk.Compute or zk.AdaptiveCompute; None → standalone
drain="lazy" (default) / "strict" whether on_epoch_end blocks to reap the previous future
cache_key=... keep the validator model architecture warm on the worker
fp16_state_dict=True halve the wire bytes
async_copy=True (default, CUDA-only) GPU→CPU snapshot on a dedicated stream, ~170 → 75 ms per epoch on x399 4090
on_backpressure={"skip","queue","block"} policy when AdaptiveCompute reports saturation
max_pending cap on in-flight evaluations

In-memory fast path is automatic: when val_compute resolves to standalone, torch.save/torch.load are skipped entirely — measured +23.6 % wall on a 3-epoch distilbert fine-tune vs forced serialisation.

TensorFlow / Keras

sakura.tensorflow.SakuraKerasCallback — a tf.keras.callbacks.Callback with the same pattern:

from sakura.tensorflow import SakuraKerasCallback

model.fit(
    x_train, y_train,
    epochs=10,
    callbacks=[SakuraKerasCallback(
        model_factory=lambda: tf.keras.Sequential([...]),
        val_fn=lambda m, p: m.evaluate(*p, verbose=0, return_dict=True),
        val_payload=(x_val, y_val),
        val_compute=zk.Compute(uri="quic://eval-worker:4433"),
    )],
)

Weights are transferred as numpy arrays via get_weights() / set_weights() — clean cloudpickle, no TF graph-state serialisation.

Generic async trainer (framework-agnostic)

sakura.ml.async_trainer.AsyncTrainer — for training loops that aren't Lightning / HF / Keras. Takes any object implementing train(loader), serialized_state_dict(), _epochs, _metrics, plus a model_factory and test_fn(model) -> dict. Same dispatch mechanics.

Benchmarks & notebooks

Measured performance wins (distilbert-base-uncased, 268 MB state_dict)

slice before after measured on
blocking .cpu() → async CUDA-stream copy 176 ms / epoch main-thread 75 ms / epoch x399 4090
cloudpickle → torch.save for state_dict 482 ms / epoch pool 282 ms / epoch x399 CPU
in-memory handle for standalone 9.12 s wall (3 epochs) 7.59 s Mac MPS

See zakuro/PLAN.md for the consolidated numbers across both repos.

Development

git clone https://github.com/zakuro-ai/sakura && cd sakura
uv pip install -e '.[bench]'
uv run pytest tests/

License

BSD-3-Clause.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sakura_ml-0.1.5.tar.gz (268.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sakura_ml-0.1.5-py3-none-any.whl (26.7 kB view details)

Uploaded Python 3

File details

Details for the file sakura_ml-0.1.5.tar.gz.

File metadata

  • Download URL: sakura_ml-0.1.5.tar.gz
  • Upload date:
  • Size: 268.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for sakura_ml-0.1.5.tar.gz
Algorithm Hash digest
SHA256 31afc78a19966d13be8d018f45cafd8797a1c92dd5c4e627cb8e08ef06de2694
MD5 d3b3d6069856fa6b6ef771bc4c0e6cb9
BLAKE2b-256 c8b1e96d1daa28ed63a390d3a8b967c5e10c0d408d2f188aeb5711d4e744a426

See more details on using hashes here.

File details

Details for the file sakura_ml-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: sakura_ml-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 26.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for sakura_ml-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 7b01e238ddab76722757349a267d1b4121c6baed3d4b5ac902f6cfc0e37a0820
MD5 76b47eb8e507931fd4320a9375f4e450
BLAKE2b-256 0785d17b217bb69b7385e401095f418ac4a3e38a4fb02fc995ea02e03442d4a5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page