Skip to main content

Functional & flexible neural network training with PyTorch.

Project description

So you want to train neural nets with PyTorch? Here are your options:

Enter torchtraining - we try to get what's best from both worlds while adding: explicitness, functional approach, easy extensions and freedom to structure your code!

All of that using single ** piping operator!

Version Docs Tests Coverage Style PyPI Python PyTorch Docker LOC
Version Documentation Tests codecov codebeat PyPI Python PyTorch Docker LOC

Tutorials

See tutorials to get a grasp of what's the fuss is all about:

  • Introduction - quick tour around functionalities with CIFAR100 classification and tensorboard.
  • GAN training - more advanced example and creating you own pipeline components.

Installation

See documentation for full list of extras (e.g. installation with integrations like horovod).

To just start you can install via pip:

pip install --user torchtraining

Why torchtraining?

There are a lot of training libraries around for a lot of frameworks. Why would you choose this one?

torchtraining fits you, not the other way around

We think it's impossible to squeeze user's code in an overly strict API. We are not trying to fit everything into a single... .fit() method (or Trainer god class, see 40! arguments in PyTorch-Lightning trainer). This approach has shown time and time again it does not work for more complicated use cases as one cannot foresee the endless possibilities of training neural network and data generation user might require. torchtraining gives you building blocks to calculate metrics, log results, distribute training instead.

Implement single forward instead of 40 methods

Implementing forward with data argument is all you will ever need (okay, accumulators also need calculate, but that's it), we add thin __call__. Compare that to PyTorch-Lightning's LightningModule (source code here)

  • training_step
  • training_step_end
  • training_epoch_end (repeat all the above for validation and test)
  • validation_end, test_end
  • configure_sync_batchnorm
  • configure_ddp
  • init_ddp_connection
  • configure_apex
  • configure_optimizers
  • optimizer_step
  • optimizer_zero_grad
  • tbptt_split_batch (?)
  • prepare_data
  • train_dataloader
  • tng_dataloader
  • test_dataloader
  • val_dataloader

This list could go on (and probably will grow even bigger as time passes). We believe in functional approach and using only what you need (a lot of decoupled building blocks instead of gigantic god classes trying to do everything). Once again: we can't foresee future and won't squash everything into single class.

Explicitness

You are offered building blocks and it's up to you what you want to use. Still, you are explicit about everything going on in your code, for example:

  • when, where and what to log to tensorboard
  • when and how often to run optimization
  • what neural network(s) go into what step
  • what data you choose to accumulate and how often
  • which component of your pipeline should log via loguru
  • and how to log (e.g. to stdout and file or maybe over the web?)

See introduction tutorial to see how it's done

Neural network != training

We don't think your neural network source code should be polluted with training. We think it's better to have data preparation in data.py module, optimizers in optimizers.py and so on. With torchtraining you don't have to crunch any functionalities into single god class.

Nothing under the hood (almost)

~3000 lines of code (including comet-ml, neptune and horovod integration) and short functions/classes allow you to quickly dig into the source if you find something odd/not working. It's leverages what exists instead of reinventing the wheel.

PyTorch first

We don't force you to jump into and from numpy as most of the tasks can already be done in PyTorch. We are pytorch first. Unless we have to integrate third party tool... In that case you don't pay for this feature if you don't use it!

Easy integration with other tools

If we don't provide an integration out of the box, you can request it via issues or make your own PR. Any code you want can almost always be integrated via following steps:

  • make a new module (say amazing.py)
  • create new classes inheriting from torchtraining.Operation
  • implement forward for each operation which takes single argument data which can be anything (Tuple, List, torch.Tensor, str, whatever really)
  • process this data in forward and return results
  • you have your own operator compatible with **!

Other tools integrate components by trying to squash them into their predefined APIs and/or trying to be smart and guess what the user does (which often fails). Here's how we do:

Example of integration of neptune image logging:

import torchtraining as tt

class Image(tt.Operation):
    def __init__(
        self,
        experiment,
        log_name: str,
        image_name: str = None,
        description: str = None,
        timestamp=None,
        experiment=None,
    ):
        super().__init__()
        self.experiment = experiment
        self.log_name = log_name
        self.image_name = image_name
        self.description = description
        self.timestamp = timestamp

    # Always forward some data so it can be reused
    def forward(self, data):
        self.experiment.log_image(
            self.log_name, data, self.image_name, self.description, self.timestamp
        )
        return data

Contributing

This project is currently in it's infancy and we would love to get some help from you! You can find current ideas inside issues tagged by [DISCUSSION] (see here).

Also feel free to make your own feature requests and give us your thoughts in issues!

Remember: It's only 0.0.1 version, direction is there but you can be sure to encounter a lot of bugs along the way at the moment

Why ** as an operator?

Indeed, operators like |, >> or > would be way more intuitive, but:

  • Those are left associative and would require users to explicitly uses parentheses around pipes
  • > cannot be piped as easily
  • Way more complicated code on our side to handle >> or |

Currently ** seems like a reasonable trade-off, still it may be subject to change in future.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchtraining-nightly-1598400642.tar.gz (57.5 kB view details)

Uploaded Source

Built Distribution

torchtraining_nightly-1598400642-py3-none-any.whl (82.1 kB view details)

Uploaded Python 3

File details

Details for the file torchtraining-nightly-1598400642.tar.gz.

File metadata

  • Download URL: torchtraining-nightly-1598400642.tar.gz
  • Upload date:
  • Size: 57.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.3.1 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for torchtraining-nightly-1598400642.tar.gz
Algorithm Hash digest
SHA256 9a4f77671e7915e2fe342da06126b6e96e945bbec6d3ed6dc1da985b255fd896
MD5 cf950bb107b257e239395e330e10be88
BLAKE2b-256 cee7a8106cdee764f1e50fb4ef0dc4e5b893709dbc547bd21ebb101edef690a1

See more details on using hashes here.

File details

Details for the file torchtraining_nightly-1598400642-py3-none-any.whl.

File metadata

  • Download URL: torchtraining_nightly-1598400642-py3-none-any.whl
  • Upload date:
  • Size: 82.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.3.1 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for torchtraining_nightly-1598400642-py3-none-any.whl
Algorithm Hash digest
SHA256 852e6f7ad9821bf2527aa6fef992a6cf5a98ae86d39902a5a907ef98d2a3ed8b
MD5 a18952d5d6357756677b9d79441095d6
BLAKE2b-256 59a957d5ad82368c4c97dedc24f3b2adbdef92936fdea9287545386f6d00829a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page