Skip to main content

Helper to train deep neural networks

Project description

deep-trainer

Lint and Test

Baseline code to train deep neural networks. Currently only available for PyTorch Framework.

Install

Pip

$ pip install deep-trainer

Conda

Not yet available

Getting started

import torch
from deep_trainer import PytorchTrainer


# Datasets
trainset = #....
valset = #....
testset = #....

# Dataloaders
train_loader = torch.utils.data.DataLoader(trainset, 64, shuffle=True)
val_loader = torch.data.utils.DataLoader(valset, 256)
test_loader = torch.data.utils.DataLoader(testset, 256)

# Model & device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = #....
model.to(device)

# Optimizer & Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(trainset) * 50, 0.1)  # Decay by 10 every 50 epochs

# Criterion
criterion = torch.nn.CrossEntropyLoss()  # For classification for instance

# Training
trainer = PytorchTrainer(model, optimizer, scheduler, save_mode="small", device=device)
trainer.train(150, train_loader, criterion, val_loader=val_loader)

# Testing
trainer.load("experiments/checkpoints/best.ckpt")
trainer.evaluate(test_loader, criterion)

Example

example/example.py show how to train a PreActResNet with Deep Trainer.

Install the additional requirements and use it with:

$ # See hyperparameters available
$ python example.py -h
$
$ # Launch the default training
$ python example.py
$
$ # Once done (or during the training), look for default tensorboard logs
$ tensorboard --logdir experiments/logs/

This script is reaching around 94-95% accuracy on validation with Cifar10 and a PreActResNet18.
Here are the training logs:

Build and Deploy

$ python -m build
$ python -m twine upload dist/*

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deep_trainer-0.1.4.tar.gz (14.7 kB view details)

Uploaded Source

Built Distribution

deep_trainer-0.1.4-py3-none-any.whl (14.9 kB view details)

Uploaded Python 3

File details

Details for the file deep_trainer-0.1.4.tar.gz.

File metadata

  • Download URL: deep_trainer-0.1.4.tar.gz
  • Upload date:
  • Size: 14.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.19

File hashes

Hashes for deep_trainer-0.1.4.tar.gz
Algorithm Hash digest
SHA256 9c729cafec43f74d8675b8237648f0d8b35718520f87eb0df5c7bf89f8b39cd1
MD5 ccc5cd88bb7a5c3f746f6c6c907b5e11
BLAKE2b-256 f1e63ab0e6d3995f4c69edea908ddaf95679b92ef069666803533695db4f7b04

See more details on using hashes here.

File details

Details for the file deep_trainer-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: deep_trainer-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 14.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.19

File hashes

Hashes for deep_trainer-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 6ecf2db68030d4a59aaab1025ef7a539ea3eb5203f5b807d92441c1843bfdef2
MD5 d723d148d758ac8bced1177ea673d98b
BLAKE2b-256 8863f45e75f1bdd58d32c73ce9d4dfa2dc5ca403789cacc7b1f21ce2a9852718

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page