Skip to main content

Inferno is a little library providing utilities and convenience functions/classes around PyTorch.

Project description

=======
Inferno
=======



.. image:: https://img.shields.io/pypi/v/inferno.svg
:target: https://pypi.python.org/pypi/pytorch-inferno

.. image:: https://img.shields.io/travis/nasimrahaman/inferno.svg
:target: https://travis-ci.org/nasimrahaman/inferno

.. image:: https://readthedocs.org/projects/inferno-pytorch/badge/?version=latest
:target: http://inferno-pytorch.readthedocs.io/en/latest/?badge=latest
:alt: Documentation Status

.. image:: https://pyup.io/repos/github/nasimrahaman/inferno/shield.svg
:target: https://pyup.io/repos/github/nasimrahaman/inferno/
:alt: Updates



.. image:: http://svgshare.com/i/2j7.svg





Inferno is a little library providing utilities and convenience functions/classes around
`PyTorch <https://github.com/pytorch/pytorch>`_.
It's a work-in-progress, but the first release is underway!



* Free software: Apache Software License 2.0
* Documentation: https://pytorch-inferno.readthedocs.io (Work in progress).


Features
--------

Current features include:
* a basic
`Trainer class <https://github.com/nasimrahaman/inferno/tree/master/docs#preparing-the-trainer>`_
to encapsulate the training boilerplate (iteration/epoch loops, validation and checkpoint creation),
* a `graph API <https://github.com/nasimrahaman/inferno/blob/master/inferno/extensions/containers/graph.py>`_ for building models with complex architectures, powered by `networkx <https://github.com/networkx/networkx>`_.
* `easy data-parallelism <https://github.com/nasimrahaman/inferno/tree/master/docs#using-gpus>`_ over multiple GPUs,
* `a submodule <https://github.com/nasimrahaman/inferno/blob/master/inferno/extensions/initializers>`_ for `torch.nn.Module`-level parameter initialization,
* `a submodule <https://github.com/nasimrahaman/inferno/blob/master/inferno/io/transform>`_ for data preprocessing / transforms,
* `support <https://github.com/nasimrahaman/inferno/tree/master/docs#using-tensorboard>`_ for `Tensorboard <https://www.tensorflow.org/get_started/summaries_and_tensorboard>`_ (best with atleast `tensorflow-cpu <https://github.com/tensorflow/tensorflow>`_ installed)
* `a callback API <https://github.com/nasimrahaman/inferno/tree/master/docs#setting-up-callbacks>`_ to enable flexible interaction with the trainer,
* `various utility layers <https://github.com/nasimrahaman/inferno/tree/master/inferno/extensions/layers>`_ with more underway,
* `a submodule <https://github.com/nasimrahaman/inferno/blob/master/inferno/io/volumetric>`_ for volumetric datasets, and more!





.. code:: python

import torch.nn as nn
from inferno.io.box.cifar10 import get_cifar10_loaders
from inferno.trainers.basic import Trainer
from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
from inferno.extensions.layers.convolutional import ConvELU2D
from inferno.extensions.layers.reshape import Flatten

# Fill these in:
LOG_DIRECTORY = '...'
SAVE_DIRECTORY = '...'
DATASET_DIRECTORY = '...'
DOWNLOAD_CIFAR = True
USE_CUDA = True

# Build torch model
model = nn.Sequential(
ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
Flatten(),
nn.Linear(in_features=(256 * 4 * 4), out_features=10),
nn.Softmax()
)

# Load loaders
train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY,
download=DOWNLOAD_CIFAR)

# Build trainer
trainer = Trainer(model) \
.build_criterion('CrossEntropyLoss') \
.build_metric('CategoricalError') \
.build_optimizer('Adam') \
.validate_every((2, 'epochs')) \
.save_every((5, 'epochs')) \
.save_to_directory(SAVE_DIRECTORY) \
.set_max_num_epochs(10) \
.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
log_images_every='never'),
log_directory=LOG_DIRECTORY)

# Bind loaders
trainer \
.bind_loader('train', train_loader) \
.bind_loader('validate', validate_loader)

if USE_CUDA:
trainer.cuda()

# Go!
trainer.fit()




To visualize the training progress, navigate to `LOG_DIRECTORY` and fire up tensorboard with

.. code:: bash

$ tensorboard --logdir=${PWD} --port=6007


and navigate to `localhost:6007` with your browser.



Future Features:
------------------------
Planned features include:
* a class to encapsulate Hogwild! training over multiple GPUs,
* minimal shape inference with a dry-run,
* proper packaging and documentation,
* cutting-edge fresh-off-the-press implementations of what the future has in store. :)



Credits
---------
All contributors are listed here_.

.. _here: https://pytorch-inferno.readthedocs.io/en/latest/authors.html

This packag was partially generated with Cookiecutter_ and the `audreyr/cookiecutter-pypackage`_ project template + lots of work by Thorsten.

.. _Cookiecutter: https://github.com/audreyr/cookiecutter
.. _`audreyr/cookiecutter-pypackage`: https://github.com/audreyr/cookiecutter-pypackage



=======
History
=======

0.1.0 (2017-08-24)
------------------

* First early release on PyPI

0.1.1 (2017-08-24)
------------------

* Version Increment

0.1.2 (2017-08-24)
------------------

* Version Increment


0.1.3 (2017-08-24)
------------------

* Updated Documentation

0.1.4 (2017-08-24)
------------------

* travis auto-deployment on pypi


0.1.5 (2017-08-24)
------------------

* travis changes to run unittest


0.1.6 (2017-08-24)
------------------

* travis missing packages for unittesting
* fixed inconsistent version numbers

0.1.7 (2017-08-25)
------------------

* setup.py critical bugix in install procedure

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

inferno-pytorch-0.1.7.tar.gz (88.3 kB view details)

Uploaded Source

File details

Details for the file inferno-pytorch-0.1.7.tar.gz.

File metadata

File hashes

Hashes for inferno-pytorch-0.1.7.tar.gz
Algorithm Hash digest
SHA256 4ff7667e90be50b19d3c215c6f40666d9d527299d8f3d1cc739d99434894cc31
MD5 be790148948125ccfe686fecfebacb9d
BLAKE2b-256 5b09487465b22d0a2e3d1362708da1d207ec4f53727d39007aa97af8f638d92b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page