Skip to main content

PyTorch Implementation of Spatial Grouping Attention Layer

Project description

spatial-grouping-attention

PyTorch Implementation of Spatial Grouping Attention

GitHub - License CI/CD Pipeline codecov PyPI - Version PyPI - Python Version

Inspired by the spatial grouping layer in Native Segmentation Vision Transformers (https://arxiv.org/abs/2505.16993), implemented in PyTorch with a modified rotary position embedding generalized to N-dimensions and incorporating real-world pixel spacing.

Installation

From PyPI

You will first need to install PyTorch separately, as it is required for building one of our dependencies (natten). We recommend installing within a virtual environment, such as venv or mamba:

# create a virtual environment
mamba create -n spatial-attention -y python=3.11 pytorch ninja cmake

# activate the virtual environment
mamba activate spatial-attention

# install the package(s)
pip install spatial-grouping-attention
pip install natten==0.17.5 # requires python 3.11

From source

To install the latest development version directly from GitHub, follow the creation and activation of a virtual environment, as above, then run:

pip install git+https://github.com/rhoadesScholar/spatial-grouping-attention.git

Usage

The spatial grouping attention mechanism automatically computes query grid parameters (q_spacing and q_grid_shape) from the input key grid and convolution parameters. This makes it easy to use - you only need to specify the input resolution and the algorithm handles the spatial downsampling.

Basic 2D Dense Attention

import torch
from spatial_grouping_attention import DenseSpatialGroupingAttention

# Create attention module with 7x7 grouping kernel, stride=1 (no downsampling)
attention = DenseSpatialGroupingAttention(
    feature_dims=128,
    spatial_dims=2,
    kernel_size=7,        # 7x7 spatial grouping
    stride=1,            # No spatial downsampling
    num_heads=8,
    mlp_ratio=4
)

# Input: 32x32 image with 128 features per pixel
batch_size, height, width = 2, 32, 32
x = torch.randn(batch_size, height * width, 128)

# Only specify input (key) grid - query grid computed automatically
input_spacing = (0.5, 0.5)      # 0.5 microns per pixel
input_grid_shape = (32, 32)     # Input resolution

output = attention(x=x, input_spacing=input_spacing, input_grid_shape=input_grid_shape)
# Auto-computed: q_spacing = (0.5, 0.5) * 1 = (0.5, 0.5)
# Auto-computed: q_grid_shape = (32, 32) (no downsampling with stride=1)
print(f"Output shape: {output['x_out'].shape}")  # (2, 1024, 128)

2D Attention with Spatial Downsampling

# Create attention with spatial downsampling for efficiency
downsampling_attention = DenseSpatialGroupingAttention(
    feature_dims=256,
    spatial_dims=2,
    kernel_size=5,        # 5x5 grouping kernel
    stride=2,            # 2x spatial downsampling
    padding=2,           # Maintain spatial coverage
    num_heads=16
)

# High resolution input: 128x128 image
x_hires = torch.randn(1, 128*128, 256)
input_spacing_hires = (0.1, 0.1)        # 0.1 mm per pixel (input)
input_grid_shape_hires = (128, 128)     # High-res input grid

output_hires = downsampling_attention(
    x=x_hires,
    input_spacing=input_spacing_hires,
    input_grid_shape=input_grid_shape_hires
)
# Auto-computed: q_spacing = (0.1, 0.1) * 2 = (0.2, 0.2)
# Auto-computed: q_grid_shape = (128+2*2-5)//2+1 = (64, 64)
print(f"Downsampled output: {output_hires['x_out'].shape}")  # (1, 4096, 256)
print(f"Compression ratio: {128*128 / (64*64)}x")        # 4x fewer points

3D Sparse Attention (GPU Required)

# 3D sparse attention for volumetric data (requires CUDA + natten)
try:
    from spatial_grouping_attention import SparseSpatialGroupingAttention

    sparse_3d = SparseSpatialGroupingAttention(
        feature_dims=128,
        spatial_dims=3,
        kernel_size=(3, 5, 5),    # Anisotropic: 3x5x5 grouping
        stride=(1, 2, 2),         # Downsample only in x,y
        num_heads=8,
        neighborhood_kernel=9     # Local attention window
    )

    # 3D volume: 16x64x64 voxels
    depth, height, width = 16, 64, 64
    x_3d = torch.randn(1, depth*height*width, 128).cuda()

    # Anisotropic spacing (e.g., confocal microscopy)
    input_spacing_3d = (0.5, 0.1, 0.1)     # z, y, x spacing in microns
    input_grid_shape_3d = (16, 64, 64)

    output_3d = sparse_3d(
        x=x_3d,
        input_spacing=input_spacing_3d,
        input_grid_shape=input_grid_shape_3d
    )
    # Auto-computed: q_spacing = (0.5*1, 0.1*2, 0.1*2) = (0.5, 0.2, 0.2)
    # Auto-computed: q_grid_shape = (16, 32, 32) - downsampled in x,y only
    print(f"3D sparse output: {output_3d['x_out'].shape}")  # (1, 16384, 128)

except ImportError:
    print("SparseSpatialGroupingAttention requires CUDA and natten package")

Multi-Scale Processing

# Process same input at multiple scales efficiently
multiscale_attention = DenseSpatialGroupingAttention(
    feature_dims=64,
    spatial_dims=2,
    kernel_size=9,
    stride=4,            # 4x downsampling for global context
    num_heads=4,
    iters=3              # Multiple attention iterations
)

# Input image
x_input = torch.randn(1, 64*64, 64)
input_spacing = (1.0, 1.0)              # 1 micron per pixel
input_grid_shape = (64, 64)

# Global context via 4x downsampling
global_output = multiscale_attention(
    x=x_input,
    input_spacing=input_spacing,
    input_grid_shape=input_grid_shape
)
# Auto-computed: q_spacing = (4.0, 4.0), q_grid_shape = (16, 16)
print(f"Global context: {global_output['x_out'].shape}")  # (1, 256, 64)

# Fine-scale processing with stride=1
fine_attention = DenseSpatialGroupingAttention(
    feature_dims=64,
    spatial_dims=2,
    kernel_size=5,
    stride=1,            # Full resolution
    num_heads=4
)

fine_output = fine_attention(
    x=x_input,
    input_spacing=input_spacing,
    input_grid_shape=input_grid_shape
)
# Auto-computed: q_spacing = (1.0, 1.0), q_grid_shape = (64, 64)
print(f"Fine details: {fine_output['x_out'].shape}")      # (1, 4096, 64)

Integration with Neural Networks

class HierarchicalSpatialNet(torch.nn.Module):
    """Multi-scale spatial processing network"""

    def __init__(self, input_channels=3, num_classes=10):
        super().__init__()

        # Input embedding
        self.embed = torch.nn.Linear(input_channels, 128)

        # Coarse-scale attention (4x downsampling)
        self.coarse_attention = DenseSpatialGroupingAttention(
            feature_dims=128,
            spatial_dims=2,
            kernel_size=7,
            stride=4,         # 4x spatial compression
            num_heads=8
        )

        # Fine-scale attention (2x downsampling)
        self.fine_attention = DenseSpatialGroupingAttention(
            feature_dims=128,
            spatial_dims=2,
            kernel_size=5,
            stride=2,         # 2x spatial compression
            num_heads=8
        )

        # Cross-scale fusion
        self.fusion = torch.nn.Linear(256, 128)
        self.classifier = torch.nn.Linear(128, num_classes)

    def forward(self, images, pixel_spacing=(1.0, 1.0)):
        B, C, H, W = images.shape

        # Flatten and embed
        x = images.permute(0, 2, 3, 1).reshape(B, H*W, C)
        x = self.embed(x)

        # Multi-scale attention
        coarse_out = self.coarse_attention(
            x=x,
            input_spacing=pixel_spacing,
            input_grid_shape=(H, W)
        )['x_out']  # (B, H*W/16, 128) - 4x downsampling

        fine_out = self.fine_attention(
            x=x,
            input_spacing=pixel_spacing,
            input_grid_shape=(H, W)
        )['x_out']   # (B, H*W/4, 128) - 2x downsampling

        # Upsample coarse to match fine resolution for fusion
        coarse_upsampled = torch.nn.functional.interpolate(
            coarse_out.transpose(1, 2).reshape(B, 128, H//4, W//4),
            size=(H//2, W//2),
            mode='bilinear',
            align_corners=False
        ).reshape(B, 128, -1).transpose(1, 2)

        # Fuse multi-scale features
        fused = self.fusion(torch.cat([fine_out, coarse_upsampled], dim=-1))

        # Global pooling and classification
        global_features = fused.mean(dim=1)
        return self.classifier(global_features)

# Usage example
net = HierarchicalSpatialNet(input_channels=3, num_classes=1000)
sample_images = torch.randn(4, 3, 128, 128)  # ImageNet-style input
pixel_spacing = (0.1, 0.1)  # 0.1 mm per pixel

predictions = net(sample_images, pixel_spacing)
print(f"Predictions: {predictions.shape}")  # (4, 1000)

Key Principles

  1. Automatic Grid Calculation: You only specify input (input_spacing, input_grid_shape) - the query grid is computed as:

    • q_spacing = input_spacing * stride
    • q_grid_shape = (k_grid + 2*padding - kernel) // stride + 1
  2. Spatial Grouping: The kernel_size determines how many neighboring points are grouped together for attention computation.

  3. Multi-Scale Processing: Use different stride values to process the same input at multiple spatial scales efficiently.

  4. Memory Efficiency: Larger strides reduce the number of query points, making attention computation more efficient for large inputs.

Contributing

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Make your changes
  4. Run the test suite (make test)
  5. Commit your changes (git commit -m 'Add some amazing feature')
  6. Push to the branch (git push origin feature/amazing-feature)
  7. Open a Pull Request

License

BSD 3-Clause License. See LICENSE for details.

Citation

If you use this software in your research, please cite it using the information in CITATION.cff.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spatial_grouping_attention-2025.9.11.1934.tar.gz (15.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file spatial_grouping_attention-2025.9.11.1934.tar.gz.

File metadata

File hashes

Hashes for spatial_grouping_attention-2025.9.11.1934.tar.gz
Algorithm Hash digest
SHA256 78c942c0e83dc08c5c83a487a3c7b0d18dd9548a8befd17e9b97bdfd4c40ab60
MD5 d1043846bb3566c7179f8acc48feba6c
BLAKE2b-256 4d865cfaea60a71a4f2f245f3f5b8ae9760adfbfb5f69e3c84415b662d5a4150

See more details on using hashes here.

Provenance

The following attestation bundles were made for spatial_grouping_attention-2025.9.11.1934.tar.gz:

Publisher: ci-cd.yml on rhoadesScholar/spatial-grouping-attention

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file spatial_grouping_attention-2025.9.11.1934-py3-none-any.whl.

File metadata

File hashes

Hashes for spatial_grouping_attention-2025.9.11.1934-py3-none-any.whl
Algorithm Hash digest
SHA256 e0695a56ddad27b998f70e9830058f56f12000c3e8029e56460e7bc85a2049ed
MD5 d37af26b3cf938210df9564e503b8f75
BLAKE2b-256 70986259644ac3e7bc33a3e7ff2735c33131c446125823300559984afe50e12a

See more details on using hashes here.

Provenance

The following attestation bundles were made for spatial_grouping_attention-2025.9.11.1934-py3-none-any.whl:

Publisher: ci-cd.yml on rhoadesScholar/spatial-grouping-attention

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page