manipulate sets of tensors
Project description
tensorset
tensorset is a pytorch library that lets you perform operations on related sequences using a unified TensorSet
object.
It aims to reduce the complexity of using multiple related sequences. Sequences like these are very commonly used as inputs to a transformer model:
import torch
from torch import nn
batch_size = 8
sequence_length = 1024
vocab_size = 256
hidden_size = 768
pad_id = 255
token_embeddings = nn.Embedding(vocab_size, hidden_size)
input_ids = torch.randint(0, vocab_size, (batch_size, sequence_length))
input_embeds = token_embeddings(input_ids) # Shape: batch_size, sequence_length, hidden_size
key_pad_mask = input_ids == pad_id # Shape: batch_size, sequence_length
is_whitespace_mask = (input_ids == 0) | (input_ids == 1)# Shape: batch_size, sequence_length
# These tensors would be used like this:
# logits = transformer_model(input_embeds, key_pad_mask, is_whitespace_mask)
Notice wherever these tensors are truncated or stacked or concatenated there will be tedious repetitive code like this:
def truncate_inputs(input_ids, key_pad_mask, is_whitespace_mask, length):
input_ids= input_ids[:, :length]
key_pad_mask= key_pad_mask[:, :length]
is_whitespace_mask= is_whitespace_mask[:, :length]
return input_ids, key_pad_mask, is_whitespace_mask
truncated_inputs = truncate_inputs(input_ids, key_pad_mask, is_whitespace_mask, length=10)
This repetitive code can be avoided. input_ids
, input_embeds
, key_pad_mask
, and is_whitespace_mask
are all related.
They all have matching leading dimensions for batch_size and sequence length.
TensorSet is a container for these related multi-dimensional sequences, making this kind of manipulation very easy and ergonomic.
import tensorset as ts
length = 10
inputs = ts.TensorSet(
input_ids=input_ids,
input_embeds=input_embeds,
key_pad_mask=key_pad_mask,
is_whitespace_mask=is_whitespace_mask,
)
truncated_inputs = inputs.iloc[:, :length]
print(truncated_inputs)
prints:
TensorSet(
named_columns:
name: input_ids, shape: torch.Size([8, 10]), dtype: torch.int64
name: input_embeds, shape: torch.Size([8, 10, 768]), dtype: torch.float32
name: key_pad_mask, shape: torch.Size([8, 10]), dtype: torch.bool
name: is_whitespace_mask, shape: torch.Size([8, 10]), dtype: torch.bool
)
Features
Stack related TensorSets to create larger batches
sequence_length = 20
sequence_1 = ts.TensorSet(
torch.randn(sequence_length, 512),
torch.randn(sequence_length, 1024),
)
sequence_2 = ts.TensorSet(
torch.randn(sequence_length, 512),
torch.randn(sequence_length, 1024),
)
batch = ts.stack((sequence_1, sequence_2), 0)
print(batch.size(1)) # This is the sequence length, prints 20
print(batch.size(0)) # This is the batch size, prints 2
Pad TensorSets with a specific amount of padding along the sequence dimension
sequence_length = 20
sequence = ts.TensorSet(
torch.randn(sequence_length, 512),
torch.randn(sequence_length, 1024),
)
pad_value = -200
padded_sequence = sequence.pad(44, 0, pad_value) # add 44 dims of padding along dimension 0, of pad_value
print(padded_sequence.size(0)) # This is the new sequence length, prints 64
Stack TensorSets with irregular shape, using torch.nested
# C, H, W pixel_values, and an additional binary mask
image1 = ts.TensorSet(
pixel_values = torch.randn(3, 20, 305),
mask = torch.randn(3, 20, 305) > 0,
)
image2 = ts.TensorSet(
pixel_values = torch.randn(3, 450, 200),
mask = torch.randn(3, 450, 200) > 0,
)
images = ts.stack_nt([image1, image2])
print(images)
output:
TensorSet(
named_columns:
name: pixel_values, shape: nested_tensor.Size([2, 3, irregular, irregular]), dtype: torch.float32
name: mask, shape: nested_tensor.Size([2, 3, irregular, irregular]), dtype: torch.bool
)
TODO
- Access by lists of columns
- Enable operations over irregular dims that are not supported yet by torch.nested, such as mean and index select
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tensorset-0.4.2.tar.gz
.
File metadata
- Download URL: tensorset-0.4.2.tar.gz
- Upload date:
- Size: 7.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a523146fb0a09bf8e14e798edb253fe99d81601d675ae07ccd6c17077c5c8ac3 |
|
MD5 | 168908e984b6a3cf232650f1658f1e3d |
|
BLAKE2b-256 | 2cc5386218b24ee31455c79b4c80693fa7b6b7aca9429d685b032f9eb2fa9eb6 |
File details
Details for the file tensorset-0.4.2-py3-none-any.whl
.
File metadata
- Download URL: tensorset-0.4.2-py3-none-any.whl
- Upload date:
- Size: 6.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9c033b6a319fc6489c06f4ac1bee6d24a8a1d618eb9557ebad48f9812d83479e |
|
MD5 | a5643e08874ab21c1ba7c6663fed6c81 |
|
BLAKE2b-256 | 80a013a44301b33d3f73f8428c62175c5568e4a9e731d275b82a1b03c8bbf24e |