Pystematic extension for running experiments in pytorch.
Project description
This is an extension to pystematic that adds functionality related to running machine learning experiments in pytorch. Its main contribution is the Context object and related classes. Which provides an easy way to manage all pytorch related objects.
Installation
All you have to do for pystematic to find the plugin is to install it:
$ pip install pystematic-torch
Example
Here’s a small example that shows how using the Context object, SmartDataLoader and Recorder simplifies setting up and running a training session in pytorch.
import pystematic
@pystematic.experiment
def context_example(params):
ctx = pystematic.torch.Context()
ctx.epoch = 0
ctx.recorder = pystematic.torch.Recorder()
ctx.model = torch.nn.Sequential(
torch.nn.Linear(2, 1),
torch.nn.Sigmoid()
)
ctx.optimzer = torch.optim.SGD(ctx.model.parameters(), lr=0.01)
# We use the smart dataloader so that batches are moved to
# the correct device
ctx.dataloader = pystematic.torch.SmartDataLoader(
dataset=Dataset(),
batch_size=2
)
ctx.loss_function = torch.nn.BCELoss()
ctx.cuda() # Move everything to cuda
# ctx.ddp() # and maybe distributed data-parallel?
if params["checkpoint"]:
# Load checkpoint
ctx.load_state_dict(pystematic.torch.load_checkpoint(params["checkpoint"]))
# Train one epoch
for input, lbl in ctx.dataloader:
# The smart dataloader makes sure the batch is placed on
# the correct device.
output = ctx.model(input)
loss = ctx.loss_function(output, lbl)
ctx.optimzer.zero_grad()
loss.backward()
ctx.optimzer.step()
ctx.recorder.scalar("train/loss", loss)
ctx.recorder.step()
ctx.epoch += 1
# Save checkpoint
pystematic.torch.save_checkpoint(ctx.state_dict(), id=ctx.epoch)
Documentation
Reference documentation is available at https://pystematic-torch.readthedocs.io.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file pystematic_torch-1.2.0-py3-none-any.whl
.
File metadata
- Download URL: pystematic_torch-1.2.0-py3-none-any.whl
- Upload date:
- Size: 13.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.2 CPython/3.8.12 Linux/5.8.0-1042-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ea9699ae052710832c09e6e02030588c5506f1e3d91144203c18b940e80ccdea |
|
MD5 | e5fcd03da89ab8f54f2daf451531efd1 |
|
BLAKE2b-256 | 379d110165bd44f66147fc322958afb74e371be6c9e8407a46143db9567daca7 |