Skip to main content

Training neural networks in PyTorch

Project description

torchtuples

Python package PyPI PyPI - Python Version License

torchtuples is a small python package for training PyTorch models. It works equally well for numpy arrays and torch tensors. One of the main benefits of torchtuples is that it handles data in the form of nested tuples (see example below).

Installation

Requires python 3.6 or 3.7.

torchtuples depends on PyTorch which should be installed from HERE.

Next, torchtuples can be installed using pip:

pip install git+git://github.com/havakv/torchtuples.git

or by cloning the repo:

git clone https://github.com/havakv/torchtuples.git
cd torchtuples
python setup.py install

Example

import torch
from torch import nn
from torchtuples import Model, optim

Make a data set with three sets of covariates x0, x1 and x2, and a target y. The covariates are structured in a nested tuple x.

n = 500
x0, x1, x2 = [torch.randn(n, 3) for _ in range(3)]
y = torch.randn(n, 1)
x = (x0, (x0, x1, x2))

Create a simple ReLU net that takes as input the tensor x_tensor and the tuple x_tuple. Note that x_tuple can be of arbitrary length. The tensors in x_tuple are passed through a layer lin_tuple, averaged, and concatenated with x_tensor. We then pass our new tensor through the layer lin_cat.

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin_tuple = nn.Linear(3, 2)
        self.lin_cat = nn.Linear(5, 1)
        self.relu = nn.ReLU()

    def forward(self, x_tensor, x_tuple):
        x = [self.relu(self.lin_tuple(xi)) for xi in x_tuple]
        x = torch.stack(x).mean(0)
        x = torch.cat([x, x_tensor], dim=1)
        return self.lin_cat(x)

We can now fit the model with

model = Model(Net(), nn.MSELoss(), optim.SGD(0.01))
log = model.fit(x, y, batch_size=64, epochs=5)

and make predictions with

preds = model.predict(x)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchtuples-0.1.2.tar.gz (36.9 kB view details)

Uploaded Source

Built Distribution

torchtuples-0.1.2-py3-none-any.whl (40.5 kB view details)

Uploaded Python 3

File details

Details for the file torchtuples-0.1.2.tar.gz.

File metadata

  • Download URL: torchtuples-0.1.2.tar.gz
  • Upload date:
  • Size: 36.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0 requests-toolbelt/0.9.1 tqdm/4.39.0 CPython/3.7.5

File hashes

Hashes for torchtuples-0.1.2.tar.gz
Algorithm Hash digest
SHA256 92e1e1fd7cbd8a11315e5c8c68e15f1bbe417134a1a665ec093a342a81393474
MD5 a0e00c1163f876029349e2345777a4ca
BLAKE2b-256 a1b8016bff677128cbf7546bb081fbabc000615e0dca00bf41f9e7d585b58b1b

See more details on using hashes here.

File details

Details for the file torchtuples-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: torchtuples-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 40.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0 requests-toolbelt/0.9.1 tqdm/4.39.0 CPython/3.7.5

File hashes

Hashes for torchtuples-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 165267f8db221904635524bdef5e0bd0ddcf942bc9f087a672dce6339c69433a
MD5 a125d43c65efbe1a123486f88defb0ff
BLAKE2b-256 014b2c02c5c8cc3735d538fd39c957ac3fb8f4330f7d96401cf0e3974fcbd27a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page