Skip to main content

Pystematic extension for running experiments in pytorch.

Project description

This is an extension to pystematic that adds functionality related to running machine learning experiments in pytorch. Its main contribution is the Context object and related classes. Which provides an easy way to manage all pytorch related objects.

Installation

All you have to do for pystematic to find the plugin is to install it:

$ pip install pystematic-torch

Example

Here’s a small example that shows how using the Context object, SmartDataLoader and Recorder simplifies setting up and running a training session in pytorch.

import pystematic

@pystematic.experiment
def context_example(params):
    ctx = pystematic.torch.Context()

    ctx.epoch = 0

    ctx.recorder = pystematic.torch.Recorder()

    ctx.model = torch.nn.Sequential(
        torch.nn.Linear(2, 1),
        torch.nn.Sigmoid()
    )

    ctx.optimzer = torch.optim.SGD(ctx.model.parameters(), lr=0.01)

    # We use the smart dataloader so that batches are moved to
    # the correct device
    ctx.dataloader = pystematic.torch.SmartDataLoader(
        dataset=Dataset(),
        batch_size=2
    )
    ctx.loss_function = torch.nn.BCELoss()

    ctx.cuda() # Move everything to cuda
    # ctx.ddp() # and maybe distributed data-parallel?

    if params["checkpoint"]:
        # Load checkpoint
        ctx.load_state_dict(pystematic.torch.load_checkpoint(params["checkpoint"]))

    # Train one epoch
    for input, lbl in ctx.dataloader:
        # The smart dataloader makes sure the batch is placed on
        # the correct device.
        output = ctx.model(input)

        loss = ctx.loss_function(output, lbl)

        ctx.optimzer.zero_grad()
        loss.backward()
        ctx.optimzer.step()

        ctx.recorder.scalar("train/loss", loss)
        ctx.recorder.step()

    ctx.epoch += 1

    # Save checkpoint
    pystematic.torch.save_checkpoint(ctx.state_dict(), id=ctx.epoch)

Documentation

Reference documentation is available at https://pystematic-torch.readthedocs.io.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

pystematic_torch-1.3.2-py3-none-any.whl (13.9 kB view details)

Uploaded Python 3

File details

Details for the file pystematic_torch-1.3.2-py3-none-any.whl.

File metadata

  • Download URL: pystematic_torch-1.3.2-py3-none-any.whl
  • Upload date:
  • Size: 13.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.2 CPython/3.8.12 Linux/5.13.0-1025-azure

File hashes

Hashes for pystematic_torch-1.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 9af7286deae583db0aa4ccc8cbead4104ba3ec8b83cd0e18147bb14b3d9372d5
MD5 cecc777d6036776fd0f33ae5df24e6b9
BLAKE2b-256 6c77eb3407dda8a6d692ae0cc15041a076f458061a12d300668048fc051087db

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page