Skip to main content

A collection of core machine learning tools

Project description

K-Scale Open Source Robotics

License Discord Wiki
python pytorch black ruff
Python Checks


mlfab

What is this?

This is a framework for trying out machine learning ideas.

Getting Started

Install the package using:

pip install mlfab

Or, to install the latest branch:

pip install 'mlfab @ git+https://github.com/kscalelabs/mlfab.git@master'

Simple Example

This framework provides an abstraction for quickly implementing and training PyTorch models. The workhorse for doing this is mlfab.Task, which wraps all of the training logic into a single cohesive unit. We can override functions on that method to get special functionality, but the default functionality is often good enough. Here's an example for training an MNIST model:

from dataclasses import dataclass

import torch.nn.functional as F
from dpshdl.dataset import Dataset
from dpshdl.impl.mnist import MNIST
from torch import Tensor, nn
from torch.optim.optimizer import Optimizer

import mlfab


@dataclass(kw_only=True)
class Config(mlfab.Config):
    in_dim: int = mlfab.field(1, help="Number of input dimensions")
    learning_rate: float = mlfab.field(1e-3, help="Learning rate to use for optimizer")
    betas: tuple[float, float] = mlfab.field((0.9, 0.999), help="Beta values for Adam optimizer")
    weight_decay: float = mlfab.field(1e-4, help="Weight decay to use for the optimizer")
    warmup_steps: int = mlfab.field(100, help="Number of warmup steps to use for the optimizer")


