Skip to main content

The flexible training toolbox

Project description

Trainable: The Flexible PyTorch Training Toolbox

If you're sick of dealing with all of the boilerplate code involved in training, evaluation, visualization, and preserving your models, then you're in luck. Trainable offers a simple, yet extensible framework to make understanding the latest papers the only headache of Neural Network training.

Installation

pip install trainable

Usage

The typical workflow for trainable involves defining a callable Algorithm to describe how to train your network on a batch, and how you'd like to label your losses:

class MSEAlgorithm(Algorithm):
    def __init__(self, eval=False, **args):
        super().__init__(eval)
        self.mse = nn.MSELoss()

    def __call__(self, model, batch, device):
        x, target = batch
        x, target = x.to(device), target.to(device)
        y = model(x)

        loss = self.mse(y, target)
        loss.backward()

        metrics = { self.key("MSE Loss"):loss.item() }
        return metrics

Then you simply instantiate your model, dataset, and optimizer...

device = torch.device('cuda')

model = MyModel().to(device)
optim = FancyPantsOptimizer(model.parameters(), lr=1e-4)

train_data = DataLoader(SomeTrainingDataset('path/to/your/data'), batch_size=32)
test_data = DataLoader(SomeTestingDataset('path/to/your/data'), batch_size=32)

...and let trainable take care of the rest!

trainer = Trainer(
  visualizer=MyVisualizer(),  # Typically Plotter() or Saver()
  train_alg=MyFancyAlgorithm(),
  test_alg=MyFancyAlgorithm(eval=True)
  display_freq=1,
  visualize_freq=10,
  validate_freq=10,
  autosave_freq=10,
  device=device
)

save_path = "desired/save/path/for/your/session.sesh"
trainer.start_session(model, optim, path)

trainer.name_session('Name')

trainer.describe_session("""
A beautiful detailed description of what the heck 
you were trying to accomplish with this training.
""")

metrics = trainer.train(train_data, test_data, epochs=200)

Plotting your data is simple as well:

import matplotlib.pyplot as plt

for key in metrics:
    plt.plot(metrics[key])
    plt.show()

Tunable Options

The Trainer interface gives you a nice handful of options to configure your training experience. They include:

  • Display Frequency: How often (in batches) information such as your training loss is updated in your progress bar.
  • Visualization Frequency: How often (in batches) the training produces a visualization of your model's outputs.
  • Validation Frequency: How often (in epochs) the trainer performs validation with your test data.
  • Autosave Frequency: How often your session is saved out to disk.
  • Device: On which hardware your training should occur.

Customization

Do you want a little more granularity in how you visualize your data? Or perhaps running an epoch with your model is a little more involved than just training on each batch of data? Wondering why the heck pytorch doesn't have a built-in dataset for unsupervised images? Maybe your training algorithm involves VGG? Got you covered. Check out the source for the various submodules:

Contributing

Find any other headaches in neural net training that you think you can simplify with Trainable? Feel free to make a pull request from my github repo.

Contact

Email me anytime at jeffhilton.code@gmail.com.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trainable-0.1.4.dev16.tar.gz (21.8 kB view details)

Uploaded Source

Built Distribution

trainable-0.1.4.dev16-py3-none-any.whl (27.3 kB view details)

Uploaded Python 3

File details

Details for the file trainable-0.1.4.dev16.tar.gz.

File metadata

  • Download URL: trainable-0.1.4.dev16.tar.gz
  • Upload date:
  • Size: 21.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for trainable-0.1.4.dev16.tar.gz
Algorithm Hash digest
SHA256 61fe0a1959f1c4110515df0c3db7f59ccd569847e69696a45471022dbbb03d08
MD5 f138b94ee09cf935dfd7012d93c8e15e
BLAKE2b-256 923f01182e330c2cbd56af83e8abfa803c8bdfdf4c1acbd5d22be09ebb8c922f

See more details on using hashes here.

File details

Details for the file trainable-0.1.4.dev16-py3-none-any.whl.

File metadata

  • Download URL: trainable-0.1.4.dev16-py3-none-any.whl
  • Upload date:
  • Size: 27.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for trainable-0.1.4.dev16-py3-none-any.whl
Algorithm Hash digest
SHA256 ffb57a792986eb2fedbb928bb36837c792d1e00a49c14dcdd1b25a3d30a5cbca
MD5 6e1f8455e74369bc827ccb341c1d2005
BLAKE2b-256 398e56f2da169d9fd126eed34826689cf3f6450da5c6435aeaf0f6428cd57a43

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page