Skip to main content

A simple abstract class for training and inference in PyTorch.

Project description

TorchABC

torchabc is a lightweight package that provides an Abstract Base Class (ABC) to structure PyTorch projects and keep code well organized.

The core of the package is the TorchABC class. This class defines the abstract training and inference workflows and must be subclassed to implement a concrete logic.

This package has no extra dependencies beyond PyTorch and it consists of a simple self-contained file. It is ideal for research, prototyping, and teaching.

Structure

The TorchABC class structures a project into the following main steps:

diagram

  1. Dataloaders - load raw data samples.
  2. Preprocess – transform raw samples.
  3. Collate - batch preprocessed samples.
  4. Network - compute model outputs.
  5. Loss - compute error against targets.
  6. Optimizer - update model parameters.
  7. Postprocess - transform outputs into predictions.

Each step corresponds to an abstract method in TorchABC. To use TorchABC, create a concrete subclass and implement these methods.

Quick start

Install the package.

pip install torchabc

Generate a template using the command line interface.

torchabc --create template.py --min

Fill out the template by implementing the methods below. The documentation of each method is available here.

import torch
from torchabc import TorchABC
from functools import cached_property


class MyModel(TorchABC):
    
    @cached_property
    def dataloaders(self):
        raise NotImplementedError
    
    @staticmethod
    def preprocess(sample, hparams, flag=''):
        return sample

    @staticmethod
    def collate(samples):
        return torch.utils.data.default_collate(samples)

    @cached_property
    def network(self):
        raise NotImplementedError
    
    @staticmethod
    def loss(outputs, targets, hparams):
        raise NotImplementedError

    @cached_property
    def optimizer(self):
        raise NotImplementedError
    
    @staticmethod
    def postprocess(outputs, hparams):
        return outputs

Usage

Once a subclass of TorchABC is implemented, it can be used for training, evaluation, checkpointing, and inference.

Initialization

model = MyModel()

Initialize the model.

Training

model.train(epochs=5, on="train", val="val")

Train the model for 5 epochs using the train and val dataloaders.

Evaluation

metrics = model.eval(on="test")

Evaluate on the test dataloader and return metrics.

Checkpoints

model.save("checkpoint.pth")
model.load("checkpoint.pth")

Save and restore the model state.

Inference

preds = model(samples)

Run predictions on raw input samples.

API Reference

The TorchABC class defines a standard workflow for PyTorch projects. Some methods are abstract (must be implemented in subclasses), others are optional (can be overridden but have defaults), and a few are concrete (should not be overridden).


Abstract Methods

Method Description
dataloaders Must return dict[str, torch.utils.data.DataLoader]. Example keys: "train", "val", "test".
preprocess(sample, hparams, flag='') Transform a raw dataset sample.
Parameters:
- sample (Any): raw sample.
- hparams (dict): hyperparameters.
- flag (str, optional): mode flag.
Returns: Tensor or iterable of tensors.
collate(samples) Collate a batch of preprocessed samples.
Parameters:
- samples (Iterable[Tensor])
Returns: Tensor or iterable of tensors.
network Must return a torch.nn.Module. Inputs and outputs must use (batch_size, ...) format.
optimizer Must return a torch.optim.Optimizer for self.network.parameters().
loss(outputs, targets, hparams) Compute loss for a batch.
Parameters:
- outputs (Tensor or iterable)
- targets (Tensor or iterable)
- hparams (dict)
Returns: dict[str, Any] containing key "loss".
postprocess(outputs, hparams) Convert network outputs into predictions.
Parameters:
- outputs (Tensor or iterable)
- hparams (dict)
Returns: predictions (Any).

Default Methods

