Skip to main content

Minimal solver for deep learning

Project description

Flashy

tests badge linter badge docs badge

Motivations

We noticed we reused the same structure over and over again in all of our research projects. PyTorch-Lightning is vastly over engineered and due to its complexity does not allow the same level of hackability. Flashy aims to be an alternative. We do not claim it will fit all use cases, and our first goal is for it to fit ours. We aim at keeping the code simple enough that you can just inherit and override behaviors, or even copy paste what you want into your project.

Definitions

At the core of Flashy is the Solver. The Solver takes care of 2 things only:

  • logging metrics, to multiple backends (file logs, tensorboard or WanDB), with custom formatting,
  • checkpointing and automatically tracking stateful part of the solver.

Beyond those core features, Flashy also provide distributed training utilities, in particular alternatives to DistributedDataParallel, which can break with complex workflows, along with simple wrappers around DataLoader to support distributed training.

Flashy is epoch based, which might sound outdated to some of you. Think of epochs not as a single pass over your dataset, but as the atomic unit of time for workflow management. Each epoch end is marked by a call to flashy.BaseSolver.commit(save_checkpoint=True).

Each epoch is composed of a number of stages, for instance train, valid, test etc, and do not need to be the same each time. Stages are a convenience to help with automatically reporting metrics with appropriate metadata.

Dependencies and installation

Flashy assume PyTorch is used along with Dora. You could use it without PyTorch with minor changes to flashy/state.py. Dora is builtin in a few places and shouldn't be too hard to remove, although we warmly recommend using it. Flashy requires at least Python 3.8.

To install Flashy, run the following

# For the moment we recommend having bleeding edge versions of Dora and Submitit
pip install -U git+https://github.com/facebookincubator/submitit@main#egg=submitit
pip install -U git+https://git@github.com/facebookresearch/dora#egg=dora-search
# Now let's install Flashy!
pip install git+ssh://git@github.com/facebookresearch/flashy.git#egg=flashy

To install Flashy for development, you can clone this repository and run

make install

Getting Started

We will assume you are using Hydra. You will need to be familiar with Dora. Let's build a very basic project, called basic, with the following structure:

basic/
  conf/
    config.yaml
  train.py
  __init__.py

This project is provided in the examples folder. For config.yaml, we can start with the basic:

epochs: 10
lr: 0.1

dora:
  # Output folder for all the artifacts of an experiment.
  dir: /tmp/flashy_basic_${oc.env:USER}/outputs

__init__.py is just empty. train.py contains most of the logic:

import torch
from dora import hydra_main
import flashy


class Solver(flashy.BaseSolver):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.model = torch.nn.Linear(32, 1)
        self.optim = torch.optim.Adam(self.model.parameters(), lr=cfg.lr)
        self.best_state = {}
        # register_stateful supports any attribute. On checkpoints loading,
        # it will try to use inplace method when possible (i.e. Modules, lists, dicts).
        self.register_stateful('model', 'optim', 'best_state')
        self.init_tensorboard()  # all metrics will be reported to stderr and tensorboard.

    def run(self):
        self.restore()  # load checkpoint
        for epoch in range(self.epoch, self.cfg.epochs):
            # Stages are used for automatic metric reporting to Dora, and it also
            # allows tuning how metrics are formatted.
            self.run_stage('train', self.train)
            # Commit will send the metrics to Dora and save checkpoints by default.
            self.commit(save_checkpoint=epoch % 2 == 1)

    def train(self):
        # this is super dumb, checkout `examples/cifar/solver.py` for more advance usage!
        x = torch.randn(4, 32)
        y = self.model(x)
        loss = y.abs().mean()
        loss.backward()
        self.optim.step()
        self.optim.zero_grad()
        return {'loss': loss.item()}


@hydra_main(config_path='config', config_name='config', version_base='1.1')
def main(cfg):
    # Setup logging both to XP specific folder, and to stderr.
    flashy.setup_logging()
    # Initialize distributed training, no need to specify anything when using Dora.
    flashy.distrib.init()
    solver = Solver(cfg)
    solver.run()


if __name__ == '__main__':
    main()

From the folder containing basic, you can launch training with

dora -P basic run
dora run  # if no other package contains a train.py file in the current folder.

Example

See examples/cifar/solver.py for a more advanced example, with real training and distributed. When running examples from the examples/ folder, you must pass the package you want to run to Dora, as there are multiple possibilities:

dora -P [basic|cifar] run

API

Checkout Flashy API Documentation

Licence

Flashy is provided under the MIT license, which can be found in the LICENSE file in the root of the repository. Parts of flashy.loggers.utils were adapted from PyTorch-Lightning, originally under the Apache 2.0 License, see flashy/loggers/utils.py for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flashy-0.0.2.tar.gz (72.4 kB view details)

Uploaded Source

File details

Details for the file flashy-0.0.2.tar.gz.

File metadata

  • Download URL: flashy-0.0.2.tar.gz
  • Upload date:
  • Size: 72.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/57.4.0 requests-toolbelt/0.9.1 tqdm/4.50.2 CPython/3.8.12

File hashes

Hashes for flashy-0.0.2.tar.gz
Algorithm Hash digest
SHA256 e7f5afc3130d5af5ddbce1d22f4a729f1900055488cb0b531cdbc190dc79e9fd
MD5 360fb98eab621dd171e5d64bf862d9f9
BLAKE2b-256 321b262b98d72cfafb16e1fbdb4529fd806da9f7975ca67531854ab2a75f7d22

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page