Skip to main content

Training neural networks in PyTorch

Project description

torchtuples

Python package PyPI PyPI PyPI - Python Version License

torchtuples is a small python package for training PyTorch models. It works equally well for numpy arrays and torch tensors. One of the main benefits of torchtuples is that it handles data in the form of nested tuples (see example below).

Installation

torchtuples depends on PyTorch which should be installed from HERE.

Next, torchtuples can be installed with pip:

pip install torchtuples

Or, via conda:

conda install -c conda-forge torchtuples

For the bleeding edge version, install directly from github (consider adding --force-reinstall):

pip install git+git://github.com/havakv/torchtuples.git

or by cloning the repo:

git clone https://github.com/havakv/torchtuples.git
cd torchtuples
python setup.py install

Example

import torch
from torch import nn
from torchtuples import Model, optim

Make a data set with three sets of covariates x0, x1 and x2, and a target y. The covariates are structured in a nested tuple x.

n = 500
x0, x1, x2 = [torch.randn(n, 3) for _ in range(3)]
y = torch.randn(n, 1)
x = (x0, (x0, x1, x2))

Create a simple ReLU net that takes as input the tensor x_tensor and the tuple x_tuple. Note that x_tuple can be of arbitrary length. The tensors in x_tuple are passed through a layer lin_tuple, averaged, and concatenated with x_tensor. We then pass our new tensor through the layer lin_cat.

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin_tuple = nn.Linear(3, 2)
        self.lin_cat = nn.Linear(5, 1)
        self.relu = nn.ReLU()

    def forward(self, x_tensor, x_tuple):
        x = [self.relu(self.lin_tuple(xi)) for xi in x_tuple]
        x = torch.stack(x).mean(0)
        x = torch.cat([x, x_tensor], dim=1)
        return self.lin_cat(x)

    def predict(self, x_tensor, x_tuple):
        x = self.forward(x_tensor, x_tuple)
        return torch.sigmoid(x)

We can now fit the model with

model = Model(Net(), nn.MSELoss(), optim.SGD(0.01))
log = model.fit(x, y, batch_size=64, epochs=5)

and make predictions with either the Net.predict method

preds = model.predict(x)

or with the Net.forward method

preds = model.predict_net(x)

For more examples, see the examples folder.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchtuples-0.2.1.tar.gz (39.1 kB view details)

Uploaded Source

Built Distribution

torchtuples-0.2.1-py3-none-any.whl (41.9 kB view details)

Uploaded Python 3

File details

Details for the file torchtuples-0.2.1.tar.gz.

File metadata

  • Download URL: torchtuples-0.2.1.tar.gz
  • Upload date:
  • Size: 39.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.9.6

File hashes

Hashes for torchtuples-0.2.1.tar.gz
Algorithm Hash digest
SHA256 6e5c3b7fa1c7a872d2503ea38ffb0d6267ea17501a695f579fa419d2ec83cbe5
MD5 17de1ac9139dc4d3bc8cbc08744024a4
BLAKE2b-256 9475d100ff332a7d7918e3de8b952c1e438a6fc642142ebd9b1a25f096703c73

See more details on using hashes here.

File details

Details for the file torchtuples-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: torchtuples-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 41.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.9.6

File hashes

Hashes for torchtuples-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7ca860d2ee5c4fcf2258dc01d8a43cbdd52bd33f39fd1cf3d5cebcfecc80270d
MD5 c0e6f089113d27c3bcc5847a3682fba5
BLAKE2b-256 e23a4c563d6eee77efeb548a9edc109976514428fd7aa4b5a9b227c94bf30a19

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page