A simple abstract class for training and inference in PyTorch.
Project description
TorchABC
TorchABC
is an abstract class for training and inference in PyTorch that helps you keep your code well organized. It is a minimalist version of pytorch-lightning, it depends on torch only, and it consists of a simple self-contained file.
Usage
Create a concrete class derived from TorchABC
following the template below. Next, you can use your class as follows.
Initialization
Initialize your class with
model = ClassName(device = None, logger = print, hparams = None, **kwargs)
where
device
is thetorch.device
to use. Defaults toNone
, which will try CUDA, then MPS, and finally fall back to CPU.logger
is a logging function that takes a dictionary in input. The default prints to standard output. You can can easily log with wandb or with any other custom logger.hparams
is a dictionary of hyperparameters used internally by your class. These hyperparameters are persistent as they will be saved in the model's checkpoints.kwargs
are additional arguments to store in the class attributes. These arguments are ephemeral as they will not be saved in the model's checkpoints.
Training
Train the model with
model.train(epochs, gas = 1, on = 'train', val = 'val')
where
epochs
is the number of training epochs to perform.gas
is the number of gradient accumulation steps.on
is the name of the training dataloader.val
is the name of the validation dataloader.
Evaluation
Compute the evaluation metrics with
model.eval(on)
where
on
is the name of the dataloader to evaluate on.
Inference
Predict with
model.predict(samples)
where
samples
is an iterable of raw data samples.
Checkpoints
Save and load the model.
model.save(checkpoint)
model.load(checkpoint)
where
checkpoint
is the path to the model checkpoint.
Quick start
Install the package.
pip install torchabc
Generate a template using the command line interface.
torchabc --create template.py
Fill out the template.
import torch
from torchabc import TorchABC
from functools import cached_property
class ClassName(TorchABC):
"""A concrete subclass of the TorchABC abstract class.
Use this template to implement your own model by following these steps:
- replace ClassName with the name of your model,
- replace this docstring with a description of your model,
- implement the methods below to define the core logic of your model.
"""
@cached_property
def dataloaders(self):
"""The dataloaders.
Returns a dictionary containing multiple `DataLoader` instances.
The keys of the dictionary are custom names (e.g., 'train', 'val', 'test'),
and the values are the corresponding `torch.utils.data.DataLoader` objects.
"""
raise NotImplementedError
@staticmethod
def preprocess(sample, hparams, flag=''):
"""The preprocessing step.
Transforms a raw sample from a `torch.utils.data.Dataset`. This method is
intended to be passed as the `transform` (or similar) argument of a `Dataset`.
Parameters
----------
sample : Any
The raw sample.
hparams : dict
The hyperparameters.
flag : str, optional
A custom flag indicating how to transform the sample.
An empty flag must transform the sample for inference.
Returns
-------
Union[Tensor, Iterable[Tensor]]
The preprocessed sample.
"""
return sample
@staticmethod
def collate(samples):
"""The collating step.
Collates a batch of preprocessed samples. This method is intended to be
passed as the `collate_fn` argument of a `Dataloader`.
Parameters
----------
samples : Iterable[Tensor]
The preprocessed samples.
Returns
-------
Union[Tensor, Iterable[Tensor]]
The batch of collated samples.
"""
return torch.utils.data.default_collate(samples)
@cached_property
def network(self):
"""The neural network.
Returns a `torch.nn.Module` whose input and output tensors assume
the batch size is the first dimension: (batch_size, ...).
"""
raise NotImplementedError
@cached_property
def optimizer(self):
"""The optimizer for training the network.
Returns a `torch.optim.Optimizer` configured for
`self.network.parameters()`.
"""
raise NotImplementedError
@cached_property
def scheduler(self):
"""The learning rate scheduler for the optimizer.
Returns a `torch.optim.lr_scheduler.LRScheduler` or
`torch.optim.lr_scheduler.ReduceLROnPlateau` configured
for `self.optimizer`.
"""
return None
@staticmethod
def accumulate(outputs, targets, hparams, accumulator=None):
"""The accumulation step.
Accumulates batch statistics that will be provided when calculating
the loss and other metrics.
Parameters
----------
outputs : Union[Tensor, Iterable[Tensor]]
The outputs returned by `self.network`.
targets : Union[Tensor, Iterable[Tensor]]
The target values.
hparams : dict
The hyperparameters.
accumulator : Any
The previous return value of this function.
If None, this is the first call.
Returns
-------
Any
The accumulated batch statistics.
"""
raise NotImplementedError
@staticmethod
def metrics(accumulator, hparams):
"""The evaluation metrics.
Computes the loss and additional evaluation metrics.
Parameters
----------
accumulator : Any
The accumulated batch statistics.
Returns
-------
Dict[str, Union[Tensor, float]]
A dictionary of evaluation metrics. This dictionary must contain
the key 'loss' whose value is used to train the network.
"""
raise NotImplementedError
@staticmethod
def postprocess(outputs, hparams):
"""The postprocessing step.
Transforms the outputs into postprocessed predictions.
Parameters
----------
outputs : Union[Tensor, Iterable[Tensor]]
The outputs returned by `self.network`.
hparams : dict
The hyperparameters.
Returns
-------
Any
The postprocessed predictions.
"""
return outputs
def checkpoint(self, epoch, metrics):
"""The checkpointing step.
Performs checkpointing at the end of each epoch.
Parameters
----------
epoch : int
The epoch number, starting at 1.
metrics : Dict[str, float]
Dictionary of validation metrics.
Returns
-------
bool
If this function returns True, training stops.
"""
return False
Examples
Get started with simple self-contained examples:
Run the examples
Install the dependencies
poetry install --with examples
Run the examples by replacing <name>
with one of the filenames in the examples folder
poetry run python examples/<name>.py
Contribute
Contributions are welcome! Submit pull requests with new examples or improvements to the core TorchABC
class itself.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchabc-0.5.0.tar.gz
.
File metadata
- Download URL: torchabc-0.5.0.tar.gz
- Upload date:
- Size: 6.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.2 CPython/3.9.13 Darwin/23.5.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
9badac505666da27a59a8d6dbdae9c5e1fbc9d0d9ccf5e4d472d05c8def7e43b
|
|
MD5 |
d8127f3d3d018329e64a52937da4b30d
|
|
BLAKE2b-256 |
37365973215a2f58ca8d0a5c5ff7a6566aec3cfddb9916750669ef1b87d968cf
|
File details
Details for the file torchabc-0.5.0-py3-none-any.whl
.
File metadata
- Download URL: torchabc-0.5.0-py3-none-any.whl
- Upload date:
- Size: 9.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.2 CPython/3.9.13 Darwin/23.5.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
1a8ed73b6d21f3d4b9b0480f10e22e270c737daa3449df7bf6d0eb9563d92cda
|
|
MD5 |
4641ebf06d60e48fc056164df9ccc6ea
|
|
BLAKE2b-256 |
a8430242a94788734da31458c4d656affc2a167d486cde84f6fda7781e189de3
|