Skip to main content

JAX/GPU N4 bias field correction — a fast drop-in match for ITK N4

Project description

n4ax

N4 bias field correction in pure JAX — a fast, GPU-friendly, drop-in match for ITK / SimpleITK's N4BiasFieldCorrectionImageFilter.

n4ax reimplements the N4 algorithm (Tustison et al., 2010 — N3 histogram sharpening

  • multi-resolution B-spline) faithfully enough to match SimpleITK to ~1% on real MRI, while running ~1500× faster on a GPU and ~20× faster on the same CPU.

NKI raw vs n4ax-corrected vs ITK-corrected Raw NKI T1w (with B1 shading) → n4ax-corrected → ITK-corrected (visually identical) → estimated bias field.

Why

N4 is the de-facto standard bias correction, but ITK's implementation is CPU-only and slow (minutes per volume). In a GPU MRI pipeline it becomes the bottleneck. n4ax gives N4-quality output on the GPU in tens of milliseconds, with no custom CUDA — just JAX.

Install

uv sync --extra cuda12        # GPU (CUDA 12)
uv sync --extra cpu           # CPU
uv sync --extra cuda12 --extra dev      # + tests/linting
uv sync --extra cuda12 --extra compare  # + SimpleITK/matplotlib for benchmarks

Usage

import nibabel as nib
import n4ax

vol = nib.load("t1w.nii.gz").get_fdata()      # 3D (or 2D) array, intensities >= 0
corrected = n4ax.n4(vol)                        # Otsu mask computed automatically
# or pass your own mask, and/or get the log bias field:
corrected, log_bias = n4ax.n4(vol, mask=mask, return_bias=True)

corrected == vol / exp(log_bias). The default config (iters=(8,12,12,8), over_relax=1.8) is tuned for speed; for the tightest ITK match use the robust fallback n4ax.n4(vol, iters=(50,50,30,20), over_relax=1.0, conv_threshold=1.5e-3).

Benchmark

Real NKI T1w volumes (256×176×256, ~2 M brain voxels), N4 [50,50,30,20], same Otsu mask. ITK on an 8-core CPU; n4ax CPU on the same node; n4ax GPU on an NVIDIA A100.

Method Time / volume Speedup vs ITK
ITK N4 (CPU, 8 cores) 146 s
n4ax (CPU, 8 cores) 7.7 s ~19×
n4ax (A100 GPU) 93 ms ~1571×

Accuracy vs ITK (corrected image, global scale removed — pipelines intensity-normalise anyway): mean 1.15 %, per-subject 0.79–1.59 % over 6 NKI scans. On a single fitting level n4ax matches ITK to 0.4 %, and a single N4 iteration to 0.1 % — the building blocks are exact; the residual is N4's own iterative crawl (ITK itself only converges after ~30 iters/level).

Multiple subjects, raw (top) vs n4ax-corrected (bottom):

NKI grid

Reproduce: python scripts/bench_nki.py (GPU) and JAX_PLATFORMS=cpu python scripts/bench_nki.py --skip-itk --skip-fig --tag cpu.

How it's fast (no custom kernels)

  • Separable B-spline fit. N4's per-iteration B-spline least-squares (Lee MBA) is a 94 M-way scatter into a tiny control lattice — brutal atomic contention (~30 ms/iter). Because the cubic weights depend only on the per-axis index and the Lee denominator factorises, this becomes 3 small dense matmuls per axis (cuBLAS) — identical math, 0.1 ms/iter.
  • Privatised histogram. The N3 sharpening histogram (1.5 M → 200 bins) is privatised over 256 lanes to avoid atomic serialisation.
  • Over-relaxation. N4's fixed point is invariant to B += α·S (S = 0 there), so α ≈ 1.8 reaches ITK's result in far fewer iterations.
  • The whole solve is one fused, jitted program with a device-side convergence loop.

Two things that mattered for correctness: zero-padding the sharpening FFT (circular wraparound otherwise breaks convergence), and that float32 == float64 here (verified).

Tests

uv run pytest          # basic correctness + ground-truth match vs SimpleITK

tests/test_vs_itk.py asserts n4ax matches SimpleITK's N4 (the reference) within tolerance on a phantom; tests/test_basic.py covers shapes, 2D/3D, the image/exp(bias) identity, bias flattening, and the Otsu mask.

Status

Alpha. The fast defaults are tuned on NKI/phantom data; validate on your own data before production (the iters=(50,50,30,20), over_relax=1.0 fallback is the conservative choice).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

n4ax-0.1.0.tar.gz (10.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

n4ax-0.1.0-py3-none-any.whl (8.7 kB view details)

Uploaded Python 3

File details

Details for the file n4ax-0.1.0.tar.gz.

File metadata

  • Download URL: n4ax-0.1.0.tar.gz
  • Upload date:
  • Size: 10.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for n4ax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 3dcf3fda31dd70fc8befb55f4f8bb3a2cc41b488f5b6e09a209bff29e2954a15
MD5 cc70dd5ebee2013044110a6311c6fe20
BLAKE2b-256 a6b9ac404f1c08e5657b8d51843dfbebaa1f0708544fad3b5d3dff1edede99d9

See more details on using hashes here.

Provenance

The following attestation bundles were made for n4ax-0.1.0.tar.gz:

Publisher: release.yml on GragasLab/n4ax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file n4ax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: n4ax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for n4ax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 653fd1b8e27e02b4b0b4d29d3599800d35dde765921a9d36e4bef60dcdf782c6
MD5 fe213ff5dd13756ea567db7587ed59c5
BLAKE2b-256 af85d7b85b2f162f9b098392bdfdb124397798f941fae5c5bebb1f0e57bf3847

See more details on using hashes here.

Provenance

The following attestation bundles were made for n4ax-0.1.0-py3-none-any.whl:

Publisher: release.yml on GragasLab/n4ax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page