Skip to main content

TorchFlint is a collection of advanced utility functions designed for tensor manipulation within the PyTorch ecosystem. It provides granular control over high-dimensional data processing, bridging the gap between high-level neural modules and low-level tensor operations.

Project description

TorchFlint: Advanced Tensor Operations and State Management for PyTorch

License Documentation Status Python Version PyTorch Version

TorchFlint is a collection of advanced utility functions designed for tensor manipulation within the PyTorch ecosystem. It provides granular control over high-dimensional data processing, bridging the gap between high-level neural modules and low-level tensor operations.

✨ Key Features

1. Transparent Buffer Management

Simplifies the registration of non-parameter states in nn.Module using a Pythonic assignment syntax.

  • torchflint.buffer(tensor, persistent): Wraps a tensor so it can be assigned directly inside a module. It automatically handles device movement and state_dict saving without using torch.nn.Module.register_buffer.
  • Containers: Includes BufferObject, BufferList and BufferDict for organized state management.

2. Patchwork

A core module for handling N-dimensional patches, serving as the foundation for convolution and pooling operations.

  • High Performance: Includes torchflint.patchwork.unfold_space, torchflint.patchwork.fold_space, and torchflint.patchwork.fold_stack. Notably, the speed of torchflint.patchwork.fold_stack is close to the official torch.nn.functional.fold (only support images less than 2D) in many scenarios.
  • N-Dimensional Support: Works seamlessly on 1D, 2D, and 3D+ spatial data.
  • Functional Convolution & Pooling: Built upon the torchflint.patchwork module, many functions expose the intermediate steps of convolution and pooling, like torchflint.patchwork.conv_patches, torchflint.patchwork.masked_conv, torchflint.patchwork.pool, torchflint.patchwork.max_pool, and torchflint.patchwork.masked_avg_pool. However, although many of them works well in some tests, but limited (masked version was only tested for their forward process), they will be tested more rigorously in the future.

3. Convenient Tensor Utilities

A suite of tools in torchflint.functional.

🛠️ Installation

pip install torchflint

Requirement: Python 3.7+ and PyTorch >= 1.6.0.

📚 Documentation

Complete API documentation:
https://torchflint.readthedocs.io/

🚀 Usage Examples

Using Buffers

Register buffers via simple assignment:

import torch
import torchflint


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Registers a persistent buffer named 'buffer'
        self.buffer = torchflint.buffer(torch.zeros(10), persistent=True)


model = MyModule()

Patch-based Operations

Using the high-performance patch functions:

from torchflint import patchwork


tensor = torch.randn(32, 3, 32, 32)


# Extract patches
patches = patchwork.unfold_space(tensor, kernel_size=3, stride=2, padding=1)


# Reconstruct
cumulative_output = patchwork.fold_stack(patches, stride=2)
average_output = patchwork.fold_space(patches, stride=2)

Convolution (Standard & Masked)

Perform convolutions, optionally applying a mask to handle sparse data or boundary validity.

from torchflint import patchwork


tensor = torch.randn(1, 3, 32, 32)
weight = torch.randn(16, 3, 3, 3) # Standard weight shape: [Out, In, K, K]


# Standard Convolution
output = patchwork.conv(tensor, weight, stride=2, padding=1)


# Masked Convolution
# Apply convolution only where input_mask is valid (1)
input_mask = torch.randn(1, 1, 32, 32) > 0.5
masked_output = patchwork.masked_conv(
    tensor, weight, 
    stride=2, padding=1, 
    input_mask=input_mask
)

Pooling (Standard & Masked)

Apply pooling operations. Masked pooling allows accurate statistical reduction by ignoring invalid regions instead of padding with zeros or infinity.

from torchflint import patchwork


tensor = torch.randn(1, 3, 32, 32)


# Standard Max Pooling
output = patchwork.max_pool(tensor, kernel_size=2, stride=2)


# Masked Max Pooling
# Regions marked as 0 in the mask are ignored during max calculation
# (Useful for non-rectangular images or padding handling)
mask = torch.ones_like(tensor)
mask[..., 0:5, 0:5] = 0 # Invalidate top-left corner
masked_output = patchwork.masked_max_pool(
    tensor, kernel_size=2, stride=2, 
    mask=mask.bool()
)

📄 License

This project is licensed under the MIT License.

Citation

@misc{torchflint,
  author = {Juanhua Zhang},
  title = {TorchFlint: Advanced Tensor Operations and State Management for PyTorch},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/Caewinix/torchflint}}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchflint-0.0.1b16.post1.dev1.tar.gz (37.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchflint-0.0.1b16.post1.dev1-py3-none-any.whl (39.2 kB view details)

Uploaded Python 3

File details

Details for the file torchflint-0.0.1b16.post1.dev1.tar.gz.

File metadata

  • Download URL: torchflint-0.0.1b16.post1.dev1.tar.gz
  • Upload date:
  • Size: 37.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.18

File hashes

Hashes for torchflint-0.0.1b16.post1.dev1.tar.gz
Algorithm Hash digest
SHA256 dcbbbbe2c05d32d4f6ecafe5b64d9b3b676a37371a636fee61d2bd6f94211cf4
MD5 5be387cbc730196390b58aac8f00af8b
BLAKE2b-256 7afd880cc7f88480525dd78c070b26e9d3b6b42779669cb7eea757dec3b7a5dc

See more details on using hashes here.

File details

Details for the file torchflint-0.0.1b16.post1.dev1-py3-none-any.whl.

File metadata

File hashes

Hashes for torchflint-0.0.1b16.post1.dev1-py3-none-any.whl
Algorithm Hash digest
SHA256 7b43bf0514239df22e1ee9350a7627ecd2bb74934703ac6896e7fab575db0cda
MD5 bc56ce70cf83cddc83a5c6aac7086c5e
BLAKE2b-256 4d81f0b753e240d8462eedb3b34adcf3cd329c3d77d5073d48bf6e3cc8612758

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page