Skip to main content

Helper to train deep neural networks

Project description

deep-trainer

Baseline code to train deep neural networks. Currently only available for PyTorch Framework.

Install

Pip

$ pip install deep-trainer

Conda

Not yet available

Getting started

import torch
from deep_trainer import PytorchTrainer


# Datasets
trainset = #....
valset = #....
testset = #....

# Dataloaders
train_loader = torch.utils.data.DataLoader(trainset, 64, shuffle=True)
val_loader = torch.data.utils.DataLoader(valset, 256)
test_loader = torch.data.utils.DataLoader(testset, 256)

# Model & device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = #....
model.to(device)

# Optimizer & Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(trainset) * 50, 0.1)  # Decay by 10 every 50 epochs

# Criterion
criterion = torch.nn.CrossEntropyLoss()  # For classification for instance

# Training
trainer = PytorchTrainer(model, optimizer, scheduler, save_mode="small", device=device)
trainer.train(150, train_loader, criterion, val_loader=val_loader)

# Testing
trainer.load("experiments/checkpoints/best.ckpt")
trainer.evaluate(test_loader, criterion)

Example

example/example.py show how to train a PreActResNet with Deep Trainer.

Install the additional requirements and use it with:

$ # See hyperparameters available
$ python example.py -h
$
$ # Launch the default training
$ python example.py
$
$ # Once done (or during the training), look for default tensorboard logs
$ tensorboard --logdir experiments/logs/

This script is reaching around 94-95% accuracy on validation with Cifar10 and a PreActResNet18.
Here are the training logs:

Build and Deploy

$ python -m build
$ python -m twine upload dist/*

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deep-trainer-0.1.1.tar.gz (14.5 kB view details)

Uploaded Source

Built Distribution

deep_trainer-0.1.1-py3-none-any.whl (14.8 kB view details)

Uploaded Python 3

File details

Details for the file deep-trainer-0.1.1.tar.gz.

File metadata

  • Download URL: deep-trainer-0.1.1.tar.gz
  • Upload date:
  • Size: 14.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.13

File hashes

Hashes for deep-trainer-0.1.1.tar.gz
Algorithm Hash digest
SHA256 5d0ff96445a80a2b4de4001510acf58981d5c2af83b6fbd62c6ed5f6bed90378
MD5 0cdf7cbe4651c54cdcf018536b89dd04
BLAKE2b-256 ea8bfb5abf7b208d3b0563f9601447d111155067cf055b2c1f0abe6fd391e0bc

See more details on using hashes here.

File details

Details for the file deep_trainer-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: deep_trainer-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 14.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.13

File hashes

Hashes for deep_trainer-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 680109a16475a3949977b3c80a5297130421d8f4975606fba5caa586b8f0f910
MD5 d647727dc8d8ee2b191a925ea3e58e75
BLAKE2b-256 4b01e21f476ecc3518aac33a33181e5a92ac79ea0e1ac47c1c985268b615efbf

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page