Skip to main content

Collection of functions and modules to help development in PyTorch.

Project description

torchwrench

Python Build Documentation Status PyTorch

Collection of functions and modules to help development in PyTorch.

Installation

With pip:

pip install torchwrench

With uv:

uv add torchwrench

The main requirement is PyTorch.

To check if the package is installed and show the package version, you can use the following command in your terminal:

torchwrench-info

This library has been tested on all Python versions 3.8 - 3.14, all PyTorch versions 1.10 - 2.9, and on Linux, Mac and Windows systems.

Examples

torchwrench functions and modules can be used like torch ones. The default acronym for torchwrench is tw.

Label conversions

Supports multiclass labels conversions between probabilities, classes indices, classes names and onehot encoding.

import torchwrench as tw

probs = tw.as_tensor([[0.9, 0.1], [0.4, 0.6]])
names = tw.probs_to_name(probs, idx_to_name={0: "Cat", 1: "Dog"})
# ["Cat", "Dog"]

This package also supports multilabel labels conversions between probabilities, classes multi-indices, classes multi-names and multihot encoding.

import torchwrench as tw

multihot = tw.as_tensor([[1, 0, 0], [0, 1, 1], [0, 0, 0]])
indices = tw.multihot_to_indices(multihot)
# [[0], [1, 2], []]

Finally, this packages includes the powerset multilabel conversions :

import torchwrench as tw

multihot = tw.as_tensor([[1, 0, 0], [0, 1, 1], [0, 0, 0]])
indices = tw.multilabel_to_powerset(multihot, num_classes=3, max_set_size=2)
# tensor([[0, 1, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 1],
#         [1, 0, 0, 0, 0, 0, 0]])

Typing

Typing with number of dimensions :

import torchwrench as tw

x1 = tw.as_tensor([1, 2])
print(isinstance(x1, tw.Tensor2D))  # False
x2 = tw.as_tensor([[1, 2], [3, 4]])
print(isinstance(x2, tw.Tensor2D))  # True

Typing with tensor dtype :

import torchwrench as tw

x1 = tw.as_tensor([1, 2], dtype=tw.int)
print(isinstance(x1, tw.SignedIntegerTensor))  # True

x2 = tw.as_tensor([1, 2], dtype=tw.long)
print(isinstance(x2, tw.SignedIntegerTensor1D))  # True

x3 = tw.as_tensor([1, 2], dtype=tw.float)
print(isinstance(x3, tw.SignedIntegerTensor))  # False

Padding & cropping

Pad a specific dimension :

import torchwrench as tw

x = tw.rand(10, 3, 1)
padded = tw.pad_dim(x, target_length=5, dim=1, pad_value=-1)
# x2 has shape (10, 5, 1), padded with -1

Pad nested list of tensors to a single one :

import torchwrench as tw

tensors = [tw.rand(10, 2), [tw.rand(3)] * 5, tw.rand(0, 5)]
padded = tw.pad_and_stack_rec(tensors, pad_value=0)
# padded has shape (3, 10, 5), padded with 0

Remove values at a specific dimension :

import torchwrench as tw

x = tw.rand(10, 5, 3)
cropped = tw.crop_dim(x, dim=1, target_length=2)
# cropped has shape (10, 2, 3)

Masking

import torchwrench as tw

x = tw.as_tensor([3, 1, 2])
mask = tw.lengths_to_non_pad_mask(x, max_len=4)
# Each row i contains x[i] True values for non-padding mask
# tensor([[True, True, True, False],
#         [True, False, False, False],
#         [True, True, False, False]])
import torchwrench as tw

x = tw.as_tensor([1, 2, 3, 4])
mask = tw.as_tensor([True, True, False, False])
result = tw.masked_mean(x, mask)
# result contains the mean of the values marked as True: 1.5

Others tensors manipulations!

import torchwrench as tw

x = tw.as_tensor([1, 2, 3, 4])
result = tw.insert_at_indices(x, indices=[0, 2], values=5)
# result contains tensor with inserted values: tensor([5, 1, 2, 5, 3, 4])
import torchwrench as tw

perm = tw.randperm(10)
inv_perm = tw.get_inverse_perm(perm)

x1 = tw.rand(10)
x2 = x1[perm]
x3 = x2[inv_perm]
# inv_perm are indices that allow us to get x3 from x2, i.e. x1 == x3 here

Extra: pre-compute datasets to HDF files

Here is an example of pre-computing spectrograms of torchaudio SPEECHCOMMANDS dataset, using pack_dataset function:

from torchaudio.datasets import SPEECHCOMMANDS
from torchaudio.transforms import Spectrogram
from torchwrench import nn
from torchwrench.extras.hdf import pack_to_hdf

speech_commands_root = "path/to/speech_commands"
packed_root = "path/to/packed_dataset.hdf"

dataset = SPEECHCOMMANDS(speech_commands_root, download=True, subset="validation")
# dataset[0] is a tuple, contains waveform and other metadata

class MyTransform(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.spectrogram_extractor = Spectrogram()

    def forward(self, item):
        waveform = item[0]
        spectrogram = self.spectrogram_extractor(waveform)
        return (spectrogram,) + item[1:]

pack_to_hdf(dataset, packed_root, MyTransform())

Then you can load the pre-computed dataset using HDFDataset:

from torchwrench.extras.hdf import HDFDataset

packed_root = "path/to/packed_dataset.hdf"
packed_dataset = HDFDataset(packed_root)
packed_dataset[0]  # == first transformed item, i.e. transform(dataset[0])

Contact

Maintainer:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchwrench-0.8.0.tar.gz (128.4 kB view details)

Uploaded Source

File details

Details for the file torchwrench-0.8.0.tar.gz.

File metadata

  • Download URL: torchwrench-0.8.0.tar.gz
  • Upload date:
  • Size: 128.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for torchwrench-0.8.0.tar.gz
Algorithm Hash digest
SHA256 9a61e06d3ef058886977a332ea5362483fe515a2e91d99e59d9dc16d588e5147
MD5 473291c8c7dcc227aecc065b0245d6f0
BLAKE2b-256 53e96a478e9ca73d31230125654f042554906167efdb3c443b79afd3aacb359c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page