Skip to main content

manipulate sets of tensors

Project description

tensorset

tensorset is a pytorch library that lets you perform operations on related sequences using a unified TensorSet object.

It aims to reduce the complexity of using multiple related sequences. Sequences like these are very commonly used as inputs to a transformer model:

import torch
from torch import nn

batch_size = 8
sequence_length = 1024
vocab_size = 256
hidden_size = 768
pad_id = 255

token_embeddings = nn.Embedding(vocab_size, hidden_size)

input_ids = torch.randint(0, vocab_size, (batch_size, sequence_length))
input_embeds = token_embeddings(input_ids) # Shape: batch_size, sequence_length, hidden_size
key_pad_mask = input_ids == pad_id # Shape: batch_size, sequence_length 
is_whitespace_mask = (input_ids == 0) | (input_ids == 1)# Shape: batch_size, sequence_length 

# These tensors would be used like this:
# logits = transformer_model(input_embeds, key_pad_mask, is_whitespace_mask)

Notice wherever these tensors are truncated or stacked or concatenated there will be tedious repetitive code like this:

def truncate_inputs(input_ids, key_pad_mask, is_whitespace_mask, length):
  input_ids= input_ids[:, :length]
  key_pad_mask= key_pad_mask[:, :length]
  is_whitespace_mask= is_whitespace_mask[:, :length]
  return input_ids, key_pad_mask, is_whitespace_mask

truncated_inputs = truncate_inputs(input_ids, key_pad_mask, is_whitespace_mask, length=10)

This repetitive code can be avoided. input_ids, input_embeds, key_pad_mask, and is_whitespace_mask are all related. They all have matching leading dimensions for batch_size and sequence length.

TensorSet is a container for these related multi-dimensional sequences, making this kind of manipulation very easy and ergonomic.

import tensorset as ts
length = 10
inputs = ts.TensorSet(
                input_ids=input_ids,
                input_embeds=input_embeds,
                key_pad_mask=key_pad_mask,
                is_whitespace_mask=is_whitespace_mask,
         )
truncated_inputs = inputs.iloc[:, :length]
print(truncated_inputs)

prints:

TensorSet(
  named_columns:
    name: input_ids, shape: torch.Size([8, 10]), dtype: torch.int64
    name: input_embeds, shape: torch.Size([8, 10, 768]), dtype: torch.float32
    name: key_pad_mask, shape: torch.Size([8, 10]), dtype: torch.bool
    name: is_whitespace_mask, shape: torch.Size([8, 10]), dtype: torch.bool
)

Features

Stack related TensorSets to create larger batches

sequence_length = 20
sequence_1 = ts.TensorSet(
                torch.randn(sequence_length, 512),
                torch.randn(sequence_length, 1024),
            )
sequence_2 = ts.TensorSet(
                torch.randn(sequence_length, 512),
                torch.randn(sequence_length, 1024),
            )
batch = ts.stack((sequence_1, sequence_2), 0)

print(batch.size(1)) # This is the sequence length, prints 20
print(batch.size(0)) # This is the batch size, prints 2

Pad TensorSets with a specific amount of padding along the sequence dimension

sequence_length = 20
sequence = ts.TensorSet(
                torch.randn(sequence_length, 512),
                torch.randn(sequence_length, 1024),
            )
pad_value = -200
padded_sequence = sequence.pad(44, 0, pad_value) # add 44 dims of padding along dimension 0, of pad_value
print(padded_sequence.size(0)) # This is the new sequence length, prints 64

Stack TensorSets with irregular shape, using torch.nested

# C, H, W pixel_values, and an additional binary mask
image1 = ts.TensorSet(
          pixel_values = torch.randn(3, 20, 305),
          mask = torch.randn(3, 20, 305) > 0,
        )
image2 = ts.TensorSet(
          pixel_values = torch.randn(3, 450, 200),
          mask = torch.randn(3, 450, 200) > 0,
        )
images = ts.stack_nt([image1, image2])
print(images)

output:

TensorSet(
  named_columns:
    name: pixel_values, shape: nested_tensor.Size([2, 3, irregular, irregular]), dtype: torch.float32
    name: mask, shape: nested_tensor.Size([2, 3, irregular, irregular]), dtype: torch.bool
)

TODO

  • Access by lists of columns
  • Enable operations over irregular dims that are not supported yet by torch.nested, such as mean and index select

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensorset-0.4.0.tar.gz (6.8 kB view details)

Uploaded Source

Built Distribution

tensorset-0.4.0-py3-none-any.whl (6.3 kB view details)

Uploaded Python 3

File details

Details for the file tensorset-0.4.0.tar.gz.

File metadata

  • Download URL: tensorset-0.4.0.tar.gz
  • Upload date:
  • Size: 6.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for tensorset-0.4.0.tar.gz
Algorithm Hash digest
SHA256 d908b5ce5e25cfdc6fd89e5f74541dfb598a3d5c1e1b84ce2e304ac55126daf1
MD5 55f26d7d36851e1d4078034f886768ab
BLAKE2b-256 a688de7a10c3ac0d81032452e303054ab10c7f8db3aea9a9df432ec35ab0ef4d

See more details on using hashes here.

File details

Details for the file tensorset-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: tensorset-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 6.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for tensorset-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 76c0add1c75f2acf34fb4e2a1c57835edd3b605e6dcde98f0cedfd7bd5e23a41
MD5 6d82ff406038ada1167748d1dbe609dd
BLAKE2b-256 3e522076f2051f579a49cba23a7a6c3963f6cb04d0173f3b3ebc37261d740472

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page