Skip to main content

The flexible training toolbox

Project description

Trainable: The Flexible PyTorch Training Toolbox

If you're sick of dealing with all of the boilerplate code involved in training, evaluation, visualization, and preserving your models, then you're in luck. Trainable offers a simple, yet extensible framework to make understanding the latest papers the only headache of Neural Network training.

Installation

pip install trainable

Usage

The typical workflow for trainable involves defining a callable Algorithm to describe how to train your network on a batch, and how you'd like to label your losses:

class MSEAlgorithm(Algorithm):
    def __init__(self, eval=False, **args):
        super().__init__(eval)
        self.mse = nn.MSELoss()

    def __call__(self, model, batch, device):
        x, target = batch
        x, target = x.to(device), target.to(device)
        y = model(x)

        loss = self.mse(y, target)
        loss.backward()

        metrics = { self.key("MSE Loss"):loss.item() }
        return metrics

Then you simply instantiate your model, dataset, and optimizer...

device = torch.device('cuda')

model = MyModel().to(device)
optim = FancyPantsOptimizer(model.parameters(), lr=1e-4)

train_data = DataLoader(SomeTrainingDataset('path/to/your/data'), batch_size=32)
test_data = DataLoader(SomeTestingDataset('path/to/your/data'), batch_size=32)

...and let trainable take care of the rest!

trainer = Trainer(
  visualizer=MyVisualizer(),  # Typically Plotter() or Saver()
  train_alg=MyFancyAlgorithm(),
  test_alg=MyFancyAlgorithm(eval=True)
  display_freq=1,
  visualize_freq=10,
  validate_freq=10,
  autosave_freq=10,
  device=device
)

save_path = "desired/save/path/for/your/session.sesh"
trainer.start_session(model, optim, path)

trainer.name_session('Name')

trainer.describe_session("""
A beautiful detailed description of what the heck 
you were trying to accomplish with this training.
""")

metrics = trainer.train(train_data, test_data, epochs=200)

Plotting your data is simple as well:

import matplotlib.pyplot as plt

for key in metrics:
    plt.plot(metrics[key])
    plt.show()

Tunable Options

The Trainer interface gives you a nice handful of options to configure your training experience. They include:

  • Display Frequency: How often (in batches) information such as your training loss is updated in your progress bar.
  • Visualization Frequency: How often (in batches) the training produces a visualization of your model's outputs.
  • Validation Frequency: How often (in epochs) the trainer performs validation with your test data.
  • Autosave Frequency: How often your session is saved out to disk.
  • Device: On which hardware your training should occur.

Customization

Do you want a little more granularity in how you visualize your data? Or perhaps running an epoch with your model is a little more involved than just training on each batch of data? Wondering why the heck pytorch doesn't have a built-in dataset for unsupervised images? Maybe your training algorithm involves VGG? Got you covered. Check out the source for the various submodules:

Contributing

Find any other headaches in neural net training that you think you can simplify with Trainable? Feel free to make a pull request from my github repo.

Contact

Email me anytime at jeffhilton.code@gmail.com.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trainable-0.1.4.dev13.tar.gz (21.8 kB view details)

Uploaded Source

Built Distribution

trainable-0.1.4.dev13-py3-none-any.whl (27.2 kB view details)

Uploaded Python 3

File details

Details for the file trainable-0.1.4.dev13.tar.gz.

File metadata

  • Download URL: trainable-0.1.4.dev13.tar.gz
  • Upload date:
  • Size: 21.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for trainable-0.1.4.dev13.tar.gz
Algorithm Hash digest
SHA256 210d897282ca37543d691440a2a3d61298e0a00c28e8bf4077adc9d7c6bc0006
MD5 18181b3e5188ea9263fb7cc3332bf56f
BLAKE2b-256 3ae2ec04e3b0b5dfb2710d5dfcbde988f0b8234d0bdf860de919c0a16ce818a3

See more details on using hashes here.

File details

Details for the file trainable-0.1.4.dev13-py3-none-any.whl.

File metadata

  • Download URL: trainable-0.1.4.dev13-py3-none-any.whl
  • Upload date:
  • Size: 27.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for trainable-0.1.4.dev13-py3-none-any.whl
Algorithm Hash digest
SHA256 86ff19f218f9664f43e9cef60b9c17102e7414d2942388b5374086a7b74e7ea0
MD5 5223cf6df5f1d2716363388e5a7da1e6
BLAKE2b-256 ee39ca5b9c54fba8f8044dc3361f35c28ed454d9371db51b7b4dd606ba98fa45

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page