Machine Learning framework allowing plug-and-play training for pytorch models
Project description
🔥 pyro
Lightweight Machine Learning framework allowing plug-and-play training for Pytorch models
- ⚡ Lightning inspired
- 💾 Support for wandb and checkpoints out-of-the-box
- 📊 Pretty logs, plots and support for metrics
- ✨ Fully type-safe
- 🪶 Lightweight and easy to use
Usage
You can use 🔥 pyro with minimal code changes and forever forget about writing training loops. Here is an example of a pyro model and training script for the Iris dataset to get you started.
1. Define your Model
import pyroml as p
class MySOTAModel(p.PyroModel):
def __init__(self):
super().__init__()
self.loss_fn = torch.nn.MyLossFunction()
# Optionally, configure your own optimizer and scheduler, see more in the docs
def configure_optimizers(self, _):
self.optimizer = torch.optim.AdamW(self.parameters(), lr=tr.lr)
self.scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, step_size=1, gamma=0.99
)
def step(self, batch, stage: p.Stage):
# Extract data from your dataset batch
# Batches and model are moved to the appropriate device automatically
x, y = batch
# Forward the model
preds = self(x)
# Compute the loss
loss = self.loss_fn(preds, y)
# Optionally, register some metrics
self.log(loss=loss.item(), accuracy=compute_accuracy(preds, y))
# Return loss when training, otherwise return predictions
if stage == p.Stage.TRAIN:
return loss
return preds
2. Instantiate a Trainer
trainer = p.Trainer(
lr=0.01,
max_epochs=32,
batch_size=16,
# And many other options such as device, precision, callbacks, ...
)
3. Run training, validation and testing
# Fit the model on given training set and evaluate the model during training
train_tracker = trainer.fit(model, training_dataset, validation_dataset)
print(train_tracker.records)
# Plot metric curves registered during training
train_tracker.plot(epoch=True)
# Evaluate your model after training
validation_tracker = trainer.evaluate(model, validation_dataset)
print(validation_tracker.records)
# Test your model on some testing set
_, test_preds = trainer.predict(model, test_dataset)
print("Test Predictions", test_preds)
Requirements
- Python ^3.10 | ^3.11 | ^3.12
- Recommended: Poetry v2 (docs)
Installation
pip
# CPU only version
pip install pyroml
# OR with CUDA-enabled PyTorch and torchvision
pip install pyroml[cuda]
# Additional dependencies that you might require
pip install pyroml[extra]
poetry
# CPU only version
poetry add pyroml
# OR with CUDA-enabled PyTorch and torchvision
poetry add pyroml[cuda] --source pytorch-cu124
# Additional dependencies that you might require
poetry add [...] --extras extra
Locally
# Clone the repo
git clone https://github.com/peacefulotter/pyroml.git
cd pyroml
# Install dependencies
poetry config virtualenvs.in-project true
poetry install --with dev
Tests
Running tests has been made easy using pytest. First install the package and run the script:
poetry install --with test
./run_tests.sh
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pyroml-2.1.0.tar.gz.
File metadata
- Download URL: pyroml-2.1.0.tar.gz
- Upload date:
- Size: 30.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.0.1 CPython/3.11.7 Windows/10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
82ac7d6527b0f6fbc372a6542f9cc8f76f37fe8fd69aeeef84cfe8f6941b2539
|
|
| MD5 |
3988bf5b19b48dbaf5927e4364c05bdd
|
|
| BLAKE2b-256 |
a579d5a4d72df4333d8b79bee7b988d256f618aff4ce4bcce80234960bb108d0
|
File details
Details for the file pyroml-2.1.0-py3-none-any.whl.
File metadata
- Download URL: pyroml-2.1.0-py3-none-any.whl
- Upload date:
- Size: 47.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.0.1 CPython/3.11.7 Windows/10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b15f283e758a4d90a70a83bdc6d4355bee231fe011a898d060ba02978ed90a90
|
|
| MD5 |
e573f4ac3395fe2a6b9a36b8ee4c49d8
|
|
| BLAKE2b-256 |
97065bc2ae0fc289e432843ac0b02fff81eb5e2ef9ffd5fd9e41ffa640bd0403
|