A simple abstract class for training and inference in PyTorch.
Project description
TorchABC
torchabc is a lightweight package that provides an Abstract Base Class (ABC) to structure PyTorch projects and keep code well organized.
The core of the package is the TorchABC class. This class defines the abstract training and inference workflows and must be subclassed to implement a concrete logic.
This package has no extra dependencies beyond PyTorch and it consists of a simple self-contained file. It is ideal for research, prototyping, and teaching.
Structure
The TorchABC class structures a project into the following main steps:
- Dataloaders - load raw data samples.
- Preprocess – transform raw samples.
- Collate - batch preprocessed samples.
- Network - compute model outputs.
- Loss - compute error against targets.
- Optimizer - update model parameters.
- Postprocess - transform outputs into predictions.
Each step corresponds to an abstract method in TorchABC. To use TorchABC, create a concrete subclass and implement these methods.
Quick start
Install the package.
pip install torchabc
Generate a template using the command line interface.
torchabc --create template.py --min
Fill out the template by implementing the methods below. The documentation of each method is available here.
import torch
from torchabc import TorchABC
from functools import cached_property
class MyModel(TorchABC):
@cached_property
def dataloaders(self):
raise NotImplementedError
@staticmethod
def preprocess(sample, hparams, flag=''):
return sample
@staticmethod
def collate(samples):
return torch.utils.data.default_collate(samples)
@cached_property
def network(self):
raise NotImplementedError
@staticmethod
def loss(outputs, targets, hparams):
raise NotImplementedError
@cached_property
def optimizer(self):
raise NotImplementedError
@staticmethod
def postprocess(outputs, hparams):
return outputs
Usage
Once a subclass of TorchABC is implemented, it can be used for training, evaluation, checkpointing, and inference.
Initialization
model = MyModel()
Initialize the model.
Training
model.train(epochs=5, on="train", val="val")
Train the model for 5 epochs using the train and val dataloaders.
Evaluation
metrics = model.eval(on="test")
Evaluate on the test dataloader and return metrics.
Checkpoints
model.save("checkpoint.pth")
model.load("checkpoint.pth")
Save and restore the model state.
Inference
preds = model(samples)
Run predictions on raw input samples.
API Reference
The TorchABC class defines a standard workflow for PyTorch projects. Some methods are abstract (must be implemented in subclasses), others are optional (can be overridden but have defaults), and a few are concrete (should not be overridden).
Abstract Methods
| Method | Description |
|---|---|
dataloaders |
Must return dict[str, torch.utils.data.DataLoader]. Example keys: "train", "val", "test". |
preprocess(sample, hparams, flag='') |
Transform a raw dataset sample. Parameters: - sample (Any): raw sample.- hparams (dict): hyperparameters.- flag (str, optional): mode flag.Returns: Tensor or iterable of tensors. |
collate(samples) |
Collate a batch of preprocessed samples. Parameters: - samples (Iterable[Tensor])Returns: Tensor or iterable of tensors. |
network |
Must return a torch.nn.Module. Inputs and outputs must use (batch_size, ...) format. |
optimizer |
Must return a torch.optim.Optimizer for self.network.parameters(). |
loss(outputs, targets, hparams) |
Compute loss for a batch. Parameters: - outputs (Tensor or iterable)- targets (Tensor or iterable)- hparams (dict)Returns: dict[str, Any] containing key "loss". |
postprocess(outputs, hparams) |
Convert network outputs into predictions. Parameters: - outputs (Tensor or iterable)- hparams (dict)Returns: predictions ( Any). |
Default Methods
| Method | Description |
|---|---|
scheduler |
Learning rate scheduler. May return None, torch.optim.lr_scheduler.LRScheduler, or ReduceLROnPlateau. Default is None. |
backward(batch, gas) |
Backpropagation step. Parameters: - batch (dict[str, Any]): must contain key "loss".- gas (int): gradient accumulation steps. |
metrics(batches, hparams) |
Compute evaluation metrics. Parameters: - batches (deque[dict[str, Any]]): batch results.- hparams (dict)Returns: dict[str, Any]. Default computes average loss. |
checkpoint(epoch, metrics, out) |
Checkpoint step. Saves model if loss improves. Parameters: - epoch (int): epoch number.- metrics (dict[str, float]): validation metrics.- out (str or None): output path to save checkpoints.Returns: bool indicating early stopping. |
move(data) |
Move data to current device. Supports Tensor, list, tuple, dict. |
detach(data) |
Detach data from computation graph. Supports Tensor, list, tuple, dict. |
Concrete Methods
| Method | Description |
|---|---|
TorchABC(device=None, logger=print, hparams=None, **kwargs) |
Initialize the model. Parameters: - device (str or torch.device, optional): computation device. Defaults to CUDA if available, otherwise MPS or CPU.- logger (Callable[[dict], None], optional): logging function. Defaults to print.- hparams (dict, optional): dictionary of hyperparameters.- kwargs: additional attributes stored in the instance. |
train(epochs, gas=1, mas=None, on='train', val='val', out=None) |
Train the model. Parameters: - epochs (int): number of training epochs.- gas (int, optional): gradient accumulation steps. Defaults to 1.- mas (int, optional): metrics accumulation steps. Defaults to gas.- on (str, optional): training dataloader name. Default "train".- val (str, optional): validation dataloader name. Default "val". If None, validation is skipped.- out (str, optional): output path to save checkpoints. |
eval(on) |
Evaluate the model. Parameters: - on (str): dataloader name.Returns: dict[str, float] of evaluation metrics. |
__call__(samples) |
Run inference on raw samples. Parameters: - samples (Iterable[Any]): raw samples.Returns: postprocessed predictions. |
save(path) |
Save a checkpoint. Parameters: - path (str): file path. |
load(path) |
Load a checkpoint. Parameters: - path (str): file path. |
Examples
Get started with simple self-contained examples:
Run the examples
Install the dependencies
poetry install --with examples
Run the examples by replacing <name> with one of the filenames in the examples folder
poetry run python examples/<name>.py
Contribute
Contributions are welcome! Submit pull requests with new examples or improvements to the core TorchABC class itself.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchabc-0.6.3.tar.gz.
File metadata
- Download URL: torchabc-0.6.3.tar.gz
- Upload date:
- Size: 10.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.2 CPython/3.9.13 Darwin/23.5.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4b8cd9dcbd6266e6c05a8df90be0ac85b92334a6efb0a31ef08853cd3010f311
|
|
| MD5 |
8ad2a4f64ac4fcf88ee2d10a20cb6495
|
|
| BLAKE2b-256 |
b81733e9c8d7cadc9d59421af1566c546e130d3a476b0f3b206215a4d4d2e51a
|
File details
Details for the file torchabc-0.6.3-py3-none-any.whl.
File metadata
- Download URL: torchabc-0.6.3-py3-none-any.whl
- Upload date:
- Size: 9.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.2 CPython/3.9.13 Darwin/23.5.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
67373da44e11e7a73c4269524a77ec1564357739193cdcfd4a0e9502ac1ceceb
|
|
| MD5 |
cf898a96618b4d9f6960ba51031f8f60
|
|
| BLAKE2b-256 |
0b266b0ced06a4f15a2e50a9efb1c65758923a2a3b79fa4ff106fb153b100f39
|