Skip to main content

Lightweight event-driven PyTorch trainer with composable callbacks

Project description

coffeetrain

Lightweight event-driven PyTorch trainer with composable callbacks. Inspired by MosaicML Composer but implemented fewer external dependencies and allowing a newer PyTorch version (2.10 as of this point).

Features

  • Event-driven lifecycle: fit_start, epoch_start, batch_start, before_forward, after_forward, before_loss, after_loss, before_backward, after_backward, batch_end, eval_*, etc.
  • Composable callbacks: EMA, SWA, checkpointing, W&B, Comet, early stopping, batch size scheduling, and more.
  • TrainerModel protocol: Simple interface (forward, loss) for model integration.
  • Accelerate support: Distributed training via HuggingFace Accelerator.

Installation

pip install coffeetrain

Optional extras:

pip install coffeetrain[wandb,comet,optimi]

Quick Start

from coffeetrain import Trainer, CosineWarmupScheduler, HistoryCallback, BestModelCheckpointer
from coffeetrain import create_optimizer
from coffeetrain.optimizers import OptimizerConfig

model = MyModel()
optimizer = create_optimizer(model.parameters(), OptimizerConfig(name="adamw", lr=1e-4, weight_decay=0.01))
scheduler = CosineWarmupScheduler(optimizer, warmup_steps=100, total_steps=1000)

trainer = Trainer(
    model=model,
    train_dataloader=train_loader,
    optimizers=optimizer,
    schedulers=scheduler,
    max_epochs=10,
    callbacks=[
        HistoryCallback(save_dir="output"),
        BestModelCheckpointer(save_dir="output", metric_name="loss", mode="min"),
    ],
)
trainer.fit()

Callbacks

Callback Description
BestModelCheckpointer Save best model by metric
HistoryCallback Track and save training history to JSON
EMACallback Exponential moving average of weights
SWACallback Stochastic weight averaging
EarlyStoppingCallback Stop when metric stops improving
WandbCallback Log to Weights & Biases
CometCallback Log to Comet.ml
BatchSizeSchedulerCallback Batch size warmup
ScheduleLoggerCallback Log LR schedule phase transitions
ParameterCounter Print parameter counts at start
SpeedMonitor Track samples/sec
ProgressCallback Print epoch summaries
LRMonitor Log learning rates
TorchMetricsCallback Integrate torchmetrics

License

Apache-2.0

Tests

From repository root:

uv run pytest packages/coffeetrain/tests -q

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

coffeetrain-0.1.0.tar.gz (95.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

coffeetrain-0.1.0-py3-none-any.whl (39.1 kB view details)

Uploaded Python 3

File details

Details for the file coffeetrain-0.1.0.tar.gz.

File metadata

  • Download URL: coffeetrain-0.1.0.tar.gz
  • Upload date:
  • Size: 95.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for coffeetrain-0.1.0.tar.gz
Algorithm Hash digest
SHA256 8d27060cfd885f81b63c7d2c48a8c58ccfc732b94d046c10f8c7c25247e5d6b4
MD5 3dce1d76d92dfe461e166840a1820f83
BLAKE2b-256 fb17bb91744e3f83fa12b92d53854528f30bf796afcd462b5d5de2dfc18f2b1b

See more details on using hashes here.

Provenance

The following attestation bundles were made for coffeetrain-0.1.0.tar.gz:

Publisher: publish.yml on paul-english/coffeetrain

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file coffeetrain-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: coffeetrain-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 39.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for coffeetrain-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 06e90fc624750ce28adcba1b056572423a8850fca52348eb47b390484f6dd690
MD5 60f9a30668be461123ee6444d15fadc3
BLAKE2b-256 214e4c61951781c24afe4c985ed1ffbe9b17e18d88bc36165f7023d6adf83713

See more details on using hashes here.

Provenance

The following attestation bundles were made for coffeetrain-0.1.0-py3-none-any.whl:

Publisher: publish.yml on paul-english/coffeetrain

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page