Library build on top of pytorch to fuel productivity
Project description
torchfuel
Build on top of pytorch to fuel productivity.
Features
- Generic Trainer
- Classification Trainer (with cross-entropy loss)
- MSE Trainer
- Additional utility layers
- Better dataloaders (currently only for image datasets)
Classification Example
import os
import time
from collections import namedtuple
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
from torchfuel.data_loaders.image import ImageDataLoader
from torchfuel.trainers.classification import ClassificationTrainer
from torchfuel.transforms.noise import DropPixelNoiser
dl = ImageDataLoader(
train_data_folder='imgs/train',
eval_data_folder='imgs/eval',
pil_transformations=[transforms.RandomHorizontalFlip()]
tensor_transformations=[DropPixelNoiser()],
batch_size=64,
imagenet_format=True,
)
train_dataloader, eval_dataloader, n_classes = dl.prepare()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Model(...).to(device)
optimiser = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', patience=20)
trainer = ClassificationTrainer(device, model, optimiser, scheduler)
fitted_model = trainer.fit(epochs, train_dataloader, eval_dataloader)
How to install
Clone repository and run:
pip install .
Optionally (not up to date):
pip install torchfuel
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchfuel-0.1.2.tar.gz
(11.0 kB
view details)
Built Distribution
torchfuel-0.1.2-py3-none-any.whl
(19.9 kB
view details)
File details
Details for the file torchfuel-0.1.2.tar.gz
.
File metadata
- Download URL: torchfuel-0.1.2.tar.gz
- Upload date:
- Size: 11.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.0.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8b90080c5744f3ce94e48235c648a675a91278c84a9e2d308463411d32466db4 |
|
MD5 | b1e98332197cc27a471fd5fb514dbcfb |
|
BLAKE2b-256 | 60ce0822f9993617f56e0c1724f990757a72907f082629359bf5b7abf92ec5d1 |
File details
Details for the file torchfuel-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: torchfuel-0.1.2-py3-none-any.whl
- Upload date:
- Size: 19.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.0.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 11fcaefe375fd3788008c26144b6b7f59795a869c0d81a419aa9200734b5ed07 |
|
MD5 | 4ec49e25083e77bd842d1f51556769bf |
|
BLAKE2b-256 | 30f22735f13ee2eeb185f79072c4720507d74a3f0683316e65505b42a931cc57 |