Skip to main content

A wrapper for fastai projects to create easy command-line inferfaces and manage hyper-parameter tuning.

Project description

https://raw.githubusercontent.com/rbturnbull/torchapp/master/docs/images/torchapp-banner.svg

testing badge coverage badge docs badge black badge git3moji badge torchapp badge

A wrapper for PyTorch projects to create easy command-line interfaces and manage hyper-parameter tuning.

Documentation at https://rbturnbull.github.io/torchapp/

Installation

The software can be installed using pip

pip install torchapp

To install the latest version from the repository, you can use this command:

pip install git+https://github.com/rbturnbull/torchapp.git

Writing an App

Inherit a class from TorchApp to make an app. The parent class includes several methods for training and hyper-parameter tuning. The minimum requirement is that you fill out the dataloaders method and the model method.

The dataloaders method requires that you return a fastai Dataloaders object. This is a collection of dataloader objects. Typically it contains one dataloader for training and another for testing. For more information see https://docs.fast.ai/data.core.html#DataLoaders You can add parameter values with typing hints in the function signature and these will be automatically added to the train and show_batch methods.

The model method requires that you return a pytorch module. Parameters in the function signature will be added to the train method.

Here’s an example for doing logistic regression:

#!/usr/bin/env python3
from pathlib import Path
import pandas as pd
from torch import nn
from fastai.data.block import DataBlock, TransformBlock
from fastai.data.transforms import ColReader, RandomSplitter
import torchapp as ta
from torchapp.blocks import BoolBlock


class LogisticRegressionApp(ta.TorchApp):
    """
    Creates a basic app to do logistic regression.
    """

    def dataloaders(
        self,
        csv: Path = ta.Param(help="The path to a CSV file with the data."),
        x: str = ta.Param(default="x", help="The column name of the independent variable."),
        y: str = ta.Param(default="y", help="The column name of the dependent variable."),
        validation_proportion: float = ta.Param(
            default=0.2, help="The proportion of the dataset to use for validation."
        ),
        batch_size: int = ta.Param(
            default=32,
            help="The number of items to use in each batch.",
        ),
    ):

        datablock = DataBlock(
            blocks=[TransformBlock, BoolBlock],
            get_x=ColReader(x),
            get_y=ColReader(y),
            splitter=RandomSplitter(validation_proportion),
        )
        df = pd.read_csv(csv)

        return datablock.dataloaders(df, bs=batch_size)

    def model(self) -> nn.Module:
        """Builds a simple logistic regression model."""
        return nn.Linear(in_features=1, out_features=1, bias=True)

    def loss_func(self):
        return nn.BCEWithLogitsLoss()


if __name__ == "__main__":
    LogisticRegressionApp.main()

Programmatic Interface

To use the app in Python, simply instantiate it:

app = LogisticRegressionApp()

Then you can train with the method:

app.train(training_csv_path)

This takes the arguments of both the dataloaders method and the train method. The function signature is modified so these arguments show up in auto-completion in a Jupyter notebook.

Predictions are made by simply calling the app object.

app(data_csv_path)

Command-Line Interface

Command-line interfaces are created simply by using the Poetry package management tool. Just add a line like this in pyproject.toml

logistic = "logistic.apps:LogisticRegressionApp.main"

Now we can train with the command line:

logistic train training_csv_path

All the arguments for the dataloader and the model can be set through arguments in the CLI. To see them run

logistic train -h

Predictions are made like this:

logistic predict data_csv_path

Hyperparameter Tuning

All the arguments in the dataloader and the model can be tuned using Weights & Biases (W&B) hyperparameter sweeps (https://docs.wandb.ai/guides/sweeps). In Python, simply run:

app.tune(runs=10)

Or from the command line, run

logistic tune --runs 10

These commands will connect with W&B and your runs will be visible on the wandb.ai site.

Project Generation

To use a template to construct a package for your app, simply run:

torchapp

Credits

torchapp was created created by Robert Turnbull with contributions from Jonathan Garber and Simone Bae.

Citation details to follow.

Logo elements derived from icons by ProSymbols and Philipp Petzka.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchapp-0.3.6.tar.gz (40.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchapp-0.3.6-py3-none-any.whl (47.7 kB view details)

Uploaded Python 3

File details

Details for the file torchapp-0.3.6.tar.gz.

File metadata

  • Download URL: torchapp-0.3.6.tar.gz
  • Upload date:
  • Size: 40.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.6.1 CPython/3.10.12 Linux/6.2.0-1012-azure

File hashes

Hashes for torchapp-0.3.6.tar.gz
Algorithm Hash digest
SHA256 8084f005dbd3548aa1ffb341b6539325adc475b5d9cfcb5967d4af45ab2e9ff2
MD5 c0a5b78c257f172d8a2e253fc32efb52
BLAKE2b-256 bfa9d9b79fdfc91051e7b050d7b4e0f91dc320effbd6efd53486731965e43e2c

See more details on using hashes here.

File details

Details for the file torchapp-0.3.6-py3-none-any.whl.

File metadata

  • Download URL: torchapp-0.3.6-py3-none-any.whl
  • Upload date:
  • Size: 47.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.6.1 CPython/3.10.12 Linux/6.2.0-1012-azure

File hashes

Hashes for torchapp-0.3.6-py3-none-any.whl
Algorithm Hash digest
SHA256 30258eb5b1b8576c6cd73d93f7308a018d7c7c3576dc913f78345e22d8fb8546
MD5 7fe56dc6a6da4a27a0f62dcdf4dffd28
BLAKE2b-256 bd338a4bc5190b5739581012b07617c4ccb2530cdc576e4ea9be56bcd5d3ab06

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page