Skip to main content

Helper to train deep neural networks

Project description

deep-trainer

Baseline code to train deep neural networks. Currently only available for PyTorch Framework.

Install

Pip

$ pip install deep-trainer

Conda

Not yet available

Getting started

import torch
from deep_trainer import PytorchTrainer


# Datasets
trainset = #....
valset = #....
testset = #....

# Dataloaders
train_loader = torch.utils.data.DataLoader(trainset, 64, shuffle=True)
val_loader = torch.data.utils.DataLoader(valset, 256)
test_loader = torch.data.utils.DataLoader(testset, 256)

# Model & device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = #....
model.to(device)

# Optimizer & Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(trainset) * 50, 0.1)  # Decay by 10 every 50 epochs

# Criterion
criterion = torch.nn.CrossEntropyLoss()  # For classification for instance

# Training
trainer = PytorchTrainer(model, optimizer, scheduler, save_mode="small", device=device)
trainer.train(150, train_loader, criterion, val_loader=val_loader)

# Testing
trainer.load("experiments/checkpoints/best.ckpt")
trainer.evaluate(test_loader, criterion)

Example

example/example.py show how to train a PreActResNet with Deep Trainer.

Install the additional requirements and use it with:

$ # See hyperparameters available
$ python example.py -h
$
$ # Launch the default training
$ python example.py
$
$ # Once done (or during the training), look for default tensorboard logs
$ tensorboard --logdir experiments/logs/

This script is reaching around 94-95% accuracy on validation with Cifar10 and a PreActResNet18.
Here are the training logs:

Build and Deploy

$ pip install build twine
$ python -m build
$ python -m twine upload dist/*

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deep-trainer-0.1.0.tar.gz (14.4 kB view details)

Uploaded Source

Built Distribution

deep_trainer-0.1.0-py3-none-any.whl (14.8 kB view details)

Uploaded Python 3

File details

Details for the file deep-trainer-0.1.0.tar.gz.

File metadata

  • Download URL: deep-trainer-0.1.0.tar.gz
  • Upload date:
  • Size: 14.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.8.13

File hashes

Hashes for deep-trainer-0.1.0.tar.gz
Algorithm Hash digest
SHA256 92402c015f518f4c79f232c0e5879eb1dfc02acc242499a6163b3bbff195c659
MD5 7c162a1bec628aa0b6f17603aa60caa5
BLAKE2b-256 3bbb7b433a87b5ddb3a7e23b6c5219ae120ed9a789d8bec03402d8287769cd68

See more details on using hashes here.

File details

Details for the file deep_trainer-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: deep_trainer-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 14.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.8.13

File hashes

Hashes for deep_trainer-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fb9afb6921b26d4049d52437405edc9741ee16a67e9b4cd97507c7d93515c681
MD5 548b8ce4269b8c6bf194224cfe8f38db
BLAKE2b-256 f3e27a285e552ba273b826c6dfaf1624ad465733f2e5ac0ad4df1b6793d94597

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page