Skip to main content

Machine Learning framework allowing plug-and-play training for pytorch models

Project description

🔥 pyro

Lightweight Machine Learning framework allowing plug-and-play training for Pytorch models

  • Lightning inspired
  • 💾 Support for wandb and checkpoints out-of-the-box
  • 📊 Pretty logs, plots and support for metrics
  • ✨ Fully type-safe
  • 🪶 Lightweight and easy to use

Usage

You can use 🔥 pyro with minimal code changes and forever forget about writing training loops. Here is an example of a pyro model and training script for the Iris dataset to get you started.

1. Define your Model

import pyroml as p

class MySOTAModel(p.PyroModel):
    def __init__(self):
        super().__init__()
        self.loss_fn = torch.nn.MyLossFunction()

    # Optionally, configure your own optimizer and scheduler, see more in the docs
    def configure_optimizers(self, _):
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=tr.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size=1, gamma=0.99
        )

    def step(self, batch, stage: p.Stage):
        # Extract data from your dataset batch
        # Batches and model are moved to the appropriate device automatically
        x, y = batch
        # Forward the model
        preds = self(x)
        # Compute the loss
        loss = self.loss_fn(preds, y)
        # Optionally, register some metrics
        self.log(loss=loss.item(), accuracy=compute_accuracy(preds, y))
        # Return loss when training, otherwise return predictions
        if stage == p.Stage.TRAIN:
            return loss
        return preds    

2. Instantiate a Trainer

trainer = p.Trainer(
    lr=0.01,
    max_epochs=32,
    batch_size=16,
    # And many other options such as device, precision, callbacks, ...
)

3. Run training, validation and testing

# Fit the model on given training set and evaluate the model during training  
train_tracker = trainer.fit(model, training_dataset, validation_dataset)
print(train_tracker.records)

# Plot metric curves registered during training 
train_tracker.plot(epoch=True)

# Evaluate your model after training
validation_tracker = trainer.evaluate(model, validation_dataset)
print(validation_tracker.records)

# Test your model on some testing set
_, test_preds = trainer.predict(model, test_dataset)
print("Test Predictions", test_preds)

Requirements

  • Python ^3.10 | ^3.11 | ^3.12
  • Recommended: Poetry v2 (docs)

Installation

pip

# CPU only version
pip install pyroml 
# OR with CUDA-enabled PyTorch and torchvision
pip install pyroml[cuda]
# Additional dependencies that you might require
pip install pyroml[extra] 

poetry

# CPU only version
poetry add pyroml 
# OR with CUDA-enabled PyTorch and torchvision
poetry add pyroml[cuda] --source pytorch-cu124 
# Additional dependencies that you might require
poetry add [...] --extras extra

Locally

# Clone the repo
git clone https://github.com/peacefulotter/pyroml.git
cd pyroml

# Install dependencies
poetry config virtualenvs.in-project true  
poetry install --with dev

Tests

Running tests has been made easy using pytest. First install the package and run the script:

poetry install --with test
./run_tests.sh

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyroml-2.1.0.tar.gz (30.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyroml-2.1.0-py3-none-any.whl (47.2 kB view details)

Uploaded Python 3

File details

Details for the file pyroml-2.1.0.tar.gz.

File metadata

  • Download URL: pyroml-2.1.0.tar.gz
  • Upload date:
  • Size: 30.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.0.1 CPython/3.11.7 Windows/10

File hashes

Hashes for pyroml-2.1.0.tar.gz
Algorithm Hash digest
SHA256 82ac7d6527b0f6fbc372a6542f9cc8f76f37fe8fd69aeeef84cfe8f6941b2539
MD5 3988bf5b19b48dbaf5927e4364c05bdd
BLAKE2b-256 a579d5a4d72df4333d8b79bee7b988d256f618aff4ce4bcce80234960bb108d0

See more details on using hashes here.

File details

Details for the file pyroml-2.1.0-py3-none-any.whl.

File metadata

  • Download URL: pyroml-2.1.0-py3-none-any.whl
  • Upload date:
  • Size: 47.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.0.1 CPython/3.11.7 Windows/10

File hashes

Hashes for pyroml-2.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b15f283e758a4d90a70a83bdc6d4355bee231fe011a898d060ba02978ed90a90
MD5 e573f4ac3395fe2a6b9a36b8ee4c49d8
BLAKE2b-256 97065bc2ae0fc289e432843ac0b02fff81eb5e2ef9ffd5fd9e41ffa640bd0403

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page