Skip to main content

Lightweight high-level PyTorch framework that runs on potato machines

Project description

PotaTorch is a lightweight PyTorch framework specifically designed to run on hardware with limited resources.


PyPI - Python Version PyPI version GitHub commit activity license

Installation

PotaTorch is published on PyPI, you can install it through pip:

pip install potatorch

or you can install it from sources:

git clone --single-branch -b main https://github.com/crybot/potatorch
pip install -e potatorch

Minimal Working Example

You can run the following example directly from examples/mlp.py if you already have pytorch installed, or you can run it with docker through the provided scripts:

./build.sh && ./run.sh

The example trains a feed forward network on a toy problem:

import torch
from torch import nn

from potatorch.training import TrainingLoop, make_optimizer
from potatorch.callbacks import ProgressbarCallback
from torch.utils.data import TensorDataset

# Fix a seed for TrainingLoop to make non-deterministic operations such as
# shuffling reproducible
SEED = 42
device = 'cuda'

epochs = 100
lr = 1e-4

# Define your model as a pytorch Module
model = nn.Sequential(nn.Linear(1, 128), nn.ReLU(), 
        nn.Linear(128, 128), nn.ReLU(),
        nn.Linear(128, 1))

# Create your dataset as a torch.data.Dataset
dataset = TensorDataset(torch.arange(1000).view(1000, 1), torch.sin(torch.arange(1000)))

# Provide a loss function and an optimizer
loss_fn = torch.nn.MSELoss()
optimizer = make_optimizer(torch.optim.Adam, lr=lr)

# Construct a TrainingLoop object.
# TrainingLoop handles the initialization of dataloaders, dataset splitting,
# shuffling, mixed precision training, etc.
# You can provide callback handles through the `callbacks` argument.
training_loop = TrainingLoop(
        dataset,
        loss_fn,
        optimizer,
        train_p=0.8,
        val_p=0.1,
        test_p=0.1,
        random_split=False,
        batch_size=None,
        shuffle=False,
        device=device,
        num_workers=0,
        seed=SEED,
        val_metrics={'l1': nn.L1Loss(), 'mse': nn.MSELoss()},
        callbacks=[
            ProgressbarCallback(epochs=epochs, width=20),
            ]
        )
# Run the training loop
model = training_loop.run(model, epochs=epochs)

Automatic Hyperparameters Optimization

PotaTorch provides a basic set of utilities to perform hyperparameters optimization. You can choose among grid search, random search and bayesian search. All of them are provided by potatorch.optimization.tuning.HyperOptimizer. The following is a working example of a simple grid search on a toy problem. You can find the full script under examples/grid_search.py

def train(dataset, device, config):
    """ Your usual training function that runs a TrainingLoop instance """
    SEED = 42
    # `epochs` is a fixed hyperparameter; it won't change among runs
    epochs = config['epochs']

    # Define your model as a pytorch Module
    model = nn.Sequential(nn.Linear(1, 128), nn.ReLU(), 
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, 1))

    loss_fn = torch.nn.MSELoss()
    # `lr` is a dynamic hyperparameter; it will change among runs
    optimizer = make_optimizer(torch.optim.Adam, lr=config['lr'])

    training_loop = TrainingLoop(
            dataset,
            loss_fn,
            optimizer,
            train_p=0.8,
            val_p=0.1,
            test_p=0.1,
            random_split=False,
            batch_size=None,
            shuffle=False,
            device=device,
            num_workers=0,
            seed=SEED,
            val_metrics={'l1': nn.L1Loss(), 'mse': nn.MSELoss()},
            callbacks=[
                ProgressbarCallback(epochs=epochs, width=20),
                ]
            )
    model = training_loop.run(model, epochs=epochs, verbose=1)
    # Return a dictionary containing the training and validation metrics 
    # calculated during the last epoch of the loop
    return training_loop.get_last_metrics()

# Define your search configuration
search_config = {
        'method': 'grid',   # which search method to use: ['grid', 'bayes', 'random']
        'metric': {
            'name': 'val_loss', # the metric you're optimizing
            'goal': 'minimize'  # whether you want to minimize or maximize it
        },
        'parameters': { # the set of hyperparameters you want to optimize
            'lr': {
                'values': [1e-2, 1e-3, 1e-4]    # a range of values for the grid search to try
            }
        },
        'fixed': {      # fixed hyperparameters that won't change among runs
            'epochs': 200
        }
    }

def main():
    device = 'cuda'
    dataset = TensorDataset(torch.arange(1000).view(1000, 1), torch.sin(torch.arange(1000)))
    # Apply additional parameters to the train function to have f(config) -> {}
    score_function = partial(train, dataset, device)
    # Construct the hyperparameters optimizer
    hyperoptimizer = HyperOptimizer(search_config)
    # Run the optimization over the hyperparameters space
    config, error = hyperoptimizer.optimize(score_function, return_error=True)
    print('Best configuration found: {}\n with error: {}'.format(config, error))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

potatorch-0.0.5.tar.gz (1.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

potatorch-0.0.5-py3-none-any.whl (20.5 kB view details)

Uploaded Python 3

File details

Details for the file potatorch-0.0.5.tar.gz.

File metadata

  • Download URL: potatorch-0.0.5.tar.gz
  • Upload date:
  • Size: 1.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for potatorch-0.0.5.tar.gz
Algorithm Hash digest
SHA256 63dabc56690eb2f043627e373b49c4b4edcd38d1e644b0aacc671f9c2f4db4c7
MD5 20f713cc40608937dfc6550b39a2a3ec
BLAKE2b-256 b4f278920ba8c49f7335b358df0c54dd5f277532045c724e09907b6c3d7c8424

See more details on using hashes here.

File details

Details for the file potatorch-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: potatorch-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 20.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for potatorch-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 25c9faf3a73805424897eb6dd9411dc71bf07fbc535d6b5b90a02f621e960b6d
MD5 3e94c96d0a582e05df80b90e0317b444
BLAKE2b-256 fc2c6f41dee4ebbdfb20d216066d3b9d9b9029cf6696ec5d9b993eb885e87102

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page