PyTorch Implementation of Spatial Grouping Attention Layer
Project description
spatial-grouping-attention
PyTorch Implementation of Spatial Grouping Attention
Inspired by the spatial grouping layer in Native Segmentation Vision Transformers (https://arxiv.org/abs/2505.16993), implemented in PyTorch with a modified rotary position embedding generalized to N-dimensions and incorporating real-world pixel spacing.
Installation
From PyPI
You will first need to install PyTorch separately, as it is required for building one of our dependencies (natten). We recommend installing within a virtual environment, such as venv or mamba:
# create a virtual environment
mamba create -n spatial-attention -y python=3.11 pytorch ninja cmake
# activate the virtual environment
mamba activate spatial-attention
# install the package(s)
pip install spatial-grouping-attention
pip install natten==0.17.5 # requires python 3.11
From source
To install the latest development version directly from GitHub, follow the creation and activation of a virtual environment, as above, then run:
pip install git+https://github.com/rhoadesScholar/spatial-grouping-attention.git
Usage
The spatial grouping attention mechanism automatically computes query grid parameters (q_spacing and q_grid_shape) from the input key grid and convolution parameters. This makes it easy to use - you only need to specify the input resolution and the algorithm handles the spatial downsampling.
Basic 2D Dense Attention
import torch
from spatial_grouping_attention import DenseSpatialGroupingAttention
# Create attention module with 7x7 grouping kernel, stride=1 (no downsampling)
attention = DenseSpatialGroupingAttention(
feature_dims=128,
spatial_dims=2,
kernel_size=7, # 7x7 spatial grouping
stride=1, # No spatial downsampling
num_heads=8,
mlp_ratio=4
)
# Input: 32x32 image with 128 features per pixel
batch_size, height, width = 2, 32, 32
x = torch.randn(batch_size, height * width, 128)
# Only specify input (key) grid - query grid computed automatically
input_spacing = (0.5, 0.5) # 0.5 microns per pixel
input_grid_shape = (32, 32) # Input resolution
output = attention(x=x, input_spacing=input_spacing, input_grid_shape=input_grid_shape)
# Auto-computed: q_spacing = (0.5, 0.5) * 1 = (0.5, 0.5)
# Auto-computed: q_grid_shape = (32, 32) (no downsampling with stride=1)
print(f"Output shape: {output['x_out'].shape}") # (2, 1024, 128)
2D Attention with Spatial Downsampling
# Create attention with spatial downsampling for efficiency
downsampling_attention = DenseSpatialGroupingAttention(
feature_dims=256,
spatial_dims=2,
kernel_size=5, # 5x5 grouping kernel
stride=2, # 2x spatial downsampling
padding=2, # Maintain spatial coverage
num_heads=16
)
# High resolution input: 128x128 image
x_hires = torch.randn(1, 128*128, 256)
input_spacing_hires = (0.1, 0.1) # 0.1 mm per pixel (input)
input_grid_shape_hires = (128, 128) # High-res input grid
output_hires = downsampling_attention(
x=x_hires,
input_spacing=input_spacing_hires,
input_grid_shape=input_grid_shape_hires
)
# Auto-computed: q_spacing = (0.1, 0.1) * 2 = (0.2, 0.2)
# Auto-computed: q_grid_shape = (128+2*2-5)//2+1 = (64, 64)
print(f"Downsampled output: {output_hires['x_out'].shape}") # (1, 4096, 256)
print(f"Compression ratio: {128*128 / (64*64)}x") # 4x fewer points
3D Sparse Attention (GPU Required)
# 3D sparse attention for volumetric data (requires CUDA + natten)
try:
from spatial_grouping_attention import SparseSpatialGroupingAttention
sparse_3d = SparseSpatialGroupingAttention(
feature_dims=128,
spatial_dims=3,
kernel_size=(3, 5, 5), # Anisotropic: 3x5x5 grouping
stride=(1, 2, 2), # Downsample only in x,y
num_heads=8,
neighborhood_kernel=9 # Local attention window
)
# 3D volume: 16x64x64 voxels
depth, height, width = 16, 64, 64
x_3d = torch.randn(1, depth*height*width, 128).cuda()
# Anisotropic spacing (e.g., confocal microscopy)
input_spacing_3d = (0.5, 0.1, 0.1) # z, y, x spacing in microns
input_grid_shape_3d = (16, 64, 64)
output_3d = sparse_3d(
x=x_3d,
input_spacing=input_spacing_3d,
input_grid_shape=input_grid_shape_3d
)
# Auto-computed: q_spacing = (0.5*1, 0.1*2, 0.1*2) = (0.5, 0.2, 0.2)
# Auto-computed: q_grid_shape = (16, 32, 32) - downsampled in x,y only
print(f"3D sparse output: {output_3d['x_out'].shape}") # (1, 16384, 128)
except ImportError:
print("SparseSpatialGroupingAttention requires CUDA and natten package")
Multi-Scale Processing
# Process same input at multiple scales efficiently
multiscale_attention = DenseSpatialGroupingAttention(
feature_dims=64,
spatial_dims=2,
kernel_size=9,
stride=4, # 4x downsampling for global context
num_heads=4,
iters=3 # Multiple attention iterations
)
# Input image
x_input = torch.randn(1, 64*64, 64)
input_spacing = (1.0, 1.0) # 1 micron per pixel
input_grid_shape = (64, 64)
# Global context via 4x downsampling
global_output = multiscale_attention(
x=x_input,
input_spacing=input_spacing,
input_grid_shape=input_grid_shape
)
# Auto-computed: q_spacing = (4.0, 4.0), q_grid_shape = (16, 16)
print(f"Global context: {global_output['x_out'].shape}") # (1, 256, 64)
# Fine-scale processing with stride=1
fine_attention = DenseSpatialGroupingAttention(
feature_dims=64,
spatial_dims=2,
kernel_size=5,
stride=1, # Full resolution
num_heads=4
)
fine_output = fine_attention(
x=x_input,
input_spacing=input_spacing,
input_grid_shape=input_grid_shape
)
# Auto-computed: q_spacing = (1.0, 1.0), q_grid_shape = (64, 64)
print(f"Fine details: {fine_output['x_out'].shape}") # (1, 4096, 64)
Integration with Neural Networks
class HierarchicalSpatialNet(torch.nn.Module):
"""Multi-scale spatial processing network"""
def __init__(self, input_channels=3, num_classes=10):
super().__init__()
# Input embedding
self.embed = torch.nn.Linear(input_channels, 128)
# Coarse-scale attention (4x downsampling)
self.coarse_attention = DenseSpatialGroupingAttention(
feature_dims=128,
spatial_dims=2,
kernel_size=7,
stride=4, # 4x spatial compression
num_heads=8
)
# Fine-scale attention (2x downsampling)
self.fine_attention = DenseSpatialGroupingAttention(
feature_dims=128,
spatial_dims=2,
kernel_size=5,
stride=2, # 2x spatial compression
num_heads=8
)
# Cross-scale fusion
self.fusion = torch.nn.Linear(256, 128)
self.classifier = torch.nn.Linear(128, num_classes)
def forward(self, images, pixel_spacing=(1.0, 1.0)):
B, C, H, W = images.shape
# Flatten and embed
x = images.permute(0, 2, 3, 1).reshape(B, H*W, C)
x = self.embed(x)
# Multi-scale attention
coarse_out = self.coarse_attention(
x=x,
input_spacing=pixel_spacing,
input_grid_shape=(H, W)
)['x_out'] # (B, H*W/16, 128) - 4x downsampling
fine_out = self.fine_attention(
x=x,
input_spacing=pixel_spacing,
input_grid_shape=(H, W)
)['x_out'] # (B, H*W/4, 128) - 2x downsampling
# Upsample coarse to match fine resolution for fusion
coarse_upsampled = torch.nn.functional.interpolate(
coarse_out.transpose(1, 2).reshape(B, 128, H//4, W//4),
size=(H//2, W//2),
mode='bilinear',
align_corners=False
).reshape(B, 128, -1).transpose(1, 2)
# Fuse multi-scale features
fused = self.fusion(torch.cat([fine_out, coarse_upsampled], dim=-1))
# Global pooling and classification
global_features = fused.mean(dim=1)
return self.classifier(global_features)
# Usage example
net = HierarchicalSpatialNet(input_channels=3, num_classes=1000)
sample_images = torch.randn(4, 3, 128, 128) # ImageNet-style input
pixel_spacing = (0.1, 0.1) # 0.1 mm per pixel
predictions = net(sample_images, pixel_spacing)
print(f"Predictions: {predictions.shape}") # (4, 1000)
Key Principles
-
Automatic Grid Calculation: You only specify input (
input_spacing,input_grid_shape) - the query grid is computed as:q_spacing = input_spacing * strideq_grid_shape = (k_grid + 2*padding - kernel) // stride + 1
-
Spatial Grouping: The
kernel_sizedetermines how many neighboring points are grouped together for attention computation. -
Multi-Scale Processing: Use different
stridevalues to process the same input at multiple spatial scales efficiently. -
Memory Efficiency: Larger strides reduce the number of query points, making attention computation more efficient for large inputs.
Contributing
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Make your changes
- Run the test suite (
make test) - Commit your changes (
git commit -m 'Add some amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
License
BSD 3-Clause License. See LICENSE for details.
Citation
If you use this software in your research, please cite it using the information in CITATION.cff.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file spatial_grouping_attention-2025.9.11.1930.tar.gz.
File metadata
- Download URL: spatial_grouping_attention-2025.9.11.1930.tar.gz
- Upload date:
- Size: 16.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9426e3d2815bf5a82e1729fbaa917a31a46e51c91c4d4d2ac90311e51c11f1eb
|
|
| MD5 |
c5c8a18c36d82475fcc694f3e1274010
|
|
| BLAKE2b-256 |
3697b51b358b7eb598125d2204b1473913d23da232059df937747f1f699b067f
|
Provenance
The following attestation bundles were made for spatial_grouping_attention-2025.9.11.1930.tar.gz:
Publisher:
ci-cd.yml on rhoadesScholar/spatial-grouping-attention
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
spatial_grouping_attention-2025.9.11.1930.tar.gz -
Subject digest:
9426e3d2815bf5a82e1729fbaa917a31a46e51c91c4d4d2ac90311e51c11f1eb - Sigstore transparency entry: 501031144
- Sigstore integration time:
-
Permalink:
rhoadesScholar/spatial-grouping-attention@b130e5f2ffe70fceeb6434210a065780d526a5d4 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/rhoadesScholar
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci-cd.yml@b130e5f2ffe70fceeb6434210a065780d526a5d4 -
Trigger Event:
push
-
Statement type:
File details
Details for the file spatial_grouping_attention-2025.9.11.1930-py3-none-any.whl.
File metadata
- Download URL: spatial_grouping_attention-2025.9.11.1930-py3-none-any.whl
- Upload date:
- Size: 11.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f94fab44298be9262455dcdc616805fb8153d01df980bc6229d0a721ff09bf2a
|
|
| MD5 |
5764dba286a6bc2db8a4fc2c4e521f29
|
|
| BLAKE2b-256 |
e0e2abb20b9ef9992943b058ef314c60c758229b8507877e6dfc1ce3c3ed9322
|
Provenance
The following attestation bundles were made for spatial_grouping_attention-2025.9.11.1930-py3-none-any.whl:
Publisher:
ci-cd.yml on rhoadesScholar/spatial-grouping-attention
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
spatial_grouping_attention-2025.9.11.1930-py3-none-any.whl -
Subject digest:
f94fab44298be9262455dcdc616805fb8153d01df980bc6229d0a721ff09bf2a - Sigstore transparency entry: 501031157
- Sigstore integration time:
-
Permalink:
rhoadesScholar/spatial-grouping-attention@b130e5f2ffe70fceeb6434210a065780d526a5d4 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/rhoadesScholar
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci-cd.yml@b130e5f2ffe70fceeb6434210a065780d526a5d4 -
Trigger Event:
push
-
Statement type: