Skip to main content

A simple abstract class for training and inference in PyTorch.

Project description

TorchABC

torchabc is a lightweight package that provides an Abstract Base Class (ABC) to structure PyTorch projects and keep code well organized.

The core of the package is the TorchABC class. This class defines the abstract training and inference workflows and must be subclassed to implement a concrete logic.

This package has no extra dependencies beyond PyTorch and it consists of a simple self-contained file. It is ideal for research, prototyping, and teaching.

Structure

The TorchABC class structures a project into the following main steps:

diagram

  1. Dataloaders - load raw data.
  2. Preprocess – transform raw data into preprocessed samples.
  3. Collate - batch preprocessed samples.
  4. Network - compute the model's outputs for a single batch.
  5. Loss - compute the loss for a single batch.
  6. Optimizer - update the model's parameters.
  7. Scheduler - update the optimizer's parameters.
  8. Metrics - compute evaluation metrics from multiple batches.
  9. Postprocess - transform outputs into predictions.

Each step corresponds to an abstract method in TorchABC. To use TorchABC, create a concrete subclass.

Quick start

Install the package.

pip install torchabc

Generate a minimalistic template to fill out:

torchabc --create template.py --min
import torch
from torchabc import TorchABC
from functools import cached_property


class MyModel(TorchABC):
    
    @cached_property
    def dataloaders(self):
        raise NotImplementedError
    
    @staticmethod
    def preprocess(sample, hparams, flag=''):
        return sample

    @staticmethod
    def collate(samples):
        return torch.utils.data.default_collate(samples)

    @cached_property
    def network(self):
        raise NotImplementedError
    
    @staticmethod
    def loss(outputs, targets, hparams):
        raise NotImplementedError

    @cached_property
    def optimizer(self):
        raise NotImplementedError
    
    @cached_property
    def scheduler(self):
        return None
    
    @staticmethod
    def metrics(losses, hparams):
        return {"loss": sum(loss["loss"] for loss in losses) / len(losses)}

    @staticmethod
    def postprocess(outputs, hparams):
        return outputs

The full template with the documentation can be created with:

torchabc --create template.py
import torch
from torchabc import TorchABC
from functools import cached_property


class MyModel(TorchABC):
    """A concrete subclass of the TorchABC abstract class.

    Use this template to implement your own model by following these steps:
      - replace MyModel with the name of your model,
      - replace this docstring with a description of your model,
      - implement the methods below to define the core logic of your model.
    """
    
    @cached_property
    def dataloaders(self):
        """The dataloaders.

        Return a dictionary containing multiple `DataLoader` instances. 
        The keys of the dictionary are custom names (e.g., 'train', 'val', 'test'), 
        and the values are the corresponding `torch.utils.data.DataLoader` objects.
        """
        raise NotImplementedError
    
    @staticmethod
    def preprocess(sample, hparams, flag=''):
        """The preprocessing step.

        Transform a raw sample. This method is called when preprocessing raw samples 
        for inference. It can also be used in `self.dataloaders` with custom flags 
        for different behaviour (e.g., see examples/mnist.py for data augmentation).

        Parameters
        ----------
        sample : Any
            The raw sample.
        hparams : dict
            The hyperparameters.
        flag : str, optional
            When flag is empty, this method transforms a raw sample for inference.
            A custom flag can be used to specify a different behavior when using
            this method in `self.dataloaders` (e.g., see examples/mnist.py).

        Returns
        -------
        Union[Tensor, Iterable[Tensor]]
            The preprocessed sample.
        """
        return sample

    @staticmethod
    def collate(samples):
        """The collating step.

        Collate a batch of preprocessed samples.

        Parameters
        ----------
        samples : Iterable[Tensor]
            The preprocessed samples.

        Returns
        -------
        Union[Tensor, Iterable[Tensor]]
            The batch of collated samples.
        """
        return torch.utils.data.default_collate(samples)

    @cached_property
    def network(self):
        """The neural network.

        Return a `torch.nn.Module` whose input and output tensors assume 
        the batch size is the first dimension: (batch_size, ...).
        """
        raise NotImplementedError
    
    @staticmethod
    def loss(outputs, targets, hparams):
        """The loss function.

        Compute the loss and optional extra information for a single batch.
        The loss is used for training and all information are passed to `self.metrics`.

        Parameters
        ----------
        outputs : Union[Tensor, Iterable[Tensor]]
            The outputs returned by `self.network`.
        targets : Union[Tensor, Iterable[Tensor]]
            The target values.
        hparams : dict
            The hyperparameters.

        Returns
        -------
        dict[str, Any]
            Dictionary with key 'loss' and optional extra keys.
        """
        raise NotImplementedError

    @cached_property
    def optimizer(self):
        """The optimizer for training the network.

        Return a `torch.optim.Optimizer` configured for 
        `self.network.parameters()`.
        """
        raise NotImplementedError
    
    @cached_property
    def scheduler(self):
        """The learning rate scheduler for the optimizer.

        Return a `torch.optim.lr_scheduler.LRScheduler` or 
        `torch.optim.lr_scheduler.ReduceLROnPlateau` configured 
        for `self.optimizer`.
        """
        return None
    
    @staticmethod
    def metrics(losses, hparams):
        """The evaluation metrics.

        Compute evaluation metrics from the losses on multiple batches.

        Parameters
        ----------
        losses : deque[dict[str, Any]]
            List of dictionaries returned by `self.loss`.

        Returns
        -------
        dict[str, Any]
            Dictionary of evaluation metrics.
        """
        return {"loss": sum(loss["loss"] for loss in losses) / len(losses)}

    @staticmethod
    def postprocess(outputs, hparams):
        """The postprocessing step.

        Transform the outputs into postprocessed predictions. 

        Parameters
        ----------
        outputs : Union[Tensor, Iterable[Tensor]]
            The outputs returned by `self.network`.
        hparams : dict
            The hyperparameters.

        Returns
        -------
        Any
            The postprocessed predictions.
        """
        return outputs

