Skip to main content

The flexible training toolbox

Project description

Trainable: The Flexible PyTorch Training Toolbox

If you're sick of dealing with all of the boilerplate code involved in training, evaluation, visualization, and preserving your models, then you're in luck. Trainable offers a simple, yet extensible framework to make understanding the latest papers the only headache of Neural Network training.

Installation

pip install trainable

Usage

The typical workflow for trainable involves defining a callable Algorithm to describe how to train your network on a batch, and how you'd like to label your losses:

class MSEAlgorithm(Algorithm):
    def __init__(self, eval=False, **args):
        super().__init__(eval)
        self.mse = nn.MSELoss()

    def __call__(self, model, batch, device):
        x, target = batch
        x, target = x.to(device), target.to(device)
        y = model(x)

        loss = self.mse(y, target)
        loss.backward()

        metrics = { self.key("MSE Loss"):loss.item() }
        return metrics

Then you simply instantiate your model, dataset, and optimizer...

device = torch.device('cuda')

model = MyModel().to(device)
optim = FancyPantsOptimizer(model.parameters(), lr=1e-4)

train_data = DataLoader(SomeTrainingDataset('path/to/your/data'), batch_size=32)
test_data = DataLoader(SomeTestingDataset('path/to/your/data'), batch_size=32)

...and let trainable take care of the rest!

trainer = Trainer(
  visualizer=MyVisualizer(),  # Typically Plotter() or Saver()
  train_alg=MyFancyAlgorithm(),
  test_alg=MyFancyAlgorithm(eval=True)
  display_freq=1,
  visualize_freq=10,
  validate_freq=10,
  autosave_freq=10,
  device=device
)

save_path = "desired/save/path/for/your/session.sesh"
trainer.start_session(model, optim, path)

trainer.name_session('Name')

trainer.describe_session("""
A beautiful detailed description of what the heck 
you were trying to accomplish with this training.
""")

metrics = trainer.train(train_data, test_data, epochs=200)

Plotting your data is simple as well:

import matplotlib.pyplot as plt

for key in metrics:
    plt.plot(metrics[key])
    plt.show()

Tunable Options

The Trainer interface gives you a nice handful of options to configure your training experience. They include:

  • Display Frequency: How often (in batches) information such as your training loss is updated in your progress bar.
  • Visualization Frequency: How often (in batches) the training produces a visualization of your model's outputs.
  • Validation Frequency: How often (in epochs) the trainer performs validation with your test data.
  • Autosave Frequency: How often your session is saved out to disk.
  • Device: On which hardware your training should occur.

Customization

Do you want a little more granularity in how you visualize your data? Or perhaps running an epoch with your model is a little more involved than just training on each batch of data? Wondering why the heck pytorch doesn't have a built-in dataset for unsupervised images? Maybe your training algorithm involves VGG? Got you covered. Check out the source for the various submodules:

Contributing

Find any other headaches in neural net training that you think you can simplify with Trainable? Feel free to make a pull request from my github repo.

Contact

Email me anytime at jeffhilton.code@gmail.com.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trainable-0.1.4.dev5.tar.gz (18.5 kB view details)

Uploaded Source

Built Distribution

trainable-0.1.4.dev5-py3-none-any.whl (23.5 kB view details)

Uploaded Python 3

File details

Details for the file trainable-0.1.4.dev5.tar.gz.

File metadata

  • Download URL: trainable-0.1.4.dev5.tar.gz
  • Upload date:
  • Size: 18.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.7

File hashes

Hashes for trainable-0.1.4.dev5.tar.gz
Algorithm Hash digest
SHA256 5516bb073aeedeb86f326a388f1bbf249218f212278a0bd1f87f4de3b77a3feb
MD5 bf3d1ac67fd3f016e574531841ae4300
BLAKE2b-256 b109b6ec60ddf1edf9190e6a38543b6013f1b3c6e459ce183e3ffcc670fded4d

See more details on using hashes here.

File details

Details for the file trainable-0.1.4.dev5-py3-none-any.whl.

File metadata

  • Download URL: trainable-0.1.4.dev5-py3-none-any.whl
  • Upload date:
  • Size: 23.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.7

File hashes

Hashes for trainable-0.1.4.dev5-py3-none-any.whl
Algorithm Hash digest
SHA256 ff2e68086f5287da153198451717d8ddfea5ed1e006260673bf004078bcc5f0f
MD5 a0ce93a622e124fc409c9222022431fb
BLAKE2b-256 d18917b6e2d06f15f69af1761d14552584e22790eb5819db98c8da42c0173df3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page