Skip to main content

Machine Learning framework allowing plug-and-play training for pytorch models

Project description

🔥 pyro

Lightweight Machine Learning framework allowing plug-and-play training for Pytorch models

  • Lightning inspired
  • 💾 Support for wandb and checkpoints out-of-the-box
  • 📊 Pretty logs, plots and support for metrics
  • ✨ Fully type-safe
  • 🪶 Lightweight and easy to use

Examples

See 📓 notebooks for examples using pyro. In particular, you can find:

  • Iris : Simplest example training a small MLP on the Iris dataset.
  • SmolVLM on Flowers102 : Features from SmolVLM vision model are extracted and used to train a linear classifier on the Flowers102 dataset, reaching a test accuracy of 98.6%.

Usage

You can use 🔥 pyro with minimal code changes and forever forget about writing training loops. Here is an example of a pyro model and training script to get you started.

1. Define your Model

import torch
import pyroml as p

class MySOTAModel(p.PyroModel):
    def __init__(self):
        super().__init__()
        self.loss_fn = torch.nn.MyLossFunction()

    # Optionally, configure your own optimizer and scheduler, see more in the docs
    def configure_optimizers(self, _):
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=tr.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size=1, gamma=0.99
        )

    def step(self, batch, stage: p.Stage):
        # Extract data from your dataset batch
        # Batches and model are moved to the appropriate device automatically
        x, y = batch
        # Forward the model
        preds = self(x)
        # Compute the loss
        loss = self.loss_fn(preds, y)
        # Optionally, register some metrics
        self.log(loss=loss.item(), accuracy=compute_accuracy(preds, y))
        # Return loss when training, otherwise return predictions
        if stage == p.Stage.TRAIN:
            return loss
        return preds    

2. Instantiate a Trainer

trainer = p.Trainer(
    lr=0.01,
    max_epochs=32,
    batch_size=16,
    # And many other options such as device, precision, callbacks, ...
)

3. Run training, validation and testing

# Fit the model on given training set and evaluate the model during training  
train_tracker = trainer.fit(model, training_dataset, validation_dataset)
print(train_tracker.records)

# Plot metric curves registered during training 
train_tracker.plot(epoch=True)

# Evaluate your model after training
validation_tracker = trainer.evaluate(model, validation_dataset)
print(validation_tracker.records)

# Test your model on some testing set
_, test_preds = trainer.predict(model, test_dataset)
print("Test Predictions", test_preds)

Requirements

  • Python ^3.10 | ^3.11 | ^3.12
  • Recommended: Poetry v2 (docs)

Installation

pip

# CPU only version
pip install pyroml 
# OR with CUDA-enabled PyTorch and torchvision
pip install pyroml[cuda]
# Additional dependencies that you might require
pip install pyroml[extra] 

poetry

# CPU only version
poetry add pyroml 
# OR with CUDA-enabled PyTorch and torchvision
poetry add pyroml[cuda] --source pytorch-cu124 
# Additional dependencies that you might require
poetry add [...] --extras extra

Locally

# Clone the repo
git clone https://github.com/peacefulotter/pyroml.git
cd pyroml

# Install dependencies
poetry config virtualenvs.in-project true  
poetry install --with dev

Tests

Running tests has been made easy using pytest. First install the package and run the script:

poetry install --with test
./run_tests.sh

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyroml-2.1.1.tar.gz (30.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyroml-2.1.1-py3-none-any.whl (48.5 kB view details)

Uploaded Python 3

File details

Details for the file pyroml-2.1.1.tar.gz.

File metadata

  • Download URL: pyroml-2.1.1.tar.gz
  • Upload date:
  • Size: 30.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.0.1 CPython/3.11.7 Windows/10

File hashes

Hashes for pyroml-2.1.1.tar.gz
Algorithm Hash digest
SHA256 14dc8ee0dc856a9313c12f223b90a7c9ad0006015f65a8b289c48ecd1c2ee85e
MD5 2ded89d8fbd7c5f8c897fef3e8999fee
BLAKE2b-256 cad702d52f426f048d1701e805dcd07fe72a69df9154efa4d953ee93ff2cfa28

See more details on using hashes here.

File details

Details for the file pyroml-2.1.1-py3-none-any.whl.

File metadata

  • Download URL: pyroml-2.1.1-py3-none-any.whl
  • Upload date:
  • Size: 48.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.0.1 CPython/3.11.7 Windows/10

File hashes

Hashes for pyroml-2.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 134d59a35bb4f70faafd2d91019dbce588db2add7a2afaabe9a5bcda5aea9226
MD5 2f0b1cac4c401a01c87f02d48d656165
BLAKE2b-256 0e8e9ce07a22ed80f160b3967307061e62eb384ae11855527104abb58ed6a549

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page