Skip to main content

Lightweight high-level PyTorch framework that runs on potato machines

Project description

PotaTorch

PotaTorch is a lightweight PyTorch framework specifically designed to run on hardware with limited resources.

WIP:


Install PotaTorch

PotaTorch is not currently on PyPI, so you'll have to install it from sources:

git clone https://github.com/crybot/potatorch
pip install -e potatorch

Minimal Working Example

You can run the following example directly from examples/mlp.py if you already have pytorch installed, or you can run it with docker through the provided scripts:

./build.sh && ./run.sh

The example trains a feed forward network on a toy problem:

import torch
from torch import nn

from potatorch.training import TrainingLoop, make_optimizer
from potatorch.callbacks import ProgressbarCallback
from torch.utils.data import TensorDataset

# Fix a seed for TrainingLoop to make non-deterministic operations such as
# shuffling reproducible
SEED = 42
device = 'cuda'

epochs = 100
lr = 1e-4

# Define your model as a pytorch Module
model = nn.Sequential(nn.Linear(1, 128), nn.ReLU(), 
        nn.Linear(128, 128), nn.ReLU(),
        nn.Linear(128, 1))

# Create your dataset as a torch.data.Dataset
dataset = TensorDataset(torch.arange(1000).view(1000, 1), torch.sin(torch.arange(1000)))

# Provide a loss function and an optimizer
loss_fn = torch.nn.MSELoss()
optimizer = make_optimizer(torch.optim.Adam, lr=lr)

# Construct a TrainingLoop object.
# TrainingLoop handles the initialization of dataloaders, dataset splitting,
# shuffling, mixed precision training, etc.
# You can provide callback handles through the `callbacks` argument.
training_loop = TrainingLoop(
        dataset,
        loss_fn,
        optimizer,
        train_p=0.8,
        val_p=0.1,
        test_p=0.1,
        random_split=False,
        batch_size=None,
        shuffle=False,
        device=device,
        num_workers=0,
        seed=SEED,
        val_metrics={'l1': nn.L1Loss(), 'mse': nn.MSELoss()},
        callbacks=[
            ProgressbarCallback(epochs=epochs, width=20),
            ]
        )
# Run the training loop
model = training_loop.run(model, epochs=epochs)

Automatic Hyperparameters Optimization

PotaTorch provides a basic set of utilities to perform hyperparameters optimization. You can choose among grid search, random search and bayesian search. All of them are provided by potatorch.optimization.tuning.HyperOptimizer. The following is a working example of a simple grid search on a toy problem. You can find the full script under exampled/grid_search.py

def train(dataset, device, config):
    """ Your usual training function that runs a TrainingLoop instance """
    # Fix a seed for TrainingLoop to make non-deterministic operations such as
    # shuffling reproducible
    SEED = 42
    # NOTE: `epochs` is a fixed hyperparameter; it won't change among runs
    epochs = config['epochs']

    # Define your model as a pytorch Module
    model = nn.Sequential(nn.Linear(1, 128), nn.ReLU(), 
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, 1))

    loss_fn = torch.nn.MSELoss()
    # Provide a loss function and an optimizer
    # NOTE: `lr` is a dynamic hyperparameter; it will change among runs
    optimizer = make_optimizer(torch.optim.Adam, lr=config['lr'])

    # Construct a TrainingLoop object.
    # TrainingLoop handles the initialization of dataloaders, dataset splitting,
    # shuffling, mixed precision training, etc.
    # You can provide callback handles through the `callbacks` argument.
    training_loop = TrainingLoop(
            dataset,
            loss_fn,
            optimizer,
            train_p=0.8,
            val_p=0.1,
            test_p=0.1,
            random_split=False,
            batch_size=None,
            shuffle=False,
            device=device,
            num_workers=0,
            seed=SEED,
            val_metrics={'l1': nn.L1Loss(), 'mse': nn.MSELoss()},
            callbacks=[
                ProgressbarCallback(epochs=epochs, width=20),
                ]
            )
    # Run the training loop
    model = training_loop.run(model, epochs=epochs, verbose=1)
    # Return a dictionary containing the training and validation metrics 
    # calculated during the last epoch of the loop
    return training_loop.get_last_metrics()

# Define your search configuration
search_config = {
        'method': 'grid',   # which search method to use: ['grid', 'bayes', 'random']
        'metric': {
            'name': 'val_loss', # the metric you're optimizing
            'goal': 'minimize'  # whether you want to minimize or maximize it
        },
        'parameters': { # the set of hyperparameters you want to optimize
            'lr': {
                'values': [1e-2, 1e-3, 1e-4]    # a range of values for the grid search to try
            }
        },
        'fixed': {      # fixed hyperparameters that won't change among runs
            'epochs': 200
        }
    }

def main():
    device = 'cuda'
    # Create your dataset as a torch.data.Dataset
    dataset = TensorDataset(torch.arange(1000).view(1000, 1), torch.sin(torch.arange(1000)))
    # Apply additional parameters to the train function to have f(config) -> {}
    score_function = partial(train, dataset, device)
    # Construct the hyperparameters optimizer
    hyperoptimizer = HyperOptimizer(search_config)
    # Run the optimization over the hyperparameters space
    config, error = hyperoptimizer.optimize(score_function, return_error=True)
    print('Best configuration found: {}\n with error: {}'.format(config, error))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

potatorch-0.0.1.tar.gz (296.3 kB view details)

Uploaded Source

Built Distribution

potatorch-0.0.1-py3-none-any.whl (19.3 kB view details)

Uploaded Python 3

File details

Details for the file potatorch-0.0.1.tar.gz.

File metadata

  • Download URL: potatorch-0.0.1.tar.gz
  • Upload date:
  • Size: 296.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for potatorch-0.0.1.tar.gz
Algorithm Hash digest
SHA256 c7bde8b1bfd6ce450530a093cc2e0ff6f939aee2fa91ae5676ffd332da6ffdc7
MD5 20bab471a615e82dc2445fbf6e3aa763
BLAKE2b-256 d765330596d6abebf7e2a3f5303bb845a0a2078837eccce3e95af9d19e94fb48

See more details on using hashes here.

File details

Details for the file potatorch-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: potatorch-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 19.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for potatorch-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 16d5eb9e391a74ab99c9dd7f9f30d2964bde43386e26946a1b94cf9e0f5989ce
MD5 093099727384e70663e852a5a7bc4506
BLAKE2b-256 cb4051f036167369edecf7688b4ec8f9de4d6df405fc1ccf0e21da68700d7d08

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page