The flexible training toolbox
Project description
Trainable: The Flexible PyTorch Training Toolbox
If you're sick of dealing with all of the boilerplate code involved in training, evaluation, visualization, and preserving your models, then you're in luck. Trainable offers a simple, yet extensible framework to make understanding the latest papers the only headache of Neural Network training.
Installation
pip install trainable
Usage
The typical workflow for trainable involves defining a callable Algorithm to describe how to train your network on a batch, and how you'd like to label your losses:
class MSEAlgorithm(Algorithm):
def __init__(self, eval=False, **args):
super().__init__(eval)
self.mse = nn.MSELoss()
def __call__(self, model, batch, device):
x, target = batch
x, target = x.to(device), target.to(device)
y = model(x)
loss = self.mse(y, target)
loss.backward()
metrics = { self.key("MSE Loss"):loss.item() }
return metrics
Then you simply instantiate your model, dataset, and optimizer...
device = torch.device('cuda')
model = MyModel().to(device)
optim = FancyPantsOptimizer(model.parameters(), lr=1e-4)
train_data = DataLoader(SomeTrainingDataset('path/to/your/data'), batch_size=32)
test_data = DataLoader(SomeTestingDataset('path/to/your/data'), batch_size=32)
...and let trainable take care of the rest!
trainer = Trainer(
visualizer=MyVisualizer(), # Typically Plotter() or Saver()
train_alg=MyFancyAlgorithm(),
test_alg=MyFancyAlgorithm(eval=True)
display_freq=1,
visualize_freq=10,
validate_freq=10,
autosave_freq=10,
device=device
)
save_path = "desired/save/path/for/your/session.sesh"
trainer.start_session(model, optim, path)
trainer.name_session('Name')
trainer.describe_session("""
A beautiful detailed description of what the heck
you were trying to accomplish with this training.
""")
metrics = trainer.train(train_data, test_data, epochs=200)
Plotting your data is simple as well:
import matplotlib.pyplot as plt
for key in metrics:
plt.plot(metrics[key])
plt.show()
Tunable Options
The Trainer interface gives you a nice handful of options to configure your training experience. They include:
- Display Frequency: How often (in batches) information such as your training loss is updated in your progress bar.
- Visualization Frequency: How often (in batches) the training produces a visualization of your model's outputs.
- Validation Frequency: How often (in epochs) the trainer performs validation with your test data.
- Autosave Frequency: How often your session is saved out to disk.
- Device: On which hardware your training should occur.
Customization
Do you want a little more granularity in how you visualize your data? Or perhaps running an epoch with your model is a little more involved than just training on each batch of data? Wondering why the heck pytorch doesn't have a built-in dataset for unsupervised images? Maybe your training algorithm involves VGG? Got you covered. Check out the source for the various submodules:
- trainable.visualize -- for customizing visualization.
- trainable.epoch -- for customizing epochs.
- trainable.data -- for common datasets and transforms not found in pytorch's modules.
- trainable.features -- for working with intermediate activations and features, such as with VGG-based losses.
Contributing
Find any other headaches in neural net training that you think you can simplify with Trainable? Feel free to make a pull request from my github repo.
Contact
Email me anytime at jeffhilton.code@gmail.com.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file trainable-0.1.4.dev5.tar.gz
.
File metadata
- Download URL: trainable-0.1.4.dev5.tar.gz
- Upload date:
- Size: 18.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5516bb073aeedeb86f326a388f1bbf249218f212278a0bd1f87f4de3b77a3feb |
|
MD5 | bf3d1ac67fd3f016e574531841ae4300 |
|
BLAKE2b-256 | b109b6ec60ddf1edf9190e6a38543b6013f1b3c6e459ce183e3ffcc670fded4d |
File details
Details for the file trainable-0.1.4.dev5-py3-none-any.whl
.
File metadata
- Download URL: trainable-0.1.4.dev5-py3-none-any.whl
- Upload date:
- Size: 23.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ff2e68086f5287da153198451717d8ddfea5ed1e006260673bf004078bcc5f0f |
|
MD5 | a0ce93a622e124fc409c9222022431fb |
|
BLAKE2b-256 | d18917b6e2d06f15f69af1761d14552584e22790eb5819db98c8da42c0173df3 |