Skip to main content

TorchFlint is a collection of advanced utility functions designed for tensor manipulation within the PyTorch ecosystem. It provides granular control over high-dimensional data processing, bridging the gap between high-level neural modules and low-level tensor operations.

Project description

TorchFlint: Advanced Tensor Operations and State Management for PyTorch

License Documentation Status Python Version PyTorch Version

TorchFlint is a collection of advanced utility functions designed for tensor manipulation within the PyTorch ecosystem. It provides granular control over high-dimensional data processing, bridging the gap between high-level neural modules and low-level tensor operations.

✨ Key Features

1. Transparent Buffer Management

Simplifies the registration of non-parameter states in nn.Module using a Pythonic assignment syntax.

  • torchflint.buffer(tensor, persistent): Wraps a tensor so it can be assigned directly inside a module. It automatically handles device movement and state_dict saving without using torch.nn.Module.register_buffer.
  • Containers: Includes BufferObject, BufferList and BufferDict for organized state management.

2. Patchwork

A core module for handling N-dimensional patches, serving as the foundation for convolution and pooling operations.

  • High Performance: Includes torchflint.patchwork.unfold_space, torchflint.patchwork.fold_space, and torchflint.patchwork.fold_stack. Notably, the speed of torchflint.patchwork.fold_stack is close to the official torch.nn.functional.fold (only support images less than 2D) in many scenarios.
  • N-Dimensional Support: Works seamlessly on 1D, 2D, and 3D+ spatial data.
  • Functional Convolution & Pooling: Built upon the torchflint.patchwork module, many functions expose the intermediate steps of convolution and pooling, like torchflint.patchwork.conv_patches, torchflint.patchwork.masked_conv, torchflint.patchwork.pool, torchflint.patchwork.max_pool, and torchflint.patchwork.masked_avg_pool. However, although many of them works well in some tests, but limited (masked version was only tested for their forward process), they will be tested more rigorously in the future.

3. Convenient Tensor Utilities

A suite of tools in torchflint.functional.

🛠️ Installation

pip install torchflint

Requirement: Python 3.7+ and PyTorch >= 1.6.0.

📚 Documentation

Complete API documentation:
https://torchflint.readthedocs.io/

🚀 Usage Examples

Using Buffers

Register buffers via simple assignment:

import torch
import torchflint


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Registers a persistent buffer named 'buffer'
        self.buffer = torchflint.buffer(torch.zeros(10), persistent=True)


model = MyModule()

Patch-based Operations

Using the high-performance patch functions:

from torchflint import patchwork


tensor = torch.randn(32, 3, 32, 32)


# Extract patches
patches = patchwork.unfold_space(tensor, kernel_size=3, stride=2, padding=1)


# Reconstruct
cumulative_output = patchwork.fold_stack(patches, stride=2)
average_output = patchwork.fold_space(patches, stride=2)

Convolution (Standard & Masked)

Perform convolutions, optionally applying a mask to handle sparse data or boundary validity.

from torchflint import patchwork


tensor = torch.randn(1, 3, 32, 32)
weight = torch.randn(16, 3, 3, 3) # Standard weight shape: [Out, In, K, K]


# Standard Convolution
output = patchwork.conv(tensor, weight, stride=2, padding=1)


# Masked Convolution
# Apply convolution only where input_mask is valid (1)
input_mask = torch.randn(1, 1, 32, 32) > 0.5
masked_output = patchwork.masked_conv(
    tensor, weight, 
    stride=2, padding=1, 
    input_mask=input_mask
)

Pooling (Standard & Masked)

Apply pooling operations. Masked pooling allows accurate statistical reduction by ignoring invalid regions instead of padding with zeros or infinity.

from torchflint import patchwork


tensor = torch.randn(1, 3, 32, 32)


# Standard Max Pooling
output = patchwork.max_pool(tensor, kernel_size=2, stride=2)


# Masked Max Pooling
# Regions marked as 0 in the mask are ignored during max calculation
# (Useful for non-rectangular images or padding handling)
mask = torch.ones_like(tensor)
mask[..., 0:5, 0:5] = 0 # Invalidate top-left corner
masked_output = patchwork.masked_max_pool(
    tensor, kernel_size=2, stride=2, 
    mask=mask.bool()
)

📄 License

This project is licensed under the MIT License.

Citation

@misc{torchflint,
  author = {Juanhua Zhang},
  title = {TorchFlint: Advanced Tensor Operations and State Management for PyTorch},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/Caewinix/torchflint}}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchflint-0.0.1b16.post1.dev2.tar.gz (40.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchflint-0.0.1b16.post1.dev2-py3-none-any.whl (43.1 kB view details)

Uploaded Python 3

File details

Details for the file torchflint-0.0.1b16.post1.dev2.tar.gz.

File metadata

  • Download URL: torchflint-0.0.1b16.post1.dev2.tar.gz
  • Upload date:
  • Size: 40.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.18

File hashes

Hashes for torchflint-0.0.1b16.post1.dev2.tar.gz
Algorithm Hash digest
SHA256 96f4e5c58b7566d4ef2cf8cd733549dc0eac9abfa4774a38760dc3eb31942a3f
MD5 336b3f09957ca23111a5a757ab1b0daf
BLAKE2b-256 c7c1646a93a5c9a64e21f2dfc95dc005613eeb373b19e6f9fafe9dad7700656c

See more details on using hashes here.

File details

Details for the file torchflint-0.0.1b16.post1.dev2-py3-none-any.whl.

File metadata

File hashes

Hashes for torchflint-0.0.1b16.post1.dev2-py3-none-any.whl
Algorithm Hash digest
SHA256 d28ca52dbdebb4ec8a1caf468be69ebdd67ee4662e5347a8da12dcf697937603
MD5 deaa730e98a08bc2ccbc25b2fa72e86e
BLAKE2b-256 6f699560e7091615c0d1174344800e43e16a28aad9c7f7de37bcf0b1b650dacd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page