Skip to main content

Lightweight high-level PyTorch framework that runs on potato machines

Project description

PotaTorch is a lightweight PyTorch framework specifically designed to run on hardware with limited resources.


PyPI - Python Version PyPI version GitHub commit activity license

Installation

PotaTorch is published on PyPI, you can install it through pip:

pip install potatorch

or you can install it from sources:

git clone --single-branch -b main https://github.com/crybot/potatorch
pip install -e potatorch

Minimal Working Example

You can run the following example directly from examples/mlp.py if you already have pytorch installed, or you can run it with docker through the provided scripts:

./build.sh && ./run.sh

The example trains a feed forward network on a toy problem:

import torch
from torch import nn

from potatorch.training import TrainingLoop, make_optimizer
from potatorch.callbacks import ProgressbarCallback
from torch.utils.data import TensorDataset

# Fix a seed for TrainingLoop to make non-deterministic operations such as
# shuffling reproducible
SEED = 42
device = 'cuda'

epochs = 100
lr = 1e-4

# Define your model as a pytorch Module
model = nn.Sequential(nn.Linear(1, 128), nn.ReLU(), 
        nn.Linear(128, 128), nn.ReLU(),
        nn.Linear(128, 1))

# Create your dataset as a torch.data.Dataset
dataset = TensorDataset(torch.arange(1000).view(1000, 1), torch.sin(torch.arange(1000)))

# Provide a loss function and an optimizer
loss_fn = torch.nn.MSELoss()
optimizer = make_optimizer(torch.optim.Adam, lr=lr)

# Construct a TrainingLoop object.
# TrainingLoop handles the initialization of dataloaders, dataset splitting,
# shuffling, mixed precision training, etc.
# You can provide callback handles through the `callbacks` argument.
training_loop = TrainingLoop(
        dataset,
        loss_fn,
        optimizer,
        train_p=0.8,
        val_p=0.1,
        test_p=0.1,
        random_split=False,
        batch_size=None,
        shuffle=False,
        device=device,
        num_workers=0,
        seed=SEED,
        val_metrics={'l1': nn.L1Loss(), 'mse': nn.MSELoss()},
        callbacks=[
            ProgressbarCallback(epochs=epochs, width=20),
            ]
        )
# Run the training loop
model = training_loop.run(model, epochs=epochs)

Automatic Hyperparameters Optimization

PotaTorch provides a basic set of utilities to perform hyperparameters optimization. You can choose among grid search, random search and bayesian search. All of them are provided by potatorch.optimization.tuning.HyperOptimizer. The following is a working example of a simple grid search on a toy problem. You can find the full script under examples/grid_search.py

def train(dataset, device, config):
    """ Your usual training function that runs a TrainingLoop instance """
    SEED = 42
    # `epochs` is a fixed hyperparameter; it won't change among runs
    epochs = config['epochs']

    # Define your model as a pytorch Module
    model = nn.Sequential(nn.Linear(1, 128), nn.ReLU(), 
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, 1))

    loss_fn = torch.nn.MSELoss()
    # `lr` is a dynamic hyperparameter; it will change among runs
    optimizer = make_optimizer(torch.optim.Adam, lr=config['lr'])

    training_loop = TrainingLoop(
            dataset,
            loss_fn,
            optimizer,
            train_p=0.8,
            val_p=0.1,
            test_p=0.1,
            random_split=False,
            batch_size=None,
            shuffle=False,
            device=device,
            num_workers=0,
            seed=SEED,
            val_metrics={'l1': nn.L1Loss(), 'mse': nn.MSELoss()},
            callbacks=[
                ProgressbarCallback(epochs=epochs, width=20),
                ]
            )
    model = training_loop.run(model, epochs=epochs, verbose=1)
    # Return a dictionary containing the training and validation metrics 
    # calculated during the last epoch of the loop
    return training_loop.get_last_metrics()

# Define your search configuration
search_config = {
        'method': 'grid',   # which search method to use: ['grid', 'bayes', 'random']
        'metric': {
            'name': 'val_loss', # the metric you're optimizing
            'goal': 'minimize'  # whether you want to minimize or maximize it
        },
        'parameters': { # the set of hyperparameters you want to optimize
            'lr': {
                'values': [1e-2, 1e-3, 1e-4]    # a range of values for the grid search to try
            }
        },
        'fixed': {      # fixed hyperparameters that won't change among runs
            'epochs': 200
        }
    }

def main():
    device = 'cuda'
    dataset = TensorDataset(torch.arange(1000).view(1000, 1), torch.sin(torch.arange(1000)))
    # Apply additional parameters to the train function to have f(config) -> {}
    score_function = partial(train, dataset, device)
    # Construct the hyperparameters optimizer
    hyperoptimizer = HyperOptimizer(search_config)
    # Run the optimization over the hyperparameters space
    config, error = hyperoptimizer.optimize(score_function, return_error=True)
    print('Best configuration found: {}\n with error: {}'.format(config, error))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

potatorch-0.0.4.tar.gz (1.3 MB view details)

Uploaded Source

Built Distribution

potatorch-0.0.4-py3-none-any.whl (19.6 kB view details)

Uploaded Python 3

File details

Details for the file potatorch-0.0.4.tar.gz.

File metadata

  • Download URL: potatorch-0.0.4.tar.gz
  • Upload date:
  • Size: 1.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for potatorch-0.0.4.tar.gz
Algorithm Hash digest
SHA256 27ad825165f5ea77d43fef3da40f33c3397afd142ac00bf04b8b96f893d44d14
MD5 4263e43d37389ba583265a6b06847b6d
BLAKE2b-256 4919f272c044c72cf8d95c019e3f202efcf0046462ba1ae21c0affb732ecd628

See more details on using hashes here.

File details

Details for the file potatorch-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: potatorch-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 19.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for potatorch-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 13509b2cdc3053b0740f699c145653f0c2e4b3342899f2e31541f7c43219829e
MD5 4976ea76550ae300a4e9238b28d4f5d4
BLAKE2b-256 08cf475c9d568bfb37397d80ae373c8254beeb9491623eb2d00e9eed9ac105de

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page