Skip to main content

A simple abstract class for training and inference in PyTorch.

Project description

TorchABC

TorchABC is an abstract class for training and inference in PyTorch that helps you keep your code well organized. It is a minimalist version of pytorch-lightning, it depends on torch only, and it consists of a simple self-contained file.

Usage

Create a concrete class derived from TorchABC following the template below. Next, you can use your class as follows.

Initialization

Initialize your class with

model = ClassName(device = None, logger = print, hparams = None, **kwargs)

where

  • device is the torch.device to use. Defaults to None, which will try CUDA, then MPS, and finally fall back to CPU.
  • logger is a logging function that takes a dictionary in input. The default prints to standard output. You can can easily log with wandb or with any other custom logger.
  • hparams is a dictionary of hyperparameters used internally by your class. These hyperparameters are persistent as they will be saved in the model's checkpoints.
  • kwargs are additional arguments to store in the class attributes. These arguments are ephemeral as they will not be saved in the model's checkpoints.

Training

Train the model with

model.train(epochs, gas = 1, on = 'train', val = 'val')

where

  • epochs is the number of training epochs to perform.
  • gas is the number of gradient accumulation steps.
  • on is the name of the training dataloader.
  • val is the name of the validation dataloader.

Evaluation

Compute the evaluation metrics with

model.eval(on)

where

  • on is the name of the dataloader to evaluate on.

Inference

Predict with

model.predict(samples)

where

  • samples is an iterable of raw data samples.

Checkpoints

Save and load the model.

model.save(checkpoint)
model.load(checkpoint)

where

  • checkpoint is the path to the model checkpoint.

Quick start

Install the package.

pip install torchabc

Generate a template using the command line interface.

torchabc --create template.py

Fill out the template.

import torch
from torchabc import TorchABC
from functools import cached_property


class ClassName(TorchABC):
    """A concrete subclass of the TorchABC abstract class.

    Use this template to implement your own model by following these steps:
      - replace ClassName with the name of your model,
      - replace this docstring with a description of your model,
      - implement the methods below to define the core logic of your model.
    """
    
    @cached_property
    def dataloaders(self):
        """The dataloaders.

        Returns a dictionary containing multiple `DataLoader` instances. 
        The keys of the dictionary are custom names (e.g., 'train', 'val', 'test'), 
        and the values are the corresponding `torch.utils.data.DataLoader` objects.
        """
        raise NotImplementedError
    
    @staticmethod
    def preprocess(sample, hparams, flag=''):
        """The preprocessing step.

        Transforms a raw sample from a `torch.utils.data.Dataset`. This method is 
        intended to be passed as the `transform` (or similar) argument of a `Dataset`.

        Parameters
        ----------
        sample : Any
            The raw sample.
        hparams : dict
            The hyperparameters.
        flag : str, optional
            A custom flag indicating how to transform the sample. 
            An empty flag must transform the sample for inference.

        Returns
        -------
        Union[Tensor, Iterable[Tensor]]
            The preprocessed sample.
        """
        return sample

    @staticmethod
    def collate(samples):
        """The collating step.

        Collates a batch of preprocessed samples. This method is intended to be 
        passed as the `collate_fn` argument of a `Dataloader`.

        Parameters
        ----------
        samples : Iterable[Tensor]
            The preprocessed samples.

        Returns
        -------
        Union[Tensor, Iterable[Tensor]]
            The batch of collated samples.
        """
        return torch.utils.data.default_collate(samples)

    @cached_property
    def network(self):
        """The neural network.

        Returns a `torch.nn.Module` whose input and output tensors assume 
        the batch size is the first dimension: (batch_size, ...).
        """
        raise NotImplementedError
    
    @cached_property
    def optimizer(self):
        """The optimizer for training the network.

        Returns a `torch.optim.Optimizer` configured for 
        `self.network.parameters()`.
        """
        raise NotImplementedError
    
    @cached_property
    def scheduler(self):
        """The learning rate scheduler for the optimizer.

        Returns a `torch.optim.lr_scheduler.LRScheduler` or 
        `torch.optim.lr_scheduler.ReduceLROnPlateau` configured 
        for `self.optimizer`.
        """
        return None
    
    @staticmethod
    def accumulate(outputs, targets, hparams, accumulator=None):
        """The accumulation step.

        Accumulates batch statistics that will be provided when calculating 
        the loss and other metrics.

        Parameters
        ----------
        outputs : Union[Tensor, Iterable[Tensor]]
            The outputs returned by `self.network`.
        targets : Union[Tensor, Iterable[Tensor]]
            The target values.
        hparams : dict
            The hyperparameters.
        accumulator : Any
            The previous return value of this function. 
            If None, this is the first call.

        Returns
        -------
        Any
            The accumulated batch statistics.
        """
        raise NotImplementedError

    @staticmethod
    def metrics(accumulator, hparams):
        """The evaluation metrics.

        Computes the loss and additional evaluation metrics.

        Parameters
        ----------
        accumulator : Any
            The accumulated batch statistics.

        Returns
        -------
        Dict[str, Union[Tensor, float]]
            A dictionary of evaluation metrics. This dictionary must contain
            the key 'loss' whose value is used to train the network.
        """
        raise NotImplementedError

    @staticmethod
    def postprocess(outputs, hparams):
        """The postprocessing step.

        Transforms the outputs into postprocessed predictions. 

        Parameters
        ----------
        outputs : Union[Tensor, Iterable[Tensor]]
            The outputs returned by `self.network`.
        hparams : dict
            The hyperparameters.

        Returns
        -------
        Any
            The postprocessed predictions.
        """
        return outputs

    def checkpoint(self, epoch, metrics):
        """The checkpointing step.

        Performs checkpointing at the end of each epoch.

        Parameters
        ----------
        epoch : int
            The epoch number, starting at 1.
        metrics : Dict[str, float]
            Dictionary of validation metrics.

        Returns
        -------
        bool
            If this function returns True, training stops.
        """
        return False

Examples

Get started with simple self-contained examples:

Run the examples

Install the dependencies

poetry install --with examples

Run the examples by replacing <name> with one of the filenames in the examples folder

poetry run python examples/<name>.py

Contribute

Contributions are welcome! Submit pull requests with new examples or improvements to the core TorchABC class itself.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchabc-0.5.0.tar.gz (6.8 kB view details)

Uploaded Source

Built Distribution

torchabc-0.5.0-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file torchabc-0.5.0.tar.gz.

File metadata

  • Download URL: torchabc-0.5.0.tar.gz
  • Upload date:
  • Size: 6.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.2 CPython/3.9.13 Darwin/23.5.0

File hashes

Hashes for torchabc-0.5.0.tar.gz
Algorithm Hash digest
SHA256 9badac505666da27a59a8d6dbdae9c5e1fbc9d0d9ccf5e4d472d05c8def7e43b
MD5 d8127f3d3d018329e64a52937da4b30d
BLAKE2b-256 37365973215a2f58ca8d0a5c5ff7a6566aec3cfddb9916750669ef1b87d968cf

See more details on using hashes here.

File details

Details for the file torchabc-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: torchabc-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.2 CPython/3.9.13 Darwin/23.5.0

File hashes

Hashes for torchabc-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1a8ed73b6d21f3d4b9b0480f10e22e270c737daa3449df7bf6d0eb9563d92cda
MD5 4641ebf06d60e48fc056164df9ccc6ea
BLAKE2b-256 a8430242a94788734da31458c4d656affc2a167d486cde84f6fda7781e189de3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page