Skip to main content

The flexible training toolbox

Project description

Trainable: The Flexible PyTorch Training Toolbox

If you're sick of dealing with all of the boilerplate code involved in training, evaluation, visualization, and preserving your models, then you're in luck. Trainable offers a simple, yet extensible framework to make understanding the latest papers the only headache of Neural Network training.

Installation

pip install trainable

Usage

The typical workflow for trainable involves defining a callable Algorithm to describe how to train your network on a batch, and how you'd like to label your losses:

class MSEAlgorithm(Algorithm):
    def __init__(self, eval=False, **args):
        super().__init__(eval)
        self.mse = nn.MSELoss()

    def __call__(self, model, batch, device):
        x, target = batch
        x, target = x.to(device), target.to(device)
        y = model(x)

        loss = self.mse(y, target)
        loss.backward()

        metrics = { self.key("MSE Loss"):loss.item() }
        return metrics

Then you simply instantiate your model, dataset, and optimizer...

device = torch.device('cuda')

model = MyModel().to(device)
optim = FancyPantsOptimizer(model.parameters(), lr=1e-4)

train_data = DataLoader(SomeTrainingDataset('path/to/your/data'), batch_size=32)
test_data = DataLoader(SomeTestingDataset('path/to/your/data'), batch_size=32)

...and let trainable take care of the rest!

trainer = Trainer(
  visualizer=MyVisualizer(),  # Typically Plotter() or Saver()
  train_alg=MyFancyAlgorithm(),
  test_alg=MyFancyAlgorithm(eval=True)
  display_freq=1,
  visualize_freq=10,
  validate_freq=10,
  autosave_freq=10,
  device=device
)

save_path = "desired/save/path/for/your/session.sesh"
trainer.start_session(model, optim, path)

trainer.name_session('Name')

trainer.describe_session("""
A beautiful detailed description of what the heck 
you were trying to accomplish with this training.
""")

metrics = trainer.train(train_data, test_data, epochs=200)

Plotting your data is simple as well:

import matplotlib.pyplot as plt

for key in metrics:
    plt.plot(metrics[key])
    plt.show()

Tunable Options

The Trainer interface gives you a nice handful of options to configure your training experience. They include:

  • Display Frequency: How often (in batches) information such as your training loss is updated in your progress bar.
  • Visualization Frequency: How often (in batches) the training produces a visualization of your model's outputs.
  • Validation Frequency: How often (in epochs) the trainer performs validation with your test data.
  • Autosave Frequency: How often your session is saved out to disk.
  • Device: On which hardware your training should occur.

Customization

Do you want a little more granularity in how you visualize your data? Or perhaps running an epoch with your model is a little more involved than just training on each batch of data? Wondering why the heck pytorch doesn't have a built-in dataset for unsupervised images? Maybe your training algorithm involves VGG? Got you covered. Check out the source for the various submodules:

Contributing

Find any other headaches in neural net training that you think you can simplify with Trainable? Feel free to make a pull request from my github repo.

Contact

Email me anytime at jeffhilton.code@gmail.com.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trainable-0.1.4.dev1.tar.gz (17.8 kB view details)

Uploaded Source

Built Distribution

trainable-0.1.4.dev1-py3-none-any.whl (22.2 kB view details)

Uploaded Python 3

File details

Details for the file trainable-0.1.4.dev1.tar.gz.

File metadata

  • Download URL: trainable-0.1.4.dev1.tar.gz
  • Upload date:
  • Size: 17.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.7

File hashes

Hashes for trainable-0.1.4.dev1.tar.gz
Algorithm Hash digest
SHA256 9988d071eff2c3b4ac044a073a1eb8ebd366b20d8982f617216a9ebcd8804960
MD5 344824fec231b3bfce3b0e05bfa8e25b
BLAKE2b-256 649555c1b420bf637803059e12c5163195862fbe43b80124e1508bacec412d6f

See more details on using hashes here.

File details

Details for the file trainable-0.1.4.dev1-py3-none-any.whl.

File metadata

  • Download URL: trainable-0.1.4.dev1-py3-none-any.whl
  • Upload date:
  • Size: 22.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.7

File hashes

Hashes for trainable-0.1.4.dev1-py3-none-any.whl
Algorithm Hash digest
SHA256 816c6c58c11723b34654f9660af308241d2be36c09071815c0b1ef065dd8b0f6
MD5 85696660bc946ccd96fe80698e12e62a
BLAKE2b-256 f6e3dafa374ca12d0ed3dd56cf3a5acfd6f6d1ef092ea409c9f741610e93ffaf

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page