TorchFlint is a collection of advanced utility functions designed for tensor manipulation within the PyTorch ecosystem. It provides granular control over high-dimensional data processing, bridging the gap between high-level neural modules and low-level tensor operations.
Project description
TorchFlint: Advanced Tensor Operations and State Management for PyTorch
TorchFlint is a collection of advanced utility functions designed for tensor manipulation within the PyTorch ecosystem. It provides granular control over high-dimensional data processing, bridging the gap between high-level neural modules and low-level tensor operations.
✨ Key Features
1. Transparent Buffer Management
Simplifies the registration of non-parameter states in nn.Module using a Pythonic assignment syntax.
torchflint.buffer(tensor, persistent): Wraps a tensor so it can be assigned directly inside a module. It automatically handles device movement andstate_dictsaving without usingtorch.nn.Module.register_buffer.- Containers: Includes
BufferObject,BufferListandBufferDictfor organized state management.
2. Patchwork
A core module for handling N-dimensional patches, serving as the foundation for convolution and pooling operations.
- High Performance: Includes
torchflint.patchwork.unfold_space,torchflint.patchwork.fold_space, andtorchflint.patchwork.fold_stack. Notably,torchflint.patchwork.fold_stackis faster than the officialtorch.nn.functional.foldin many scenarios. - N-Dimensional Support: Works seamlessly on 1D, 2D, and 3D+ spatial data.
- Functional Convolution & Pooling: Built upon the
torchflint.patchworkmodule, many functions expose the intermediate steps of convolution and pooling, liketorchflint.patchwork.conv_patches,torchflint.patchwork.masked_conv,torchflint.patchwork.pool,torchflint.patchwork.max_pool, andtorchflint.patchwork.masked_avg_pool. However, although many of them works well in some tests, but limited (masked version was only tested for their forward process), they will be tested more rigorously in the future.
3. Convenient Tensor Utilities
A suite of tools in torchflint.functional.
🛠️ Installation
pip install torchflint
Requirement: Python 3.7+ and PyTorch >= 1.6.0.
📚 Documentation
Complete API documentation:
https://torchflint.readthedocs.io/
🚀 Usage Examples
Using Buffers
Register buffers via simple assignment:
import torch
import torchflint
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Registers a persistent buffer named 'buffer'
self.buffer = torchflint.buffer(torch.zeros(10), persistent=True)
model = MyModule()
Patch-based Operations
Using the high-performance patch functions:
from torchflint import patchwork
tensor = torch.randn(32, 3, 32, 32)
# Extract patches
patches = patchwork.unfold_space(tensor, kernel_size=3, stride=2, padding=1)
# Reconstruct
cumulative_output = patchwork.fold_stack(patches, stride=2)
average_output = patchwork.fold_space(patches, stride=2)
Convolution (Standard & Masked)
Perform convolutions, optionally applying a mask to handle sparse data or boundary validity.
from torchflint import patchwork
tensor = torch.randn(1, 3, 32, 32)
weight = torch.randn(16, 3, 3, 3) # Standard weight shape: [Out, In, K, K]
# Standard Convolution
output = patchwork.conv(tensor, weight, stride=2, padding=1)
# Masked Convolution
# Apply convolution only where input_mask is valid (1)
input_mask = torch.randn(1, 1, 32, 32) > 0.5
masked_output = patchwork.masked_conv(
tensor, weight,
stride=2, padding=1,
input_mask=input_mask
)
Pooling (Standard & Masked)
Apply pooling operations. Masked pooling allows accurate statistical reduction by ignoring invalid regions instead of padding with zeros or infinity.
from torchflint import patchwork
tensor = torch.randn(1, 3, 32, 32)
# Standard Max Pooling
output = patchwork.max_pool(tensor, kernel_size=2, stride=2)
# Masked Max Pooling
# Regions marked as 0 in the mask are ignored during max calculation
# (Useful for non-rectangular images or padding handling)
mask = torch.ones_like(tensor)
mask[..., 0:5, 0:5] = 0 # Invalidate top-left corner
masked_output = patchwork.masked_max_pool(
tensor, kernel_size=2, stride=2,
mask=mask.bool()
)
📄 License
This project is licensed under the MIT License.
Citation
@misc{torchflint,
author = {Juanhua Zhang},
title = {TorchFlint: Advanced Tensor Operations and State Management for PyTorch},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/Caewinix/torchflint}}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchflint-0.0.1b16.post1.dev0.tar.gz.
File metadata
- Download URL: torchflint-0.0.1b16.post1.dev0.tar.gz
- Upload date:
- Size: 37.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
38294fcf3d1b2c9519a5ec4e1c5ef0e517d18432f46ca4972c63f3cabbe4208f
|
|
| MD5 |
4a7b30b130a7c1a5e9c78e68e8e67527
|
|
| BLAKE2b-256 |
34f1a3eeb85a2ea94050e1274282187dc5cb0f8fcd04bc974d4b0a40b078796e
|
File details
Details for the file torchflint-0.0.1b16.post1.dev0-py3-none-any.whl.
File metadata
- Download URL: torchflint-0.0.1b16.post1.dev0-py3-none-any.whl
- Upload date:
- Size: 39.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
51a9ec4cd35a8f34170b9521dd3ba234d3f35edaeb183c8fa1b7bd886968e4e2
|
|
| MD5 |
d89211061b203c48f4d9d523ea7735ef
|
|
| BLAKE2b-256 |
6ce74d7798cecd9603d90001d39bebe664b0d121d04bd42a941c5bd5f751e949
|