Method Description
scheduler Learning rate scheduler. May return None, torch.optim.lr_scheduler.LRScheduler, or ReduceLROnPlateau. Default is None.
backward(batch, gas) Backpropagation step.
Parameters:
- batch (dict[str, Any]): must contain key "loss".
- gas (int): gradient accumulation steps.
metrics(batches, hparams) Compute evaluation metrics.
Parameters:
- batches (deque[dict[str, Any]]): batch results.
- hparams (dict)
Returns: dict[str, Any]. Default computes average loss.
checkpoint(epoch, metrics, out) Checkpoint step. Saves model if loss improves.
Parameters:
- epoch (int): epoch number.
- metrics (dict[str, float]): validation metrics.
- out (str or None): output path to save checkpoints.
Returns: bool indicating early stopping.
move(data) Move data to current device. Supports Tensor, list, tuple, dict.
detach(data) Detach data from computation graph. Supports Tensor, list, tuple, dict.

Concrete Methods

Method Description
TorchABC(device=None, logger=print, hparams=None, **kwargs) Initialize the model.
Parameters:
- device (str or torch.device, optional): computation device. Defaults to CUDA if available, otherwise MPS or CPU.
- logger (Callable[[dict], None], optional): logging function. Defaults to print.
- hparams (dict, optional): dictionary of hyperparameters.
- kwargs: additional attributes stored in the instance.
train(epochs, gas=1, mas=None, on='train', val='val', out=None) Train the model.
Parameters:
- epochs (int): number of training epochs.
- gas (int, optional): gradient accumulation steps. Defaults to 1.
- mas (int, optional): metrics accumulation steps. Defaults to gas.
- on (str, optional): training dataloader name. Default "train".
- val (str, optional): validation dataloader name. Default "val". If None, validation is skipped.
- out (str, optional): output path to save checkpoints.
eval(on) Evaluate the model.
Parameters:
- on (str): dataloader name.
Returns: dict[str, float] of evaluation metrics.
__call__(samples) Run inference on raw samples.
Parameters:
- samples (Iterable[Any]): raw samples.
Returns: postprocessed predictions.
save(path) Save a checkpoint.
Parameters:
- path (str): file path.
load(path) Load a checkpoint.
Parameters:
- path (str): file path.

Examples

Get started with simple self-contained examples:

Run the examples

Install the dependencies

poetry install --with examples

Run the examples by replacing <name> with one of the filenames in the examples folder

poetry run python examples/<name>.py

Contribute

Contributions are welcome! Submit pull requests with new examples or improvements to the core TorchABC class itself.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchabc-0.6.3.tar.gz (10.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchabc-0.6.3-py3-none-any.whl (9.6 kB view details)

Uploaded Python 3

File details

Details for the file torchabc-0.6.3.tar.gz.

File metadata

  • Download URL: torchabc-0.6.3.tar.gz
  • Upload date:
  • Size: 10.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.2 CPython/3.9.13 Darwin/23.5.0

File hashes

Hashes for torchabc-0.6.3.tar.gz
Algorithm Hash digest
SHA256 4b8cd9dcbd6266e6c05a8df90be0ac85b92334a6efb0a31ef08853cd3010f311
MD5 8ad2a4f64ac4fcf88ee2d10a20cb6495
BLAKE2b-256 b81733e9c8d7cadc9d59421af1566c546e130d3a476b0f3b206215a4d4d2e51a

See more details on using hashes here.

File details

Details for the file torchabc-0.6.3-py3-none-any.whl.

File metadata

  • Download URL: torchabc-0.6.3-py3-none-any.whl
  • Upload date:
  • Size: 9.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.2 CPython/3.9.13 Darwin/23.5.0

File hashes

Hashes for torchabc-0.6.3-py3-none-any.whl
Algorithm Hash digest
SHA256 67373da44e11e7a73c4269524a77ec1564357739193cdcfd4a0e9502ac1ceceb
MD5 cf898a96618b4d9f6960ba51031f8f60
BLAKE2b-256 0b266b0ced06a4f15a2e50a9efb1c65758923a2a3b79fa4ff106fb153b100f39

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page