Skip to main content

Accelerator Module and Trainer based on Accelerate library for simple distributed train processes, inspired by PyTorch Lightning.

Project description

AcceleratorModule

Module similar to Lightning Module for distributed training, but with Accelerator 🤗

AcceleratorModule will take care of the heavy lifting of distributed training on many GPUs. Accelerate is quite simple, and it has many adventages over PyTorch Lightning, mainly because it doesn't abstract the low level part of the training loop, so you can customize it however you want. The main idea of this little project is to have a standard way to make distributed training. This module let's you:

  • Define the logic involved for training.
  • Define the logic involved for evaluation.
  • Save checkpoints to recover training progress.
  • Save best model by evaluating best average validation loss at the end of every epoch.
  • Define the hyperparameters in a simple YAML file.
  • Visualize training progress in TensorBoard (train and validation losses in one graph).

Installation

AcceleratorModule is available via pip:

pip install accmt

Module Structure

Import AcceleratorModule:

from accmt import AcceleratorModule

The AcceleratorModule class has 3 main methods:

  • forward: Defines the flow of data.
  • training_step: Defines the training logic up to the loss function.
  • validation_step: Defines the validation logic up to the loss function.

The structure looks like this:

class ExampleModule(AcceleratorModule):
    def __init__(self):
        self.model = ...

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch):
        x, y = batch
        # ...
        return train_loss

    def validation_step(self, batch):
        x, y = batch
        # ...
        return val_loss

A forward method is not required, although training_step and/or validation_step are mandatory. AcceleratorModule must have a property self.model derived from nn.Module (PyTorch).

To train this Module, you need a Trainer class:

from accmt import Trainer

trainer = Trainer(
    hps_file_config="hps_config.yaml",
    checkpoint="checkpoint_folder"
)

Finally, we can train our model by using the .fit() function, providing our AcceleratorModule and the train and validation datasets (from PyTorch):

trainer.fit(module, train_dataset, val_dataset)

Run

To run training, we use Accelerator 🤗 in the background, so you must use the corresponding CLI:

accelerate launch train.py

You can use any Accelerate configuration that you want 🤗 (DDP, FSDP or DeepSpeed).

Docs coming soon...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

accmt-0.0.7.tar.gz (11.4 kB view hashes)

Uploaded Source

Built Distribution

accmt-0.0.7-py3-none-any.whl (11.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page