Skip to main content

PyTorch-based implementation of non-local means with GPU support.

Project description

torch_nlm - memory efficient non-local means on PyTorch

Context

Non-local means takes a long time compute as it is quadratic for the number of pixels in an image. For simple images this is OK. For larger or three-dimensional (common in medical imaging) images this is impeditive for its application. Here, I introduce a PyTorch-based solution which uses convolutions to extract neighbours (non-local means here does not use the complete image but rather a neighbourhood with $nn$ pixels or $nn*n$ voxels) and calculates the non-local means average. By porting this to PyTorch we can make easy use of the very efficient GPU parallelization and speed up what oftentimes is a very time consuming algorithm.

When should you use it?

  • You want to run NLM for small images: just use scikit-image
  • You want to run NLM for bigger images AND you have a GPU: use this
  • You want to run NLM for relatively big images and you DO NOT have a GPU: good luck

Benchmark

I only benchmarked torch_nlm against scikit-image in 2d because the latter is prohibitively slow in 3d. Results below.

Brief explanation of non-local means (NLM)

For an image $I \in \mathbb{R}^{h \times w}$ consider pixel $I_{i,j}$ with coordinates $i,j$.

To obtain the non-local mean of this pixel:

$$\frac{1}{W}\sum{}^{h,w}{a,b=1} w(I{i,j},I_{a,b}) * I_{a,b}$$

where $w(I_{i,j},I_{a,b})$ is the weight of pixel $I_{a,b}$ given $I_{i,j}$ and $W=\sum{}^{h,w}{a,b=1} w(I{i,j},I_{a,b})$.

In other words, the non-local means of a given pixel is the weighted average of all pixels. Weights, here, are calculated as the $w(I_{i,j},I_{a,b}) = \exp(- \frac{(I_{i,j} - I_{a,b})^2}{h^2})$, where $h$ is a constant (akin to a standard deviation). To make computation tractable, rather than looping over all pixels, a simple solution is to restrict the neighbourhood to a small square surrounding the image - this is the solution used here.

Usage

Installation

To use this package all you have to do is clone and install this (a pyproject.toml is provided so that you can easily install this with poetry). Alternatively, use requirements.txt with pip (i.e. pip install -r requirements.txt).

Installation with pip: this is probably the version which will be the less painful to use:

pip install torch_nlm

Or, if you already have all the dependencies:

pip install torch_nlm --no-dependencies

Installation with setup.py: also easy to use:

python setup.py install

torch_nlm usage

Two main functions are exported: nlm2d and nlm3d, which are aliases for the most efficient torch-based NLMM versions (apply_nonlocal_means_2d_mem_efficient and apply_nonlocal_means_3d_mem_efficient), respectively. So if you want to apply it to your favourite image and have a CUDA compatible GPU:

import torch # necessary for obvious reasons
from torch_nlm import nlm2d

image = ... # here you define your image

# allocate image to your favourite device
image_torch = torch.as_tensor(image).to("cuda")

image_nlm = nlm2d(image_torch, # the image
                  kernel_size=11, # neighbourhood size 
                  std=1.0, # the sigma 
                  kernel_size_mean=3, # the kernel used to compute the average pixel intensity
                  sub_filter_size=32 # how many neighbourhoods are computed per iteration
                  )

sub_filter_size is what allows large neighbourhoods - given that users may have relatively small GPU cards, they may opt for smaller sub_filter_sizes which will enable them to load much smaller sets of neighbourhoods for distance/weight calculations. You may want to run a few tests to figure out the best sub_filter_size before deploying this en masse.

Since GPU allocation can be time consuming and the user may have a lot of images to process, it might not be a terrible idea to process images as batches rather than as separate scripts.

Implementation details

This code was optimized for speed. Three main functions are provided here - apply_nonlocal_means_2d, apply_windowed_nonlocal_means_2d and apply_nonlocal_means_2d_mem_efficient. The first two are development versions, the latter is the one you should use (exposed as nlm_2d).

apply_nonlocal_means_2d

Retrieves all neighbours as a large tensor and calculates the NLM of the image.

Problems: large neighbourhoods will lead to OOM

apply_windowed_nonlocal_means_2d

Does the same as apply_nonlocal_means_2d but uses strided patches to do this, thus reducing memory requirements.

Problems: leads to visible striding patch artifacts

apply_nonlocal_means_2d_mem_efficient

Does the same as apply_nonlocal_means_2d but loops over sets of neighbourhoods to calculate weights. $W$ is updated as an accummulator. This version requires defining a "batch size" analogue so that only a few neighbours are calculated at a time. This allows for massive neighbourhoods to be computed in a very parallel fashion.

Problems: none for now! But time is a teacher to us all.

Generalising to 3D

The good aspect of this is that it requires very little effort to generalise these functions to 3D. These are made available with the same names as above but replacing 2d with 3d. The version you want to use is nlm_3d.

Expected runtime

For a large image such as assets/threegorges-1024x1127.jpg (source; size: 1024x1127), apply_nonlocal_means_2d_mem_efficient takes ~3-5 seconds with a neighbourhood with 51x51 pixels when running on GPU.

Example below (obtained by running python test.py assets/threegorges-1024x1127.jpg):

  • First panel: original image
  • Second panel: original image + noise
  • Third panel: original image + noise + NLM
  • Fourth panel: difference between original image and original image + noise + NLM

Note on benchmarking: while 2d benchmarks are reasonable, 3d benchmarks will take a lot of time because of scikit-image's implementation. Expect times of ~4,000 seconds for a $256 \times 256 \times 256$ images with a neighbourhood size of 17 (torch_nlm ran in ~70-80 seconds 😊). You will need scikit-image for benchmarking.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nlm_torch_new-0.1.3.tar.gz (8.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nlm_torch_new-0.1.3-py3-none-any.whl (11.5 kB view details)

Uploaded Python 3

File details

Details for the file nlm_torch_new-0.1.3.tar.gz.

File metadata

  • Download URL: nlm_torch_new-0.1.3.tar.gz
  • Upload date:
  • Size: 8.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for nlm_torch_new-0.1.3.tar.gz
Algorithm Hash digest
SHA256 8b99242bc547a4b633059a911d4d90a6eef62019a66e1bf6cf4c9c43c7fc8633
MD5 c764c9104164f1fd32ba82904b76185a
BLAKE2b-256 6a92ea672269dc0524920fee49ba1f51db371a8d67bda27be97c84635dd177c2

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlm_torch_new-0.1.3.tar.gz:

Publisher: pypi.yml on lucianchauvin/torch_nlm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nlm_torch_new-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: nlm_torch_new-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 11.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for nlm_torch_new-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 f6bc972567360803d7953ad9dc4b93f8bc3f1e2fb037247b3c9785c07af68c33
MD5 244df43e2f02f90a36e3848cfed2583e
BLAKE2b-256 7f34631a2a0217283f6847269b11e6448a716f0c1d3ceaba29789c92350949b3

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlm_torch_new-0.1.3-py3-none-any.whl:

Publisher: pypi.yml on lucianchauvin/torch_nlm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page