General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui.
Project description
👟 Trainer
An opinionated general purpose model trainer on PyTorch with a simple code base.
Installation
From Github:
git clone https://github.com/coqui-ai/Trainer
cd Trainer
make install
From PyPI:
pip install trainer
Prefer installing from Github as it is more stable.
Implementing a model
Subclass and overload the functions in the TrainerModel()
Training a model
See the test script here training a basic MNIST model.
Training with DDP
$ python -m trainer.distribute --script path/to/your/train.py --gpus "0,1"
We don't use .spawn()
to initiate multi-gpu training since it causes certain limitations.
- Everything must the pickable.
.spawn()
trains the model in subprocesses and the model in the main process is not updated.- DataLoader with N processes gets really slow when the N is large.
Profiling example
- Create the torch profiler as you like and pass it to the trainer.
import torch profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler/"), record_shapes=True, profile_memory=True, with_stack=True, ) prof = trainer.profile_fit(profiler, epochs=1, small_run=64) then run Tensorboard
- Run the tensorboard.
tensorboard --logdir="./profiler/"
Supported Experiment Loggers
- Tensorboard - actively maintained
- ClearML - actively maintained
- MLFlow
- Aim
- WandDB
To add a new logger, you must subclass BaseDashboardLogger and overload its functions.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
trainer-0.0.15.tar.gz
(37.2 kB
view hashes)
Built Distribution
trainer-0.0.15-py3-none-any.whl
(41.8 kB
view hashes)