Skip to main content

Helper to train deep neural networks

Project description

deep-trainer

Baseline code to train deep neural networks. Currently only available for PyTorch Framework.

Install

Pip

$ pip install deep-trainer

Conda

Not yet available

Getting started

import torch
from deep_trainer import PytorchTrainer


# Datasets
trainset = #....
valset = #....
testset = #....

# Dataloaders
train_loader = torch.utils.data.DataLoader(trainset, 64, shuffle=True)
val_loader = torch.data.utils.DataLoader(valset, 256)
test_loader = torch.data.utils.DataLoader(testset, 256)

# Model & device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = #....
model.to(device)

# Optimizer & Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(trainset) * 50, 0.1)  # Decay by 10 every 50 epochs

# Criterion
criterion = torch.nn.CrossEntropyLoss()  # For classification for instance

# Training
trainer = PytorchTrainer(model, optimizer, scheduler, save_mode="small", device=device)
trainer.train(150, train_loader, criterion, val_loader=val_loader)

# Testing
trainer.load("experiments/checkpoints/best.ckpt")
trainer.evaluate(test_loader, criterion)

Build and Deploy

$ pip install build twine
$ python -m build
$ python -m twine upload dist/*

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deep-trainer-0.0.10.tar.gz (11.0 kB view details)

Uploaded Source

Built Distribution

deep_trainer-0.0.10-py3-none-any.whl (11.4 kB view details)

Uploaded Python 3

File details

Details for the file deep-trainer-0.0.10.tar.gz.

File metadata

  • Download URL: deep-trainer-0.0.10.tar.gz
  • Upload date:
  • Size: 11.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.8

File hashes

Hashes for deep-trainer-0.0.10.tar.gz
Algorithm Hash digest
SHA256 08163193c38310c5226acd0ce2faeff8f397709c4971e83fdca55bb187d38b1f
MD5 bcd2ce40f9dbf318bcf34abf267f99c7
BLAKE2b-256 6823e06b40b04a6eca90620d94ca1b285fb94af97a7125e4f0ef13120af26cf4

See more details on using hashes here.

File details

Details for the file deep_trainer-0.0.10-py3-none-any.whl.

File metadata

  • Download URL: deep_trainer-0.0.10-py3-none-any.whl
  • Upload date:
  • Size: 11.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.8

File hashes

Hashes for deep_trainer-0.0.10-py3-none-any.whl
Algorithm Hash digest
SHA256 f8db9f80d7046597548271e5bc0d7d79bff0c247befa2f40b186df3ad0c50f8f
MD5 ee2587afac0daa3126e48dcd8453c45b
BLAKE2b-256 df9437cbccf3f3ddd6f0fc003458df354d44e35639ac9d754d3216aba2be9548

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page