Skip to main content

Pure JAX implementation of Non-Uniform FFT

Project description

nufftax logo

Pure JAX implementation of the Non-Uniform Fast Fourier Transform (NUFFT)

CI Documentation Python 3.12+ License: MIT


MRI reconstruction example

Why nufftax?

A JAX package for NUFFT already exists: jax-finufft. However, it wraps the C++ FINUFFT library via Foreign Function Interface (FFI), exposing it through custom XLA calls. This approach can lead to:

  • Kernel fusion issues on GPU — custom XLA calls act as optimization barriers, preventing XLA from fusing operations
  • CUDA version matching — GPU support requires matching CUDA versions between JAX and the library

nufftax takes a different approach — pure JAX implementation:

  • Fully differentiable — gradients w.r.t. both values and sample locations
  • Pure JAX — works with jit, grad, vmap, jvp, vjp with no FFI barriers
  • GPU ready — runs on CPU/GPU without code changes, benefits from XLA fusion
  • All NUFFT types — Type 1, 2, 3 in 1D, 2D, 3D

JAX Transformation Support

Transform jit grad/vjp jvp vmap
Type 1 (1D/2D/3D)
Type 2 (1D/2D/3D)
Type 3 (1D/2D/3D)

Differentiable inputs:

  • Type 1: grad w.r.t. c (strengths) and x, y, z (coordinates)
  • Type 2: grad w.r.t. f (Fourier modes) and x, y, z (coordinates)
  • Type 3: grad w.r.t. c (strengths), x, y, z (source coordinates), and s, t, u (target frequencies)

Installation

uv pip install nufftax

Quick Example

import jax
import jax.numpy as jnp
from nufftax import nufft1d1

# Irregular sample locations in [-pi, pi)
x = jnp.array([0.1, 0.7, 1.3, 2.1, -0.5])
c = jnp.array([1.0+0.5j, 0.3-0.2j, 0.8+0.1j, 0.2+0.4j, 0.5-0.3j])

# Compute Fourier modes
f = nufft1d1(x, c, n_modes=32, eps=1e-6)

# Differentiate through the transform
grad_c = jax.grad(lambda c: jnp.sum(jnp.abs(nufft1d1(x, c, n_modes=32)) ** 2))(c)

Documentation

Read the full documentation →

License

MIT. Algorithm based on FINUFFT by the Flatiron Institute.

Citation

If you use nufftax in your research, please cite:

@software{nufftax,
  author = {Gragas and Oudoumanessah, Geoffroy and Iollo, Jacopo},
  title = {nufftax: Pure JAX implementation of the Non-Uniform Fast Fourier Transform},
  url = {https://github.com/GragasLab/nufftax},
  year = {2026}
}

@article{finufft,
  author = {Barnett, Alexander H. and Magland, Jeremy F. and af Klinteberg, Ludvig},
  title = {A parallel non-uniform fast Fourier transform library based on an ``exponential of semicircle'' kernel},
  journal = {SIAM J. Sci. Comput.},
  volume = {41},
  number = {5},
  pages = {C479--C504},
  year = {2019}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nufftax-0.3.1.tar.gz (52.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nufftax-0.3.1-py3-none-any.whl (34.3 kB view details)

Uploaded Python 3

File details

Details for the file nufftax-0.3.1.tar.gz.

File metadata

  • Download URL: nufftax-0.3.1.tar.gz
  • Upload date:
  • Size: 52.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nufftax-0.3.1.tar.gz
Algorithm Hash digest
SHA256 90c77fbb87e0b0ee835869c23ced778457e3eedd844de23e76abfd2504c86c4c
MD5 c5ef91c6d5b9f942e8fdbcc87342042a
BLAKE2b-256 cf8627d94218b9bf18b1598e927863084a0896ba7221f72e8ae3b44664344cfa

See more details on using hashes here.

Provenance

The following attestation bundles were made for nufftax-0.3.1.tar.gz:

Publisher: release.yml on GragasLab/nufftax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nufftax-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: nufftax-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 34.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nufftax-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0d277d96adeac046940d89a5b3335b95b365b01edb47d83e11dd7c3e08df86ac
MD5 4991a053e5fdd731bd119714a3b6ccd5
BLAKE2b-256 50d6f8e43a0eb10f9f6aa4b4f8d0fff0b5e780cd0830eda31749ef5e644eaa17

See more details on using hashes here.

Provenance

The following attestation bundles were made for nufftax-0.3.1-py3-none-any.whl:

Publisher: release.yml on GragasLab/nufftax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page