Skip to main content

A pytorch wrapper that makes .fit() (and others) possible in nn.Module!

Project description

Welcome to torchy

The aim of this project is to create a PyTorch wrapper that wraps the torch.nn.Module and has additional data preprocessing utilities on torch.utils.data. We aim to retain every functionality of PyTorch, while keeping them native, and also add our piece of functionality.

The aim of torchy is to enhance the experience of Pytorch and not to replace it. torchy is ready to be used in everyday code and is in a beta stage as of today. After additional checks and testing, torchy will be passed as stable.

Introduction

torchy is a PyTorch wrapper that has some additional benefits to using plain pytorch. With torchy you have everything in pytorch plus some additional features found on other libraries. The main separating factor between torchy and torchfit or 100s of other pytorch-like modules that exists is that you don't have to re-learn to use the pytorch module.

torchy is a wrapper build on top of pytorch which enables you to use your existing code on pyTorch and still have the added benefits.

Installation using pip

It's a good idea to have PyTroch preinstalled on your current virtual environment. See official guide to install PyTorch.

It's recommended to have python version >=3.6 and <=3.8, although no problems have yet been encountered in 3.9, and 3.10.

Use pypi's pip to install torchy.

pip install torchy 

or

pip3 install torchy

PS: PyTorch will be atuomatically installed to your environment if you already don't have it but it's recommended to install it using the official guide.

Additional Functionality

Define a model using nn.Module just like with regular pyTorch but import torchy.nn instead of torch.nn.

import torchy.nn as nn


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self,x):
        return self.linear(x)


model = Model()

Now you can use torchy's functionality.

import torch
from torchy.utils.data import TensorDataset, DataLoader

# prepare dummy data
x = torch.tensor([[12.],[13],[15]])
y = torch.tensor([[2.],[3],[4]])
dataset = TensorDataset(x,y)

# nn is still same (torchy.nn)
loss_fn = nn.functional.mse_loss
opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=.9)
# Use mode.fit() to fit the model in the given TensorDataset
model = model.fit(dataset, loss_fn, opt, epochs=20, valid_pct=25, batch_size=2)
# Now you have a trained model and can also access model.hist attribute
print(model.hist)

You can also use a dataloader instead of a dataset.

# Use a DataLoader instead of a TensorDataSet
dataloader = DataLoader(dataset, batch_size = 2)
model = model.fit(dataloader, loss_fn,opt,20)

If you're using a dataloader and want to do validation while running .fit() after every epochs, you will have to manually pass valid_dataloader.

torchy.utils.data can also be used to put your dataloader into a device and split your dataset.

from torchy.utils.data import DeviceDL, SplitPCT
# put dataloader in appropirate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataloader = DeviceDL(dataloader)

# Split the dataset
dataset = SplitPCT(dataset)
train_ds, valid_ds = dataset.train_ds, dataset.valid_ds

Additional features like get_loss(), _accuracy() and full documentation, user guide, best practices and tutorials to use torchy can be found in the docs.

To-do

  1. More testing

Feel free to contribute your code and drop a star on the project if you liked the idea.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchy-0.1.3.tar.gz (11.6 kB view details)

Uploaded Source

Built Distribution

torchy-0.1.3-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file torchy-0.1.3.tar.gz.

File metadata

  • Download URL: torchy-0.1.3.tar.gz
  • Upload date:
  • Size: 11.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.13

File hashes

Hashes for torchy-0.1.3.tar.gz
Algorithm Hash digest
SHA256 9038dfe39e3ea13fc7b06c65d2774ff8e95a9870ca37dd452b9c621619536929
MD5 57263ec36671e9064e7a97a4c61de377
BLAKE2b-256 a573f68b8388afef634abb4d016fcee66342e78f7075f1d15ce8f35d0a51f777

See more details on using hashes here.

File details

Details for the file torchy-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: torchy-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.13

File hashes

Hashes for torchy-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 981bf5a5fa33f7717e6345c5e54a592a273ee975704d32aaafe8518b3dc7b6a4
MD5 46b3a1cc754cb4661b8718c790d6e3fb
BLAKE2b-256 1f452b8dea42bf10e3deeeafc429c344d8b876045216d626d8cad8be88c22bed

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page