Collection of functions and modules to help development in PyTorch.
Project description
torchwrench
Collection of functions and modules to help development in PyTorch.
Installation
With pip:
pip install torchwrench
With uv:
uv add torchwrench
The main requirement is PyTorch.
To check if the package is installed and show the package version, you can use the following command in your terminal:
torchwrench-info
This library has been tested on all Python versions 3.8 - 3.14, all PyTorch versions 1.10 - 2.9, and on Linux, Mac and Windows systems.
Examples
torchwrench functions and modules can be used like torch ones. The default acronym for torchwrench is tw.
Label conversions
Supports multiclass labels conversions between probabilities, classes indices, classes names and onehot encoding.
import torchwrench as tw
probs = tw.as_tensor([[0.9, 0.1], [0.4, 0.6]])
names = tw.probs_to_name(probs, idx_to_name={0: "Cat", 1: "Dog"})
# ["Cat", "Dog"]
This package also supports multilabel labels conversions between probabilities, classes multi-indices, classes multi-names and multihot encoding.
import torchwrench as tw
multihot = tw.as_tensor([[1, 0, 0], [0, 1, 1], [0, 0, 0]])
indices = tw.multihot_to_indices(multihot)
# [[0], [1, 2], []]
Finally, this packages includes the powerset multilabel conversions :
import torchwrench as tw
multihot = tw.as_tensor([[1, 0, 0], [0, 1, 1], [0, 0, 0]])
indices = tw.multilabel_to_powerset(multihot, num_classes=3, max_set_size=2)
# tensor([[0, 1, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 1],
# [1, 0, 0, 0, 0, 0, 0]])
Typing
Typing with number of dimensions :
import torchwrench as tw
x1 = tw.as_tensor([1, 2])
print(isinstance(x1, tw.Tensor2D)) # False
x2 = tw.as_tensor([[1, 2], [3, 4]])
print(isinstance(x2, tw.Tensor2D)) # True
Typing with tensor dtype :
import torchwrench as tw
x1 = tw.as_tensor([1, 2], dtype=tw.int)
print(isinstance(x1, tw.SignedIntegerTensor)) # True
x2 = tw.as_tensor([1, 2], dtype=tw.long)
print(isinstance(x2, tw.SignedIntegerTensor1D)) # True
x3 = tw.as_tensor([1, 2], dtype=tw.float)
print(isinstance(x3, tw.SignedIntegerTensor)) # False
Padding & cropping
Pad a specific dimension :
import torchwrench as tw
x = tw.rand(10, 3, 1)
padded = tw.pad_dim(x, target_length=5, dim=1, pad_value=-1)
# x2 has shape (10, 5, 1), padded with -1
Pad nested list of tensors to a single one :
import torchwrench as tw
tensors = [tw.rand(10, 2), [tw.rand(3)] * 5, tw.rand(0, 5)]
padded = tw.pad_and_stack_rec(tensors, pad_value=0)
# padded has shape (3, 10, 5), padded with 0
Remove values at a specific dimension :
import torchwrench as tw
x = tw.rand(10, 5, 3)
cropped = tw.crop_dim(x, dim=1, target_length=2)
# cropped has shape (10, 2, 3)
Masking
import torchwrench as tw
x = tw.as_tensor([3, 1, 2])
mask = tw.lengths_to_non_pad_mask(x, max_len=4)
# Each row i contains x[i] True values for non-padding mask
# tensor([[True, True, True, False],
# [True, False, False, False],
# [True, True, False, False]])
import torchwrench as tw
x = tw.as_tensor([1, 2, 3, 4])
mask = tw.as_tensor([True, True, False, False])
result = tw.masked_mean(x, mask)
# result contains the mean of the values marked as True: 1.5
Others tensors manipulations!
import torchwrench as tw
x = tw.as_tensor([1, 2, 3, 4])
result = tw.insert_at_indices(x, indices=[0, 2], values=5)
# result contains tensor with inserted values: tensor([5, 1, 2, 5, 3, 4])
import torchwrench as tw
perm = tw.randperm(10)
inv_perm = tw.get_inverse_perm(perm)
x1 = tw.rand(10)
x2 = x1[perm]
x3 = x2[inv_perm]
# inv_perm are indices that allow us to get x3 from x2, i.e. x1 == x3 here
Extra: pre-compute datasets to HDF files
Here is an example of pre-computing spectrograms of torchaudio SPEECHCOMMANDS dataset, using pack_dataset function:
from torchaudio.datasets import SPEECHCOMMANDS
from torchaudio.transforms import Spectrogram
from torchwrench import nn
from torchwrench.extras.hdf import pack_to_hdf
speech_commands_root = "path/to/speech_commands"
packed_root = "path/to/packed_dataset.hdf"
dataset = SPEECHCOMMANDS(speech_commands_root, download=True, subset="validation")
# dataset[0] is a tuple, contains waveform and other metadata
class MyTransform(nn.Module):
def __init__(self) -> None:
super().__init__()
self.spectrogram_extractor = Spectrogram()
def forward(self, item):
waveform = item[0]
spectrogram = self.spectrogram_extractor(waveform)
return (spectrogram,) + item[1:]
pack_to_hdf(dataset, packed_root, MyTransform())
Then you can load the pre-computed dataset using HDFDataset:
from torchwrench.extras.hdf import HDFDataset
packed_root = "path/to/packed_dataset.hdf"
packed_dataset = HDFDataset(packed_root)
packed_dataset[0] # == first transformed item, i.e. transform(dataset[0])
Contact
Maintainer:
- Étienne Labbé "Labbeti": labbeti.pub@gmail.com
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file torchwrench-0.8.0.tar.gz.
File metadata
- Download URL: torchwrench-0.8.0.tar.gz
- Upload date:
- Size: 128.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9a61e06d3ef058886977a332ea5362483fe515a2e91d99e59d9dc16d588e5147
|
|
| MD5 |
473291c8c7dcc227aecc065b0245d6f0
|
|
| BLAKE2b-256 |
53e96a478e9ca73d31230125654f042554906167efdb3c443b79afd3aacb359c
|