A collection of core machine learning tools
Project description
What is this?
This is a framework for trying out machine learning ideas.
Getting Started
Install the package using:
pip install mlfab
Or, to install the latest branch:
pip install 'mlfab @ git+https://github.com/kscalelabs/mlfab.git@master'
Simple Example
This framework provides an abstraction for quickly implementing and training PyTorch models. The workhorse for doing this is mlfab.Task
, which wraps all of the training logic into a single cohesive unit. We can override functions on that method to get special functionality, but the default functionality is often good enough. Here's an example for training an MNIST model:
from dataclasses import dataclass
import torch.nn.functional as F
from dpshdl.dataset import Dataset
from dpshdl.impl.mnist import MNIST
from torch import Tensor, nn
from torch.optim.optimizer import Optimizer
import mlfab
@dataclass
class Config(mlfab.Config):
in_dim: int = mlfab.field(1, help="Number of input dimensions")
class MnistClassification(mlfab.Task[Config]):
def __init__(self, config: Config) -> None:
super().__init__(config)
self.model = nn.Sequential(
nn.Conv2d(config.in_dim, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Linear(128, 10),
)
def set_loggers(self) -> None:
self.add_logger(
mlfab.StdoutLogger(),
mlfab.TensorboardLogger(self.exp_dir),
)
def get_dataset(self, phase: mlfab.Phase) -> Dataset[tuple[Tensor, Tensor]]:
root_dir = mlfab.get_data_dir() / "mnist"
return MNIST(root_dir=root_dir, train=phase == "train")
def build_optimizer(self) -> Optimizer:
return mlfab.Adam.get(self, lr=1e-3)
def forward(self, x: Tensor) -> Tensor:
return self.model(x)
def get_loss(self, batch: tuple[Tensor, Tensor], state: mlfab.State) -> Tensor:
x, y = batch
yhat = self(x)
self.log_step(batch, yhat, state)
return F.cross_entropy(yhat, y.squeeze(-1))
def log_valid_step(self, batch: tuple[Tensor, Tensor], output: Tensor, state: mlfab.State) -> None:
(x, y), yhat = batch, output
def get_label_strings() -> list[str]:
ytrue, ypred = y.squeeze(-1), yhat.argmax(-1)
return [f"ytrue={ytrue[i]}, ypred={ypred[i]}" for i in range(len(ytrue))]
self.log_labeled_images("images", lambda: (x, get_label_strings()))
if __name__ == "__main__":
# python -m examples.mnist
MnistClassification.launch(Config(batch_size=16))
Let's break down each part individually.
Config
Tasks are parametrized using a config dataclass. The ml.field
function is a lightweight wrapper around dataclasses.field
which is a bit more ergonomic, and ml.Config
is a bigger dataclass which contains a bunch of other options for configuring training.
@dataclass
class Config(mlfab.Config):
in_dim: int = mlfab.field(1, help="Number of input dimensions")
Model
All tasks should subclass ml.Task
and override the generic Config
with the task-specific config. This is very important, not just because it makes your life easier by working nicely with your typechecker, but because the framework looks at the generic type when resolving the config for the given task.
class MnistClassification(mlfab.Task[Config]):
def __init__(self, config: Config) -> None:
super().__init__(config)
self.model = nn.Sequential(
nn.Conv2d(config.in_dim, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Linear(128, 10),
)
Loggers
mlfab
supports logging to multiple downstream loggers, and provides a bunch of helper functions for doing common logging operations, like rate limiting, converting image resolution to normal sizes, overlaying captions on images, and more.
If this function is not overridden, the task will just log to stdout
.
def set_loggers(self) -> None:
self.add_logger(
mlfab.StdoutLogger(),
mlfab.TensorboardLogger(self.exp_dir),
)
Datasets
The task should return the dataset used for training, based on the phase. ml.Phase
is a string literal with values in ["train", "valid", "test"]
. mlfab.get_data_dir()
returns the data directory, which can be set in a configuration file which lives in ~/.mlfab.yml
. The default configuration file will be written on first run if it doesn't exist yet.
def get_dataset(self, phase: mlfab.Phase) -> Dataset[tuple[Tensor, Tensor]]:
root_dir = mlfab.get_data_dir() / "mnist"
return MNIST(root_dir=root_dir, train=phase == "train")
Optimizers
def build_optimizer(self) -> Optimizer:
return mlfab.Adam.get(self, lr=1e-3)
Compute Loss
Each mlfab
model should either implement the forward
function, which should take a batch from the dataset and return the loss, or, if more control is desired, the get_loss
function can be overridden.
def forward(self, x: Tensor) -> Tensor:
return self.model(x)
def get_loss(self, batch: tuple[Tensor, Tensor], state: mlfab.State) -> Tensor:
x, y = batch
yhat = self(x)
self.log_step(batch, yhat, state)
return F.cross_entropy(yhat, y.squeeze(-1))
Logging
When we call log_step
in the get_loss
function, it delegates to either log_train_step
, log_valid_step
or log_test_step
, depending on what state.phase
is. In this case, on each validation step we log images of the MNIST digits with the labels that our model predicts.
def log_valid_step(self, batch: tuple[Tensor, Tensor], output: Tensor, state: mlfab.State) -> None:
(x, y), yhat = batch, output
def get_label_strings() -> list[str]:
ytrue, ypred = y.squeeze(-1), yhat.argmax(-1)
return [f"ytrue={ytrue[i]}, ypred={ypred[i]}" for i in range(len(ytrue))]
self.log_labeled_images("images", lambda: (x, get_label_strings()))
Running
We can launch a training job using the launch
class method. The config can be a Config
object, or it can be the path to a config.yaml
file located in the same directory as the task file. You can additionally provide the launcher
argument, which supports training the model across multiple GPUs or nodes.
if __name__ == "__main__":
MnistClassification.launch(Config(batch_size=16))
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file mlfab-0.1.6.tar.gz
.
File metadata
- Download URL: mlfab-0.1.6.tar.gz
- Upload date:
- Size: 166.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7a24b6b2f48941fd5263b54236c199ef9cea703b4ec783a0ff982a3dfc230cb3 |
|
MD5 | 039c7dafdb41bd9d54a963b436426d28 |
|
BLAKE2b-256 | 522af95a1046c9c3ac49b17321e9b9981201fe85297ce9b80ee75bb1318cc19b |
File details
Details for the file mlfab-0.1.6-py3-none-any.whl
.
File metadata
- Download URL: mlfab-0.1.6-py3-none-any.whl
- Upload date:
- Size: 196.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 961e4f0091852deb15da00d93db512906027aa529f53bb51635f00bf0d29123c |
|
MD5 | fc19d1a1c884e6462dfa5c1bc605ea34 |
|
BLAKE2b-256 | f418da2b6d56ab8c650cff7e5ab47178a45fb38833c49cc754c64d5a84e21976 |