Skip to main content

Neural networks training pipeline based on PyTorch. Designed to standardize training process and to increase coding preformance

Project description

Neural Piepline

Neural networks training pipeline based on PyTorch. Designed to standardize training process and to increase coding preformance.

Build Status Coverage Status Maintainability

  • Core is about 2K lines, covered by tests, that you doesn't need to write again
  • Flexible and customizable training process
  • Checkpoints management and train process resuming (source and target device independent)
  • Metrics processing and visualization by builtin (tensorboard, Matplotlib) or custom monitors
  • Training best practices (e.g. learning rate decaying and hard negative mining)
  • Metrics logging and comparison (DVC compatible)

Train MNIST example:

This code run MNIST image classification with Tensorboard monitoring. Code based on PyTorch example.

See full example there.

from neural_pipeline.builtin.monitors.tensorboard import TensorboardMonitor
from neural_pipeline import DataProducer, AbstractDataset, TrainConfig, TrainStage,\
    ValidationStage, Trainer, FileStructManager

import torch
from torch import nn
from torchvision import datasets, transforms

class Net(nn.Module):
    # Network implementation

class MNISTDataset(AbstractDataset):
    transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def __init__(self, data_dir: str, is_train: bool):
        self.dataset = datasets.MNIST(data_dir, train=is_train, download=True)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        data, target = self.dataset[item]
        return {'data': self.transforms(data), 'target': target}

fsm = FileStructManager(base_dir='data', is_continue=False)
model = Net()

train_dataset = DataProducer([MNISTDataset('data/dataset', True)], batch_size=4, num_workers=2)
validation_dataset = DataProducer([MNISTDataset('data/dataset', False)], batch_size=4, num_workers=2)

train_config = TrainConfig([TrainStage(train_dataset), ValidationStage(validation_dataset)], torch.nn.NLLLoss(),
                           torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.5))

trainer = Trainer(model, train_config, fsm, torch.device('cuda:0')).set_epoch_num(50)
trainer.monitor_hub.add_monitor(TensorboardMonitor(fsm, is_continue=False))
trainer.train()

Installation:

PyPI version PyPI Downloads/Month PyPI Downloads

pip install neural-pipeline

For builtin module using install:

pip install tensorboardX matplotlib

Install latest version before it's published on PyPi

pip install -U git+https://github.com/toodef/neural-pipeline

Getting started:

Documentation

Documentation Status See the full documentation there

Data flow scheme: Data flow

See the examples

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

neural_pipeline-0.1.0-py3-none-any.whl (30.3 kB view details)

Uploaded Python 3

File details

Details for the file neural_pipeline-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: neural_pipeline-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 30.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.4

File hashes

Hashes for neural_pipeline-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 74a32a7fe0d33efb1ae36fc7a223a93139d9b1bd68ccb998ae3908357a02d8ac
MD5 3da235fbbbdbcf079f4a439d0114bbbb
BLAKE2b-256 f6ae440a1d20745d5de34c3a79605a63ad00669b93ffa0125f01535c9c72134c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page