A pytorch wrapper that makes .fit() (and others) possible in nn.Module!
Project description
Welcome to torchy
The aim of this project is to create a PyTorch wrapper that wraps the torch.nn.Module and has additional data preprocessing utilities on torch.utils.data. We aim to retain every functionality of PyTorch, while keeping them native, and also add our piece of functionality.
The aim of torchy is to enhance the experience of Pytorch and not to replace it. torchy is ready to be used in everyday code and is in a beta stage as of today. After additional checks and testing, torchy will be passed as stable.
Introduction
torchy is a PyTorch wrapper that has some additional benefits to using plain pytorch. With torchy you have everything in pytorch plus some additional features found on other libraries. The main separating factor between torchy and torchfit or 100s of other pytorch-like modules that exists is that you don't have to re-learn to use the pytorch module.
torchy is a wrapper build on top of pytorch which enables you to use your existing code on pyTorch and still have the added benefits.
Installation using pip
It's a good idea to have PyTroch preinstalled on your current virtual environment. See official guide to install PyTorch.
It's recommended to have python version >=3.6 and <=3.8, although no problems have yet been encountered in 3.9, and 3.10.
Use pypi's pip to install torchy.
pip install torchy
or
pip3 install torchy
PS: PyTorch will be atuomatically installed to your environment if you already don't have it but it's recommended to install it using the official guide.
Additional Functionality
Define a model using nn.Module just like with regular pyTorch but import torchy.nn
instead of torch.nn
.
import torchy.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self,x):
return self.linear(x)
model = Model()
Now you can use torchy's functionality.
import torch
from torchy.utils.data import TensorDataset, DataLoader
# prepare dummy data
x = torch.tensor([[12.],[13],[15]])
y = torch.tensor([[2.],[3],[4]])
dataset = TensorDataset(x,y)
# nn is still same (torchy.nn)
loss_fn = nn.functional.mse_loss
opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=.9)
# Use mode.fit() to fit the model in the given TensorDataset
model = model.fit(dataset, loss_fn, opt, epochs=20, valid_pct=25, batch_size=2)
# Now you have a trained model and can also access model.hist attribute
print(model.hist)
You can also use a dataloader instead of a dataset.
# Use a DataLoader instead of a TensorDataSet
dataloader = DataLoader(dataset, batch_size = 2)
model = model.fit(dataloader, loss_fn,opt,20)
If you're using a dataloader and want to do validation while running .fit() after every epochs, you will have to manually pass valid_dataloader.
torchy.utils.data
can also be used to put your dataloader into a device and split your dataset.
from torchy.utils.data import DeviceDL, SplitPCT
# put dataloader in appropirate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataloader = DeviceDL(dataloader)
# Split the dataset
dataset = SplitPCT(dataset)
train_ds, valid_ds = dataset.train_ds, dataset.valid_ds
Additional features like get_loss(), _accuracy() and full documentation, user guide, best practices and tutorials to use torchy can be found in the docs.
To-do
- More testing
Feel free to contribute your code and drop a star on the project if you liked the idea.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchy-0.1.3.tar.gz
.
File metadata
- Download URL: torchy-0.1.3.tar.gz
- Upload date:
- Size: 11.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9038dfe39e3ea13fc7b06c65d2774ff8e95a9870ca37dd452b9c621619536929 |
|
MD5 | 57263ec36671e9064e7a97a4c61de377 |
|
BLAKE2b-256 | a573f68b8388afef634abb4d016fcee66342e78f7075f1d15ce8f35d0a51f777 |
File details
Details for the file torchy-0.1.3-py3-none-any.whl
.
File metadata
- Download URL: torchy-0.1.3-py3-none-any.whl
- Upload date:
- Size: 9.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 981bf5a5fa33f7717e6345c5e54a592a273ee975704d32aaafe8518b3dc7b6a4 |
|
MD5 | 46b3a1cc754cb4661b8718c790d6e3fb |
|
BLAKE2b-256 | 1f452b8dea42bf10e3deeeafc429c344d8b876045216d626d8cad8be88c22bed |