Skip to main content

Helper to train deep neural networks

Project description

deep-trainer

Lint and Test

Baseline code to train deep neural networks. Currently only available for PyTorch Framework.

Install

Pip

$ pip install deep-trainer

Conda

Not yet available

Getting started

import torch
from deep_trainer import PytorchTrainer


# Datasets
trainset = #....
valset = #....
testset = #....

# Dataloaders
train_loader = torch.utils.data.DataLoader(trainset, 64, shuffle=True)
val_loader = torch.data.utils.DataLoader(valset, 256)
test_loader = torch.data.utils.DataLoader(testset, 256)

# Model & device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = #....
model.to(device)

# Optimizer & Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(trainset) * 50, 0.1)  # Decay by 10 every 50 epochs

# Criterion
criterion = torch.nn.CrossEntropyLoss()  # For classification for instance

# Training
trainer = PytorchTrainer(model, optimizer, scheduler, save_mode="small", device=device)
trainer.train(150, train_loader, criterion, val_loader=val_loader)

# Testing
trainer.load("experiments/checkpoints/best.ckpt")
trainer.evaluate(test_loader, criterion)

Example

example/example.py shows how to train a PreActResNet with Deep Trainer.

Install the additional requirements and use it with:

$ # See hyperparameters available
$ python example.py -h
$
$ # Launch the default training
$ python example.py
$
$ # Once done (or during the training), look for default tensorboard logs
$ tensorboard --logdir experiments/logs/

This script is reaching around 94-95% accuracy on validation with Cifar10 and a PreActResNet18.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deep_trainer-0.1.5.tar.gz (14.7 kB view details)

Uploaded Source

Built Distribution

deep_trainer-0.1.5-py3-none-any.whl (15.4 kB view details)

Uploaded Python 3

File details

Details for the file deep_trainer-0.1.5.tar.gz.

File metadata

  • Download URL: deep_trainer-0.1.5.tar.gz
  • Upload date:
  • Size: 14.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for deep_trainer-0.1.5.tar.gz
Algorithm Hash digest
SHA256 ec0cea4b047c5578418aa90934df7baaf9878b7247813bde33b2b1a517eb9d7f
MD5 94d1c09bc58a7e3c8df2ec2a11051071
BLAKE2b-256 9f22df79a356781e66604d3f8078e55f0fca01159735bf3c9349b44ce2cc6eda

See more details on using hashes here.

File details

Details for the file deep_trainer-0.1.5-py3-none-any.whl.

File metadata

File hashes

Hashes for deep_trainer-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 1a6835709b295794aab72a2a62f4c7b4f3e1390af759682885f0226a599233fd
MD5 3e144d30d85eefea0483ac75e1f82433
BLAKE2b-256 78aec96369cd99f1d0081d3f7b4de2761ec09942223ed19975105b4626a4c5da

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page