Helper to train deep neural networks
Project description
deep-trainer
Baseline code to train deep neural networks. Currently only available for PyTorch Framework.
Install
Pip
$ pip install deep-trainer
Conda
Not yet available
Getting started
import torch
from deep_trainer import PytorchTrainer
# Datasets
trainset = #....
valset = #....
testset = #....
# Dataloaders
train_loader = torch.utils.data.DataLoader(trainset, 64, shuffle=True)
val_loader = torch.data.utils.DataLoader(valset, 256)
test_loader = torch.data.utils.DataLoader(testset, 256)
# Model & device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = #....
model.to(device)
# Optimizer & Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(trainset) * 50, 0.1) # Decay by 10 every 50 epochs
# Criterion
criterion = torch.nn.CrossEntropyLoss() # For classification for instance
# Training
trainer = PytorchTrainer(model, optimizer, scheduler, save_mode="small", device=device)
trainer.train(150, train_loader, criterion, val_loader=val_loader)
# Testing
trainer.load("experiments/checkpoints/best.ckpt")
trainer.evaluate(test_loader, criterion)
Example
example/example.py
show how to train a PreActResNet with Deep Trainer.
Install the additional requirements and use it with:
$ # See hyperparameters available
$ python example.py -h
$
$ # Launch the default training
$ python example.py
$
$ # Once done (or during the training), look for default tensorboard logs
$ tensorboard --logdir experiments/logs/
This script is reaching around 94-95% accuracy on validation with Cifar10 and a PreActResNet18.
Here are the training logs:
Build and Deploy
$ pip install build twine
$ python -m build
$ python -m twine upload dist/*
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
deep-trainer-0.0.13.dev0.tar.gz
(12.8 kB
view details)
Built Distribution
File details
Details for the file deep-trainer-0.0.13.dev0.tar.gz
.
File metadata
- Download URL: deep-trainer-0.0.13.dev0.tar.gz
- Upload date:
- Size: 12.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.8.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8184715cfaaf8252ecb3eaa790caa0db050590ed411f0d4c1fd25a8d195fd1d1 |
|
MD5 | 33ab519e4ad72150af8ea656fff884db |
|
BLAKE2b-256 | 1d6dfc347e0beb01822fa83b24f53f3b095d7325f568710fa9d85bd4f68e9c52 |
File details
Details for the file deep_trainer-0.0.13.dev0-py3-none-any.whl
.
File metadata
- Download URL: deep_trainer-0.0.13.dev0-py3-none-any.whl
- Upload date:
- Size: 13.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.8.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d9c5004aa5d5b226b6b9544660394df5eda2a530b56d9004cddadd84bd7ca374 |
|
MD5 | 81bfa42dccd4a05c20c755ead33427a1 |
|
BLAKE2b-256 | a26415894748c67b017793abcb2fcabc082e1d491ecf5121539fd59eee810494 |