Lightweight event-driven PyTorch trainer with composable callbacks
Project description
coffeetrain
Lightweight event-driven PyTorch trainer with composable callbacks. Inspired by MosaicML Composer but implemented fewer external dependencies and allowing a newer PyTorch version (2.10 as of this point).
Features
- Event-driven lifecycle:
fit_start,epoch_start,batch_start,before_forward,after_forward,before_loss,after_loss,before_backward,after_backward,batch_end,eval_*, etc. - Composable callbacks: EMA, SWA, checkpointing, W&B, Comet, early stopping, batch size scheduling, and more.
- TrainerModel protocol: Simple interface (
forward,loss) for model integration. - Accelerate support: Distributed training via HuggingFace Accelerator.
Installation
pip install coffeetrain
Optional extras:
pip install coffeetrain[wandb,comet,optimi]
Quick Start
from coffeetrain import Trainer, CosineWarmupScheduler, HistoryCallback, BestModelCheckpointer
from coffeetrain import create_optimizer
from coffeetrain.optimizers import OptimizerConfig
model = MyModel()
optimizer = create_optimizer(model.parameters(), OptimizerConfig(name="adamw", lr=1e-4, weight_decay=0.01))
scheduler = CosineWarmupScheduler(optimizer, warmup_steps=100, total_steps=1000)
trainer = Trainer(
model=model,
train_dataloader=train_loader,
optimizers=optimizer,
schedulers=scheduler,
max_epochs=10,
callbacks=[
HistoryCallback(save_dir="output"),
BestModelCheckpointer(save_dir="output", metric_name="loss", mode="min"),
],
)
trainer.fit()
Callbacks
| Callback | Description |
|---|---|
BestModelCheckpointer |
Save best model by metric |
HistoryCallback |
Track and save training history to JSON |
EMACallback |
Exponential moving average of weights |
SWACallback |
Stochastic weight averaging |
EarlyStoppingCallback |
Stop when metric stops improving |
WandbCallback |
Log to Weights & Biases |
CometCallback |
Log to Comet.ml |
BatchSizeSchedulerCallback |
Batch size warmup |
ScheduleLoggerCallback |
Log LR schedule phase transitions |
ParameterCounter |
Print parameter counts at start |
SpeedMonitor |
Track samples/sec |
ProgressCallback |
Print epoch summaries |
LRMonitor |
Log learning rates |
TorchMetricsCallback |
Integrate torchmetrics |
License
Apache-2.0
Tests
From repository root:
uv run pytest packages/coffeetrain/tests -q
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file coffeetrain-0.1.0.tar.gz.
File metadata
- Download URL: coffeetrain-0.1.0.tar.gz
- Upload date:
- Size: 95.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8d27060cfd885f81b63c7d2c48a8c58ccfc732b94d046c10f8c7c25247e5d6b4
|
|
| MD5 |
3dce1d76d92dfe461e166840a1820f83
|
|
| BLAKE2b-256 |
fb17bb91744e3f83fa12b92d53854528f30bf796afcd462b5d5de2dfc18f2b1b
|
Provenance
The following attestation bundles were made for coffeetrain-0.1.0.tar.gz:
Publisher:
publish.yml on paul-english/coffeetrain
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
coffeetrain-0.1.0.tar.gz -
Subject digest:
8d27060cfd885f81b63c7d2c48a8c58ccfc732b94d046c10f8c7c25247e5d6b4 - Sigstore transparency entry: 1097009707
- Sigstore integration time:
-
Permalink:
paul-english/coffeetrain@0a9382456c14a82a4f635da4e53e4a2cffb36ef1 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/paul-english
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@0a9382456c14a82a4f635da4e53e4a2cffb36ef1 -
Trigger Event:
push
-
Statement type:
File details
Details for the file coffeetrain-0.1.0-py3-none-any.whl.
File metadata
- Download URL: coffeetrain-0.1.0-py3-none-any.whl
- Upload date:
- Size: 39.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
06e90fc624750ce28adcba1b056572423a8850fca52348eb47b390484f6dd690
|
|
| MD5 |
60f9a30668be461123ee6444d15fadc3
|
|
| BLAKE2b-256 |
214e4c61951781c24afe4c985ed1ffbe9b17e18d88bc36165f7023d6adf83713
|
Provenance
The following attestation bundles were made for coffeetrain-0.1.0-py3-none-any.whl:
Publisher:
publish.yml on paul-english/coffeetrain
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
coffeetrain-0.1.0-py3-none-any.whl -
Subject digest:
06e90fc624750ce28adcba1b056572423a8850fca52348eb47b390484f6dd690 - Sigstore transparency entry: 1097009714
- Sigstore integration time:
-
Permalink:
paul-english/coffeetrain@0a9382456c14a82a4f635da4e53e4a2cffb36ef1 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/paul-english
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@0a9382456c14a82a4f635da4e53e4a2cffb36ef1 -
Trigger Event:
push
-
Statement type: