Skip to main content

Simple Keras-inspired Training Loop for Pytorch.

Project description

Tests License PyPI - Version PyPI - Python Version

⚠️The package is under development, expect bugs and breaking changes!

Torch Training Loop

Simple Keras-inspired Training Loop for Pytorch.

Installation

pip install torch-training-loop

Features

  • Simple API for training Torch models;
  • Support training DataParallel and DistributedDataParallel models;
  • Support Keras-like callbacks for logging metrics to Tensorboard, model checkpoint, and early stopping;
  • Show training & validation progress via tqdm;
  • Display metrics during training & validation via torcheval.

Usage

This package consists of two main classes for training Torch models: TrainingLoop and SimpleTrainingStep. In order to train a torch model, you need to initiate these two classes:

import torch
from torch.optim import Adam
from torcheval.metrics import MulticlassAccuracy
from training_loop import TrainingLoop, SimpleTrainingStep
from training_loop.callbacks import EarlyStopping

model = ...
# Support training DataParallel models.
# model = DataParallel(model)

train_dataloader = ...
val_dataloader = ...

loop = TrainingLoop(
    model,
    step=SimpleTrainingStep(
        optimizer_fn=lambda params: Adam(params, lr=0.0001),
        loss=torch.nn.CrossEntropyLoss(),
        metrics=('accuracy', MulticlassAccuracy(num_classes=10)),
    ),
    device='cuda',
)
loop.fit(
    train_dataloader,
    val_dataloader,
    epochs=10,
    callbacks=[
        EarlyStopping(monitor='val_loss', mode='min', patience=20),
    ],
)

In the above example, initializing the SimpleTrainingStep class and calling the fit() method of the TrainingLoop class are very similar to that of Keras API. Additionally, you can also train DistributedDataParallel models to utilize multigpus setup. Currently, it only supports training on single-node multigpus machines.

from contextlib import contextmanager
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam
from torcheval.metrics import MulticlassAccuracy
from training_loop import SimpleTrainingStep
from training_loop.distributed import DistributedTrainingLoop


@contextmanager
def setup_ddp(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)
    try:
        yield
    finally:
        os.environ.pop('MASTER_ADDR')
        os.environ.pop('MASTER_PORT')
        dist.destroy_process_group()


def train_ddp(rank, world_size):
    with setup_ddp(rank, world_size):
        model = ...
        model = DDP(model, device_ids=[rank])

        train_loader = ...
        val_loader = ...

        loop = DistributedTrainingLoop(
            model,
            step=SimpleTrainingStep(
                optimizer_fn=lambda params: Adam(params, lr=0.0001),
                loss=torch.nn.CrossEntropyLoss(),
                metrics=('accuracy', MulticlassAccuracy(num_classes=10)),
            ),
            device=rank,
            rank=rank,
        )

        loop.fit(train_loader, val_loader, epochs=1)


def main():
    world_size = torch.cuda.device_count()

    mp.spawn(
        train_ddp,
        args=(world_size, ),
        nprocs=world_size,
        join=True,
    )

    return 0


if __name__ == '__main__':
    exit(main())

You can find more examples and documentation in the source code and in the examples folder.

License

Distributed under the MIT License. See LICENSE.txt for more information.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_training_loop-0.1.3.tar.gz (20.1 kB view details)

Uploaded Source

Built Distribution

torch_training_loop-0.1.3-py3-none-any.whl (24.8 kB view details)

Uploaded Python 3

File details

Details for the file torch_training_loop-0.1.3.tar.gz.

File metadata

  • Download URL: torch_training_loop-0.1.3.tar.gz
  • Upload date:
  • Size: 20.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.12.1 Linux/6.2.0-1018-azure

File hashes

Hashes for torch_training_loop-0.1.3.tar.gz
Algorithm Hash digest
SHA256 5d95a7d604c3b0993a7450969146aa1ba358aee3027f771db9a56c2b1ab7c4b0
MD5 18a6e2792908a11f5fd89ef46cd1a5c4
BLAKE2b-256 802e666a7da0f862ee67f71984ae700a91ce1fe833e4ced0c6ec8076573b99c0

See more details on using hashes here.

File details

Details for the file torch_training_loop-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: torch_training_loop-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 24.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.12.1 Linux/6.2.0-1018-azure

File hashes

Hashes for torch_training_loop-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 3e7484adbcdeaacd6ddaf6b84dd8e6755d2076509153c6fd0922fbc5f96e40e1
MD5 1c76e199060c341072ea5ca4a60f1f56
BLAKE2b-256 38474a1502da69d0c7536a5d98f2c7465fee2df681cfaac108ba50eed754a44a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page