Skip to main content

TorchFlint is a collection of advanced utility functions designed for tensor manipulation within the PyTorch ecosystem. It provides granular control over high-dimensional data processing, bridging the gap between high-level neural modules and low-level tensor operations.

Project description

TorchFlint: Advanced Tensor Operations and State Management for PyTorch

License Documentation Status Python Version PyTorch Version

TorchFlint is a collection of advanced utility functions designed for tensor manipulation within the PyTorch ecosystem. It provides granular control over high-dimensional data processing, bridging the gap between high-level neural modules and low-level tensor operations.

✨ Key Features

1. Transparent Buffer Management

Simplifies the registration of non-parameter states in nn.Module using a Pythonic assignment syntax.

  • torchflint.buffer(tensor, persistent): Wraps a tensor so it can be assigned directly inside a module. It automatically handles device movement and state_dict saving without using torch.nn.Module.register_buffer.
  • Containers: Includes BufferObject, BufferList and BufferDict for organized state management.

2. Patchwork

A core module for handling N-dimensional patches, serving as the foundation for convolution and pooling operations.

  • High Performance: Includes torchflint.patchwork.unfold_space, torchflint.patchwork.fold_space, and torchflint.patchwork.fold_stack. Notably, the speed of torchflint.patchwork.fold_stack is close to the official torch.nn.functional.fold (only support images less than 2D) in many scenarios.
  • N-Dimensional Support: Works seamlessly on 1D, 2D, and 3D+ spatial data.
  • Functional Convolution & Pooling: Built upon the torchflint.patchwork module, many functions expose the intermediate steps of convolution and pooling, like torchflint.patchwork.conv_patches, torchflint.patchwork.masked_conv, torchflint.patchwork.pool, torchflint.patchwork.max_pool, and torchflint.patchwork.masked_avg_pool. However, although many of them works well in some tests, but limited (masked version was only tested for their forward process), they will be tested more rigorously in the future.

3. Convenient Tensor Utilities

A suite of tools in torchflint.functional.

🛠️ Installation

pip install torchflint

Requirement: Python 3.7+ and PyTorch >= 1.6.0.

📚 Documentation

Complete API documentation:
https://torchflint.readthedocs.io/

🚀 Usage Examples

Using Buffers

Register buffers via simple assignment:

import torch
import torchflint


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Registers a persistent buffer named 'buffer'
        self.buffer = torchflint.buffer(torch.zeros(10), persistent=True)


model = MyModule()

Patch-based Operations

Using the high-performance patch functions:

from torchflint import patchwork


tensor = torch.randn(32, 3, 32, 32)


# Extract patches
patches = patchwork.unfold_space(tensor, kernel_size=3, stride=2, padding=1)


# Reconstruct
cumulative_output = patchwork.fold_stack(patches, stride=2)
average_output = patchwork.fold_space(patches, stride=2)

Convolution (Standard & Masked)

Perform convolutions, optionally applying a mask to handle sparse data or boundary validity.

from torchflint import patchwork


tensor = torch.randn(1, 3, 32, 32)
weight = torch.randn(16, 3, 3, 3) # Standard weight shape: [Out, In, K, K]


# Standard Convolution
output = patchwork.conv(tensor, weight, stride=2, padding=1)


# Masked Convolution
# Apply convolution only where input_mask is valid (1)
input_mask = torch.randn(1, 1, 32, 32) > 0.5
masked_output = patchwork.masked_conv(
    tensor, weight, 
    stride=2, padding=1, 
    input_mask=input_mask
)

Pooling (Standard & Masked)

Apply pooling operations. Masked pooling allows accurate statistical reduction by ignoring invalid regions instead of padding with zeros or infinity.

from torchflint import patchwork


tensor = torch.randn(1, 3, 32, 32)


# Standard Max Pooling
output = patchwork.max_pool(tensor, kernel_size=2, stride=2)


# Masked Max Pooling
# Regions marked as 0 in the mask are ignored during max calculation
# (Useful for non-rectangular images or padding handling)
mask = torch.ones_like(tensor)
mask[..., 0:5, 0:5] = 0 # Invalidate top-left corner
masked_output = patchwork.masked_max_pool(
    tensor, kernel_size=2, stride=2, 
    mask=mask.bool()
)

📄 License

This project is licensed under the MIT License.

Citation

@misc{torchflint,
  author = {Juanhua Zhang},
  title = {TorchFlint: Advanced Tensor Operations and State Management for PyTorch},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/Caewinix/torchflint}}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchflint-0.0.1b16.post1.dev3.tar.gz (40.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchflint-0.0.1b16.post1.dev3-py3-none-any.whl (43.1 kB view details)

Uploaded Python 3

File details

Details for the file torchflint-0.0.1b16.post1.dev3.tar.gz.

File metadata

  • Download URL: torchflint-0.0.1b16.post1.dev3.tar.gz
  • Upload date:
  • Size: 40.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.18

File hashes

Hashes for torchflint-0.0.1b16.post1.dev3.tar.gz
Algorithm Hash digest
SHA256 05366ef526331d3d33ff6207c70215d63d6d3335aa150b3e289e55936470bd2f
MD5 0175aaca19a8a7abb1a6d7a68c81f0eb
BLAKE2b-256 583e3e8dfbc4bbed404d2498d4b074e9037f831e23e947762e2208676dde2d4d

See more details on using hashes here.

File details

Details for the file torchflint-0.0.1b16.post1.dev3-py3-none-any.whl.

File metadata

File hashes

Hashes for torchflint-0.0.1b16.post1.dev3-py3-none-any.whl
Algorithm Hash digest
SHA256 a8f2c881c6163367d07ce176e02f71cea90712932e62793acc902533f9b1deb1
MD5 b7735c0701e91f3454e81e89e7e18216
BLAKE2b-256 2ec25126c8462c8e1a37083a699c8cfa86c960e6541c5d4b3bff5e37a37fff85

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page