class MnistClassification(mlfab.Task[Config]):
    def __init__(self, config: Config) -> None:
        super().__init__(config)

        self.model = nn.Sequential(
            nn.Conv2d(config.in_dim, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def set_loggers(self) -> None:
        self.add_logger(
            mlfab.StdoutLogger(),
            mlfab.TensorboardLogger(self.exp_dir),
        )

    def get_dataset(self, phase: mlfab.Phase) -> MNIST:
        root_dir = mlfab.get_data_dir() / "mnist"
        return MNIST(root_dir=root_dir, train=phase == "train")

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def get_loss(self, batch: tuple[Tensor, Tensor], state: mlfab.State) -> Tensor:
        x_bhw, y_b = batch
        x_bchw = (x_bhw.float() / 255.0).unsqueeze(1)
        yhat_bc = self(x_bchw)
        self.log_step(batch, yhat_bc, state)
        return F.cross_entropy(yhat_bc, y_b)

    def log_valid_step(self, batch: tuple[Tensor, Tensor], output: Tensor, state: mlfab.State) -> None:
        (x_bhw, y_b), yhat_bc = batch, output

        def get_label_strings() -> list[str]:
            ypred_b = yhat_bc.argmax(-1)
            return [f"ytrue={y_b[i]}, ypred={ypred_b[i]}" for i in range(len(y_b))]

        self.log_labeled_images("images", lambda: (x_bhw, get_label_strings()))


if __name__ == "__main__":
    # python -m examples.mnist
    MnistClassification.launch(Config(batch_size=16))

Let's break down each part individually.

Config

Tasks are parametrized using a config dataclass. The ml.field function is a lightweight wrapper around dataclasses.field which is a bit more ergonomic, and ml.Config is a bigger dataclass which contains a bunch of other options for configuring training.

@dataclass(kw_only=True)
class Config(mlfab.Config):
    in_dim: int = mlfab.field(1, help="Number of input dimensions")

Model

All tasks should subclass ml.Task and override the generic Config with the task-specific config. This is very important, not just because it makes your life easier by working nicely with your typechecker, but because the framework looks at the generic type when resolving the config for the given task.

class MnistClassification(mlfab.Task[Config]):
    def __init__(self, config: Config) -> None:
        super().__init__(config)

        self.model = nn.Sequential(
            nn.Conv2d(config.in_dim, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

Loggers

mlfab supports logging to multiple downstream loggers, and provides a bunch of helper functions for doing common logging operations, like rate limiting, converting image resolution to normal sizes, overlaying captions on images, and more.

If this function is not overridden, the task will just log to stdout.

def set_loggers(self) -> None:
    self.add_logger(
        mlfab.StdoutLogger(),
        mlfab.TensorboardLogger(self.exp_dir),
    )

Datasets

The task should return the dataset used for training, based on the phase. ml.Phase is a string literal with values in ["train", "valid"]. mlfab.get_data_dir() returns the data directory, which can be set in a configuration file which lives in ~/.mlfab.yml. The default configuration file will be written on first run if it doesn't exist yet.

def get_dataset(self, phase: mlfab.Phase) -> Dataset[tuple[Tensor, Tensor]]:
    root_dir = mlfab.get_data_dir() / "mnist"
    return MNIST(root_dir=root_dir, train=phase == "train")

Compute Loss

Each mlfab model should either implement the forward function, which should take a batch from the dataset and return the loss, or, if more control is desired, the get_loss function can be overridden.

def forward(self, x: Tensor) -> Tensor:
    return self.model(x)

def get_loss(self, batch: tuple[Tensor, Tensor], state: mlfab.State) -> Tensor:
    x_bhw, y_b = batch
    x_bchw = (x_bhw.float() / 255.0).unsqueeze(1)
    yhat_bc = self(x_bchw)
    self.log_step(batch, yhat_bc, state)
    return F.cross_entropy(yhat_bc, y_b)

Logging

When we call log_step in the get_loss function, it delegates to either log_train_step or log_valid_step, depending on what state.phase is. In this case, on each validation step we log images of the MNIST digits with the labels that our model predicts.

def log_valid_step(self, batch: tuple[Tensor, Tensor], output: Tensor, state: mlfab.State) -> None:
    (x_bhw, y_b), yhat_bc = batch, output

    def get_label_strings() -> list[str]:
        ypred_b = yhat_bc.argmax(-1)
        return [f"ytrue={y_b[i]}, ypred={ypred_b[i]}" for i in range(len(y_b))]

    self.log_labeled_images("images", lambda: (x_bhw, get_label_strings()))

Running

We can launch a training job using the launch class method. The config can be a Config object, or it can be the path to a config.yaml file located in the same directory as the task file. You can additionally provide the launcher argument, which supports training the model across multiple GPUs or nodes.

if __name__ == "__main__":
    MnistClassification.launch(Config(batch_size=16))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlfab-0.2.9.tar.gz (175.5 kB view details)

Uploaded Source

Built Distribution

mlfab-0.2.9-py3-none-any.whl (208.9 kB view details)

Uploaded Python 3

File details

Details for the file mlfab-0.2.9.tar.gz.

File metadata

  • Download URL: mlfab-0.2.9.tar.gz
  • Upload date:
  • Size: 175.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for mlfab-0.2.9.tar.gz
Algorithm Hash digest
SHA256 ff4d694c0893d1c68e83c96a79291cd6771f48e87215d877246b92964efe7bfd
MD5 72b93bf39992c40e5baaf51adedc75a6
BLAKE2b-256 6ba3c7ee9791301b0b5232fdc2c6f6929d9550dc901dcfa0b25f7b59b468fe88

See more details on using hashes here.

Provenance

The following attestation bundles were made for mlfab-0.2.9.tar.gz:

Publisher: publish.yml on kscalelabs/mlfab

Attestations:

File details

Details for the file mlfab-0.2.9-py3-none-any.whl.

File metadata

  • Download URL: mlfab-0.2.9-py3-none-any.whl
  • Upload date:
  • Size: 208.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for mlfab-0.2.9-py3-none-any.whl
Algorithm Hash digest
SHA256 c12b8ef58777239ed3eb4dbde927c05a499165cf9f24f7009e234c38ed021615
MD5 5e51e824cb322fb4a8d4811885cf7772
BLAKE2b-256 14a68dd29a68c36a00508fd752b1bdd3cd21fad0166e9633fb9a94a3f64ef9e3

See more details on using hashes here.

Provenance

The following attestation bundles were made for mlfab-0.2.9-py3-none-any.whl:

Publisher: publish.yml on kscalelabs/mlfab

Attestations:

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page