Usage

Once a subclass of TorchABC is implemented, it can be used for training, evaluation, checkpointing, and inference.

Initialization

Initialize the model.

model = MyModel()

Training

Train the model for 5 epochs using the train and val dataloaders.

model.train(epochs=5, on="train", val="val")

Evaluation

Evaluate on the test dataloader and return metrics.

metrics = model.eval(on="test")

Checkpoints

Save and restore the model state.

model.save("checkpoint.pth")
model.load("checkpoint.pth")

Inference

Run predictions on raw input samples.

preds = model(rawdata)

Examples

Get started with simple self-contained examples:

Run the examples

Install the dependencies

poetry install --with examples

Run the examples by replacing <name> with one of the filenames in the examples folder

poetry run python examples/<name>.py

Contribute

Contributions are welcome! Submit pull requests with new examples or improvements to the core TorchABC class itself.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchabc-0.6.4.tar.gz (7.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchabc-0.6.4-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file torchabc-0.6.4.tar.gz.

File metadata

  • Download URL: torchabc-0.6.4.tar.gz
  • Upload date:
  • Size: 7.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.3 CPython/3.11.4 Darwin/25.2.0

File hashes

Hashes for torchabc-0.6.4.tar.gz
Algorithm Hash digest
SHA256 79d5d96bef0c6f1a52ff678fc1ef53efeeac54d6f904dd865fb58a3e292bf8a9
MD5 e9953947df949dcafff538765ddf3854
BLAKE2b-256 9dfdc5b5705f6705e15c61fd35bb92f54b16d57247016292b1999f62ed725785

See more details on using hashes here.

File details

Details for the file torchabc-0.6.4-py3-none-any.whl.

File metadata

  • Download URL: torchabc-0.6.4-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.3 CPython/3.11.4 Darwin/25.2.0

File hashes

Hashes for torchabc-0.6.4-py3-none-any.whl
Algorithm Hash digest
SHA256 37b1769bb39fdb43c2809f2392aed9e08b7054f91ea48673ebf95480322b0a47
MD5 d271d6f6e23cc00aaf63552c73f45fa8
BLAKE2b-256 567e25b97c7b89c1927926d4e034c8d8721618aae46b17df76d1c2bbcca742a3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page