Skip to main content

Torchmate: A High level PyTorch Training Library

Project description

logo
A High level PyTorch Training and Utility Library

GitHub Workflow Status (branch) GitHub Workflow Status (branch) Codecov Read the Docs

PyPI Code Style MIT License Python Version

📚 Project Documentation

Visit Torchmate's Read The Docs Project Page or read following README to know more about Torchmate library

💡 Introduction

So, why did I write Torchmate? I was a big fan of TensorFlow and Keras. But during my undergrad thesis I needed to use PyTorch. I was astonished with PyToroch’s flexibility. But I had to write the same boilerplate code which was quite frustrating to me. So, I decided to use a high level library like Catalyst or Lightning. Catalyst was great, but I missed Keras's verbose training output (which is cleaner) for better visualization (I know it's not a very good reason for writing a library). Pytorch Lightning also very good, but it changes a lot how we usually structure our code. Additionally, I was curious about how high-level frameworks like Keras, Catalyst, or Lightning work internally and utilize callbacks for extending functionalities. Building a minimalistic library myself seemed like the best way to understand these concepts. So, that's why I built Torchmate. Torchmate incorporates everything (actually not everything, some functionalities are still under development) I need and the way I prefere as a deep learning practitioner.

🔑 Key Features

  • Encapsulate all training essentials: Model, data loaders, loss function, optimizer, and learning rate schedulers.
  • Mixed precision training (AMP): Train faster and potentially achieve better generalization with mixed precision calculations.
  • Gradient Accumulation: Train on larger batches virtually by accumulating gradients, improving memory efficiency (Implemented through callback).
  • Gradient Clipping: Prevent exploding gradients and stabilize training (Implemented through callback).
  • Gradient Penalty: Enhance stability in generative models like GANs.
  • Callback Mechanism: Monitor progress, save checkpoints, early stopping and extend functionality with custom callbacks.
  • Experiment Tracking: Integrate dedicated tools like Weights & Biases or TensorBoard through callbacks.
  • Minimal Dependency: Torchmate only requires four dependencies- PyTorch (of course), NumPy, Matplotlib, and Weights & Biases (Wandb).
  • Research Paper Modules: Implementations of various modules from research papers, facilitating reproducibility and experimentation.

⚙️ Installation

First make sure that PyTorch is installed on your envioronment. Then install torchmate via pip.

From PyPI:

$ pip install torchmate

From Source:

$ pip install git+https://github.com/SaihanTaki/torchmate

N.B. Torchmate requires the PyTorch (Of course) library to function. But torchmate does not list PyTorch as a dependency to avoid unnecessary overhead during installation. Excluding PyTorch as a dependency allows users to explicitly install the version of PyTorch best suited for their specific needs and environment. For instance, users who don't require GPU acceleration can install the CPU-only version of PyTorch, reducing unnecessary dependencies and installation size.PyTorch Installation Page: https://pytorch.org/get-started/locally/

⏳ Quick Example

import torch
import numpy as np
from torchmate.trainer import Trainer
from sklearn.model_selection import train_test_split

# Create a simple neural network model
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = torch.nn.Linear(1, 1)

    def forward(self, x):
        return self.fc1(x) 

# Create synthetic data
X = torch.tensor(np.random.rand(1000, 1), dtype=torch.float32)
y = 2 * X + 1 + torch.randn(1000, 1) * 0.1  # Adding some noise

# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Create DataLoader objects for training and validation
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataset = torch.utils.data.TensorDataset(X_val, y_val)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)

model = SimpleModel()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


# Create a Trainer instance
trainer = Trainer(
    model,
    train_dataloader,
    val_dataloader,
    loss_fn,
    optimizer,
    num_epochs=3,
    device=device
)


# Train the model
history = trainer.fit()

🛡️ License

Torchmate is distributed under MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmate-0.1.1.tar.gz (26.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchmate-0.1.1-py2.py3-none-any.whl (30.6 kB view details)

Uploaded Python 2Python 3

File details

Details for the file torchmate-0.1.1.tar.gz.

File metadata

  • Download URL: torchmate-0.1.1.tar.gz
  • Upload date:
  • Size: 26.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.19

File hashes

Hashes for torchmate-0.1.1.tar.gz
Algorithm Hash digest
SHA256 a8b64d2a3dfcbadc71f0f78b96a0a3b543855dea7de2639ba11e99a7d8747ba7
MD5 26f50f6df81371cfef8f50bb75aed0df
BLAKE2b-256 f99b5965aaa1e0b6e016c2c7a3bf0e33d69bfa409ff5ea812e6cad485311a5d5

See more details on using hashes here.

File details

Details for the file torchmate-0.1.1-py2.py3-none-any.whl.

File metadata

  • Download URL: torchmate-0.1.1-py2.py3-none-any.whl
  • Upload date:
  • Size: 30.6 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.19

File hashes

Hashes for torchmate-0.1.1-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 44ea44be267940888075c146d1a5f3fa3e16d276d506916c2a6be8d854d9d904
MD5 8ca0da9841a55a05045fc6228a65b6c3
BLAKE2b-256 7d412ee3c2e3280d68f84cd23871fc3817b08484eb98324e1f8e3b65fb4cad5c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page