Skip to main content

The flexible training toolbox

Project description

Trainable: The Flexible PyTorch Training Toolbox

If you're sick of dealing with all of the boilerplate code involved in training, evaluation, visualization, and preserving your models, then you're in luck. Trainable offers a simple, yet extensible framework to make understanding the latest papers the only headache of Neural Network training.

Installation

pip install trainable

Usage

The typical workflow for trainable involves defining a callable Algorithm to describe how to train your network on a batch, and how you'd like to label your losses:

class MSEAlgorithm(Algorithm):
    def __init__(self, eval=False, **args):
        super().__init__(eval)
        self.mse = nn.MSELoss()

    def __call__(self, model, batch, device):
        x, target = batch
        x, target = x.to(device), target.to(device)
        y = model(x)

        loss = self.mse(y, target)
        loss.backward()

        metrics = { self.key("MSE Loss"):loss.item() }
        return metrics

Then you simply instantiate your model, dataset, and optimizer...

device = torch.device('cuda')

model = MyModel().to(device)
optim = FancyPantsOptimizer(model.parameters(), lr=1e-4)

train_data = DataLoader(SomeTrainingDataset('path/to/your/data'), batch_size=32)
test_data = DataLoader(SomeTestingDataset('path/to/your/data'), batch_size=32)

...and let trainable take care of the rest!

trainer = Trainer(
  visualizer=MyVisualizer(),  # Typically Plotter() or Saver()
  train_alg=MyFancyAlgorithm(),
  test_alg=MyFancyAlgorithm(eval=True)
  display_freq=1,
  visualize_freq=10,
  validate_freq=10,
  autosave_freq=10,
  device=device
)

save_path = "desired/save/path/for/your/session.sesh"
trainer.start_session(model, optim, path)

trainer.name_session('Name')

trainer.describe_session("""
A beautiful detailed description of what the heck 
you were trying to accomplish with this training.
""")

metrics = trainer.train(train_data, test_data, epochs=200)

Plotting your data is simple as well:

import matplotlib.pyplot as plt

for key in metrics:
    plt.plot(metrics[key])
    plt.show()

Tunable Options

The Trainer interface gives you a nice handful of options to configure your training experience. They include:

  • Display Frequency: How often (in batches) information such as your training loss is updated in your progress bar.
  • Visualization Frequency: How often (in batches) the training produces a visualization of your model's outputs.
  • Validation Frequency: How often (in epochs) the trainer performs validation with your test data.
  • Autosave Frequency: How often your session is saved out to disk.
  • Device: On which hardware your training should occur.

Customization

Do you want a little more granularity in how you visualize your data? Or perhaps running an epoch with your model is a little more involved than just training on each batch of data? Wondering why the heck pytorch doesn't have a built-in dataset for unsupervised images? Maybe your training algorithm involves VGG? Got you covered. Check out the source for the various submodules:

Contributing

Find any other headaches in neural net training that you think you can simplify with Trainable? Feel free to make a pull request from my github repo.

Contact

Email me anytime at jeffhilton.code@gmail.com.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trainable-0.1.4.dev11.tar.gz (21.7 kB view details)

Uploaded Source

Built Distribution

trainable-0.1.4.dev11-py3-none-any.whl (27.1 kB view details)

Uploaded Python 3

File details

Details for the file trainable-0.1.4.dev11.tar.gz.

File metadata

  • Download URL: trainable-0.1.4.dev11.tar.gz
  • Upload date:
  • Size: 21.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for trainable-0.1.4.dev11.tar.gz
Algorithm Hash digest
SHA256 873834ee8389a404915efec5a5e2c0a00374fef896a9919b869ebaf5b2f52242
MD5 900f3b857dbb5b6541ffa881b929ee70
BLAKE2b-256 2f2ed7c6f4302849ede0fb6ecd9c36152cea423f5fb12c1f6b73d68a56738e5c

See more details on using hashes here.

File details

Details for the file trainable-0.1.4.dev11-py3-none-any.whl.

File metadata

  • Download URL: trainable-0.1.4.dev11-py3-none-any.whl
  • Upload date:
  • Size: 27.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for trainable-0.1.4.dev11-py3-none-any.whl
Algorithm Hash digest
SHA256 857448a873609a3caa62174b22d0447522fe75c0337d49f5173b39783f5dd9e7
MD5 6371a8808330f482ea8c117fb5aaa77c
BLAKE2b-256 ee652d0d9f61baf89d27d57c58278c7d5c80813369e9ce93f9dd98a559b8fa0d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page