Skip to main content

GPU-accelerated stain normalization and augmentation

Project description

Torch-StainTools

A fast, GPU-friendly PyTorch toolkit for stain normalization and augmentation of histopathological images.

Torch-StainTools implements GPU-accelerated stain augmentation and normalization algorithms (Reinhard, Macenko, Vahadane) with batch processing and caching for on-the-fly large-scale computational pathology pipelines.

Unit Testing DOI

What's New (~ 1.0.7)

  • **Version **: full vectorization support and dynamic shape tracking from Dynamo.

  • Alternative linear concentration solvers: 'qr' (QR Decomposition) and 'pinv' (Moore-Penrose inverse)

  • Color/Texture-based Hash as cache key if no unique identifiers (e.g., filenames) are available.

What It Does

  • GPU acceleration and vectorized execution for batch inputs .

  • Optional TorchDynamo graph compilation (torch.compile) for high-throughput execution

  • On-the-fly stain normalization and augmentation.

  • Stain matrix caching to avoid redundant computation across tiles.

  • Encapsulation as nn.Module . Easy to plug into existing neural network pipelines.

  • Tissue masking support. Optional and customizable.

Citation

If this toolkit helps you in your publication, please feel free to cite with the following bibtex entry:

@software{zhou_2024_10496083,
  author       = {Zhou, Yufei},
  title        = {CielAl/torch-staintools: V1.0.4 Release},
  month        = jan,
  year         = 2024,
  publisher    = {Zenodo},
  version      = {v1.0.4},
  doi          = {10.5281/zenodo.10496083},
  url          = {https://doi.org/10.5281/zenodo.10496083}
}

Normalization Showcase

Torch-Staintools

Screenshot

Comparison: StainTools

Screenshot

Augmentation

Torch-Staintools

Screenshot

Comparison: StainTools

Screenshot

Performance Benchmark

Single Large ROI (2500 $\times$ 2500 $\times$ 3; No Caching)

  • Representative preprocessing scenario for large tissue ROIs.
  • GPU execution with TorchDynamo (torch.compile) enabled.
Transformation
Method CPU [s] GPU [s] StainTool [s]
Vahadane 119.00 4.60 20.90
Macenko 5.57 0.48 20.70
Reinhard 0.84 0.02 0.41
Fitting (Click to Expand)
Fitting (one-time cost)
Method CPU [s] GPU [s] StainTool [s]
Vahadane 132.00 5.20 19.10
Macenko 6.99 0.06 20.00
Reinhard 0.42 0.01 0.08

Batched Small Tiles (81 tiles, 256$\times$256$\times$3)

  • Splitting 2500 $\times$ 2500 $\times$ 3 ROI into a batch of 81 smaller patches (256$\times$256$\times$3).

  • Representative on-the-fly processing scenario for training and inference.

  • TorchDynamo (torch.compile) enabled.

Batch Transformation

Method No Cache [s] Stain Matrix Cached [s] Speedup
Vahadane 6.62 0.019 348x Faster
Macenko 0.023 0.020 1.15x Faster

Batchified Concentration Computation

  • Split the sample images under ./test_images (size 2500x2500x3) into 81 non-overlapping 256x256x3 tiles as a batch.
  • For the StainTools baseline, a for-loop is implemented to get the individual concentration of each of the numpy array of the 81 tiles.
  • torch.compile enabled.
Method CPU[s] GPU[s]
FISTA (concentration_solver='fista') 1.47 0.24
ISTA (concentration_solver='ista') 3.12 0.31
CD (concentration_solver='cd') 29.30s 4.87
LS (concentration_solver='ls') 0.22 0.097
StainTools (SPAMS) 16.60 N/A

Installation

  • From Repository:

pip install git+https://github.com/CielAl/torch-staintools.git

  • From PyPI:

pip install torch-staintools

Documentation

Detail documentation regarding the code base can be found in the GitPages.

Minimal Usage and Tips

  • For details, follow the example in demo.py
  • Normalizers are implemented as torch.nn.Module and can be integrated like a standalone network component.
  • qr and pinv concentration solvers are on par with ls for batch concentration computation. But ls (i.e., torch.linalg.lstsq) may fail on GPU for a single larger input image (width and height). This happens with the default cusolver backend. Try using magma instead:
import torch 
torch.backends.cuda.preferred_linalg_library('magma')

Example

# We enable the torch.compile (note this is True by default)
from torch_staintools.normalizer import NormalizerBuilder
# ######### Vahadane
target_tensor = ... # any batch float inputs in B x C x H x W, value range in [0., 1.] 
norm_tensor  = ... # any batch float inputs in B x C x H x W, value range in [0., 1.] 
target_tensor = target_tensor.cuda()
norm_tensor = norm_tensor.cuda()
normalizer_vahadane = NormalizerBuilder.build('vahadane',
                                              concentration_solver='qr',
                                              use_cache=True
                                              )
normalizer_vahadane = normalizer_vahadane.cuda()
normalizer_vahadane.fit(target_tensor)
norm_out = normalizer_vahadane(norm_tensor)

# ###### Augmentation
# augment by: alpha * concentration + beta, while alpha is uniformly randomly sampled from (1 - sigma_alpha, 1 + sigma_alpha),
# and beta is uniformly randomly sampled from (-sigma_beta, sigma_beta).
from torch_staintools.augmentor import AugmentorBuilder
augmentor = AugmentorBuilder.build('vahadane',
                                   use_cache=True,
                                   )
# move augmentor to the corresponding device
augmentor = augmentor.cuda()

num_augment = 5
# multiple copies of different random augmentation of the same tile may be generated
for _ in range(num_augment):
  aug_out = augmentor(norm_tensor)

# dump the cache of stain matrices for future usage
augmentor.dump_cache('./cache.pickle')

Stain Matrix Caching

Stain matrix estimation can dominate runtime (especially for Vahadane). To reduce overhead, Normalizer and Augmentor support an in-memory, device-specific cache for stain matrices (typically 2×3 for H&E/RGB).

Why it matters: cached stain matrices can be reused across images, yielding substantial speedups in batch and on-the-fly pipelines.

How it works

  • Cache contents can be saved and exported for reuse in future.
  • Enable with use_cache=True when constructing a Normalizer or Augmentor
  • Cached entries are keyed per image (e.g., filename or slide identifier)
  • For batched inputs (B×C×H×W), provide one key per image in the batch

Fallback behavior

  • If caching is enabled but no cache_key is provided, a texture- and color-based hash is computed automatically.
  • Visually similar images are likely to reuse stain matrices, while collisions across dissimilar images are minimized.

Enable Cache / Loading

# set `use_cache` to True
# specify `load_path` to read from existing cache data
NormalizerBuilder.build('vahadane',
                        concentration_solver='qr',
                        use_cache=True,
                        load_path='path_to_cache'
                        )
# Alternatively, read cache manually
normalizer.load_cache('path_to_cach')

If unique identifiers (UID) of images are available

# explicitly set cache_keys in normalization passes.
normalizer(input_batch, cache_keys=list_of_uid)
augmentor(input_batch, cache_keys=list_of_uid)

If cache_keys are not available

# color/texture-based hash keys are internally computed.
normalizer_vahadane(input_batch)
augmentor(input_batch)
# 

Export cache

# dump to path
normalizer.dump_cache("/folder/cache.tch")

Load existing cache

Acknowledgments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_staintools-1.0.7.tar.gz (54.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_staintools-1.0.7-py3-none-any.whl (67.7 kB view details)

Uploaded Python 3

File details

Details for the file torch_staintools-1.0.7.tar.gz.

File metadata

  • Download URL: torch_staintools-1.0.7.tar.gz
  • Upload date:
  • Size: 54.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.19

File hashes

Hashes for torch_staintools-1.0.7.tar.gz
Algorithm Hash digest
SHA256 ed71f3c9110801eab1fc227b3ab650d542ab215974fbf7c2782d61f3ed26158b
MD5 a0b1185adc20d8cbfb9cffff93bd5cc5
BLAKE2b-256 f4040c94826a66330c84b30195d6be1d0f7b487606fad4f94c3934e0a655521a

See more details on using hashes here.

File details

Details for the file torch_staintools-1.0.7-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_staintools-1.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 261e74dc13f033f5b167169c26ab0b6471854e5cd1b31a16ddb52f19b74d170d
MD5 2879c8ae9876f71fe2d22c9733800b55
BLAKE2b-256 22824c413b57a922625ef8d872851a4c7da00fe638533a2b667b2ad98fd7e1a5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page