Skip to main content

manipulate sets of tensors

Project description

tensorset

tensorset is a pytorch library that lets you perform operations on related sequences using a unified TensorSet object.

It aims to reduce the complexity of using multiple related sequences. Sequences like these are very commonly used as inputs to a transformer model:

import torch
from torch import nn

batch_size = 8
sequence_length = 1024
vocab_size = 256
hidden_size = 768
pad_id = 255

token_embeddings = nn.Embedding(vocab_size, hidden_size)

input_ids = torch.randint(0, vocab_size, (batch_size, sequence_length))
input_embeds = token_embeddings(input_ids) # Shape: batch_size, sequence_length, hidden_size
key_pad_mask = input_ids == pad_id # Shape: batch_size, sequence_length 
is_whitespace_mask = (input_ids == 0) | (input_ids == 1)# Shape: batch_size, sequence_length 

# These tensors would be used like this:
# logits = transformer_model(input_embeds, key_pad_mask, is_whitespace_mask)

Notice wherever these tensors are truncated or stacked or concatenated there will be tedious repetitive code like this:

def truncate_inputs(input_ids, key_pad_mask, is_whitespace_mask, length):
  input_ids= input_ids[:, :length]
  key_pad_mask= key_pad_mask[:, :length]
  is_whitespace_mask= is_whitespace_mask[:, :length]
  return input_ids, key_pad_mask, is_whitespace_mask

truncated_inputs = truncate_inputs(input_ids, key_pad_mask, is_whitespace_mask, length=10)

This repetitive code can be avoided. input_ids, input_embeds, key_pad_mask, and is_whitespace_mask are all related. They all have matching leading dimensions for batch_size and sequence length.

TensorSet is a container for these related multi-dimensional sequences, making this kind of manipulation very easy and ergonomic.

import tensorset as ts
length = 10
inputs = ts.TensorSet(
                input_ids=input_ids,
                input_embeds=input_embeds,
                key_pad_mask=key_pad_mask,
                is_whitespace_mask=is_whitespace_mask,
         )
truncated_inputs = inputs.iloc[:, :length]
print(truncated_inputs)

prints:

TensorSet(
  named_columns:
    name: input_ids, shape: torch.Size([8, 10]), dtype: torch.int64
    name: input_embeds, shape: torch.Size([8, 10, 768]), dtype: torch.float32
    name: key_pad_mask, shape: torch.Size([8, 10]), dtype: torch.bool
    name: is_whitespace_mask, shape: torch.Size([8, 10]), dtype: torch.bool
)

Features

Stack related TensorSets to create larger batches

sequence_length = 20
sequence_1 = ts.TensorSet(
                torch.randn(sequence_length, 512),
                torch.randn(sequence_length, 1024),
            )
sequence_2 = ts.TensorSet(
                torch.randn(sequence_length, 512),
                torch.randn(sequence_length, 1024),
            )
batch = ts.stack((sequence_1, sequence_2), 0)

print(batch.size(1)) # This is the sequence length, prints 20
print(batch.size(0)) # This is the batch size, prints 2

Pad TensorSets with a specific amount of padding along the sequence dimension

sequence_length = 20
sequence = ts.TensorSet(
                torch.randn(sequence_length, 512),
                torch.randn(sequence_length, 1024),
            )
pad_value = -200
padded_sequence = sequence.pad(44, 0, pad_value) # add 44 dims of padding along dimension 0, of pad_value
print(padded_sequence.size(0)) # This is the new sequence length, prints 64

Stack TensorSets with irregular shape, using torch.nested

# C, H, W pixel_values, and an additional binary mask
image1 = ts.TensorSet(
          pixel_values = torch.randn(3, 20, 305),
          mask = torch.randn(3, 20, 305) > 0,
        )
image2 = ts.TensorSet(
          pixel_values = torch.randn(3, 450, 200),
          mask = torch.randn(3, 450, 200) > 0,
        )
images = ts.stack_nt([image1, image2])
print(images)

output:

TensorSet(
  named_columns:
    name: pixel_values, shape: nested_tensor.Size([2, 3, irregular, irregular]), dtype: torch.float32
    name: mask, shape: nested_tensor.Size([2, 3, irregular, irregular]), dtype: torch.bool
)

TODO

  • Access by lists of columns
  • Enable operations over irregular dims that are not supported yet by torch.nested, such as mean and index select

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensorset-0.4.2.tar.gz (7.1 kB view details)

Uploaded Source

Built Distribution

tensorset-0.4.2-py3-none-any.whl (6.5 kB view details)

Uploaded Python 3

File details

Details for the file tensorset-0.4.2.tar.gz.

File metadata

  • Download URL: tensorset-0.4.2.tar.gz
  • Upload date:
  • Size: 7.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for tensorset-0.4.2.tar.gz
Algorithm Hash digest
SHA256 a523146fb0a09bf8e14e798edb253fe99d81601d675ae07ccd6c17077c5c8ac3
MD5 168908e984b6a3cf232650f1658f1e3d
BLAKE2b-256 2cc5386218b24ee31455c79b4c80693fa7b6b7aca9429d685b032f9eb2fa9eb6

See more details on using hashes here.

File details

Details for the file tensorset-0.4.2-py3-none-any.whl.

File metadata

  • Download URL: tensorset-0.4.2-py3-none-any.whl
  • Upload date:
  • Size: 6.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for tensorset-0.4.2-py3-none-any.whl
Algorithm Hash digest
SHA256 9c033b6a319fc6489c06f4ac1bee6d24a8a1d618eb9557ebad48f9812d83479e
MD5 a5643e08874ab21c1ba7c6663fed6c81
BLAKE2b-256 80a013a44301b33d3f73f8428c62175c5568e4a9e731d275b82a1b03c8bbf24e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page