Skip to main content

Machine Learning framework allowing plug-and-play training for pytorch models

Project description

🔥 pyro

Lightweight Machine Learning framework allowing plug-and-play training for Pytorch models

  • Lightning inspired
  • 💾 Support for wandb and checkpoints out-of-the-box
  • 📊 Pretty logs, plots and support for metrics
  • ✨ Fully type-safe
  • 🪶 Lightweight and easy to use

Examples

See 📓 notebooks for examples using pyro. In particular, you can find:

  • Iris : Simplest example training a small MLP on the Iris dataset.
  • SmolVLM on Flowers102 : Features from SmolVLM vision model are extracted and used to train a linear classifier on the Flowers102 dataset, reaching a test accuracy of 98.6%.

Usage

You can use 🔥 pyro with minimal code changes and forever forget about writing training loops. Here is an example of a pyro model and training script to get you started.

1. Define your Model

import torch
import pyroml as p

class MySOTAModel(p.PyroModule):
    def __init__(self):
        super().__init__()
        self.loss_fn = torch.nn.MyLossFunction()

    # Optionally, configure your own optimizer and scheduler, see more in the docs
    def configure_optimizers(self, _):
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=tr.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size=1, gamma=0.99
        )

    def step(self, batch, stage: p.Stage):
        # Extract data from your dataset batch
        # Batches and model are moved to the appropriate device automatically
        x, y = batch
        # Forward the model
        preds = self(x)
        # Compute the loss
        loss = self.loss_fn(preds, y)
        # Optionally, register some metrics
        self.log(loss=loss.item(), accuracy=compute_accuracy(preds, y))
        # Return loss when training, otherwise return predictions
        if stage == p.Stage.TRAIN:
            return loss
        return preds    

2. Instantiate a Trainer

trainer = p.Trainer(
    lr=0.01,
    max_epochs=32,
    batch_size=16,
    # And many other options such as device, precision, callbacks, ...
)

3. Run training, validation and testing

# Fit the model on given training set and evaluate the model during training  
train_tracker = trainer.fit(model, training_dataset, validation_dataset)
print(train_tracker.records)

# Plot metric curves registered during training 
train_tracker.plot(epoch=True)

# Evaluate your model after training
validation_tracker = trainer.evaluate(model, validation_dataset)
print(validation_tracker.records)

# Test your model on some testing set
_, test_preds = trainer.predict(model, test_dataset)
print("Test Predictions", test_preds)

Requirements

  • Python ^3.10 | ^3.11 | ^3.12
  • Recommended: Poetry v2 (docs)

Installation

pip

# CPU only version
pip install pyroml 
# OR with CUDA-enabled PyTorch and torchvision
pip install pyroml[cuda]
# Additional dependencies that you might require
pip install pyroml[extra] 

poetry

# CPU only version
poetry add pyroml 
# OR with CUDA-enabled PyTorch and torchvision
poetry add pyroml[cuda] --source pytorch-cu124 
# Additional dependencies that you might require
poetry add [...] --extras extra

Locally

# Clone the repo
git clone https://github.com/peacefulotter/pyroml.git
cd pyroml

# Install dependencies
poetry config virtualenvs.in-project true  
poetry install --with cpu,dev # ,cuda

Tests

Running tests has been made easy using pytest. First install the package and run the script:

poetry install --with test
./run_tests.sh

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyroml-2.1.2.tar.gz (34.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyroml-2.1.2-py3-none-any.whl (52.4 kB view details)

Uploaded Python 3

File details

Details for the file pyroml-2.1.2.tar.gz.

File metadata

  • Download URL: pyroml-2.1.2.tar.gz
  • Upload date:
  • Size: 34.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.1

File hashes

Hashes for pyroml-2.1.2.tar.gz
Algorithm Hash digest
SHA256 d07f86c36fca4093925b803c4b83efd4ae5608621f5fe28007596883ea99a360
MD5 d82cb1ad897b9838b80cb89aab8db120
BLAKE2b-256 a77d0a568e772d9bb5b8d7cee13adbf3b9caa2ea4b76a23d9ed980ea26c75ac2

See more details on using hashes here.

File details

Details for the file pyroml-2.1.2-py3-none-any.whl.

File metadata

  • Download URL: pyroml-2.1.2-py3-none-any.whl
  • Upload date:
  • Size: 52.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.1

File hashes

Hashes for pyroml-2.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 ea249a0ba5c8569aa0f25e78defaf8691e6d4e36be26dfcee983d09bdf3fd1ee
MD5 7095fa3545ed2fc4cd5dc67e68dcd6f7
BLAKE2b-256 03deab1fe6e1cbce15ec182466e29d4ea43b3a9d004553c09504f13afa624345

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page