Pure JAX implementation of Non-Uniform FFT
Project description
nufftax
Pure JAX implementation of the Non-Uniform Fast Fourier Transform (NUFFT).
nufftax provides fully differentiable NUFFT operations in pure JAX. No C++ bindings or XLA custom calls - just JAX operations that work seamlessly with jit, grad, vmap, jvp, and vjp.
Features
- Pure JAX implementation - Full compatibility with JAX transformations
- All three NUFFT types:
- Type 1: nonuniform to uniform (spreading)
- Type 2: uniform to nonuniform (interpolation)
- Type 3: nonuniform to nonuniform
- 1D, 2D, and 3D transforms
- Differentiable with respect to both point coordinates and values
- GPU acceleration - Runs on CPU and GPU without code changes
- Configurable precision - From 1e-2 to 1e-14
Installation
pip install nufftax
From source:
git clone https://github.com/geoffroyO/nufftax.git
cd nufftax
pip install -e ".[dev]"
Quick Start
import jax.numpy as jnp
from nufftax import nufft1d1, nufft1d2
# Nonuniform points in [-pi, pi)
x = jnp.array([0.1, 0.5, 1.0, 2.0, -1.5])
c = jnp.array([1+1j, 2-1j, 0.5, 1j, -1+0.5j])
# Type 1: nonuniform points -> Fourier modes
f = nufft1d1(x, c, n_modes=64, eps=1e-6)
# Type 2: Fourier modes -> nonuniform points
c_interp = nufft1d2(x, f, eps=1e-6)
Autodifferentiation
Gradients work out of the box:
import jax
# Gradient w.r.t. strengths
def loss_c(c):
f = nufft1d1(x, c, n_modes=64, eps=1e-6)
return jnp.sum(jnp.abs(f) ** 2)
grad_c = jax.grad(loss_c)(c)
# Gradient w.r.t. point coordinates
def loss_x(x):
f = nufft1d1(x, c, n_modes=64, eps=1e-6)
return jnp.sum(jnp.abs(f) ** 2)
grad_x = jax.grad(loss_x)(x)
# Batched transforms
batched_nufft = jax.vmap(lambda xi: nufft1d1(xi, c, n_modes=64))
x_batch = jnp.stack([x, x + 0.1, x + 0.2])
f_batch = batched_nufft(x_batch) # Shape: (3, 64)
API Reference
Type 1: Nonuniform to Uniform
Computes: f[k] = sum_j c[j] * exp(i * isign * k * x[j])
from nufftax import nufft1d1, nufft2d1, nufft3d1
f = nufft1d1(x, c, n_modes, eps=1e-6, isign=1)
f = nufft2d1(x, y, c, n_modes, eps=1e-6, isign=1)
f = nufft3d1(x, y, z, c, n_modes, eps=1e-6, isign=1)
Type 2: Uniform to Nonuniform
Computes: c[j] = sum_k f[k] * exp(i * isign * k * x[j])
from nufftax import nufft1d2, nufft2d2, nufft3d2
c = nufft1d2(x, f, eps=1e-6, isign=1)
c = nufft2d2(x, y, f, eps=1e-6, isign=1)
c = nufft3d2(x, y, z, f, eps=1e-6, isign=1)
Type 3: Nonuniform to Nonuniform
Computes: f[k] = sum_j c[j] * exp(i * isign * s[k] * x[j])
from nufftax import nufft1d3, nufft2d3, nufft3d3
from nufftax import compute_type3_grid_size
# Compute grid size from data bounds (required for JIT)
n_modes = compute_type3_grid_size(x_extent, s_extent, eps=1e-6)
f = nufft1d3(x, c, s, n_modes, eps=1e-6, isign=1)
f = nufft2d3(x, y, c, s, t, n_modes, eps=1e-6, isign=1)
f = nufft3d3(x, y, z, c, s, t, u, n_modes, eps=1e-6, isign=1)
Parameters
| Parameter | Description |
|---|---|
x, y, z |
Nonuniform source points in [-pi, pi) |
s, t, u |
Nonuniform target frequencies (Type 3 only) |
c |
Complex strengths at source points |
f |
Fourier mode coefficients |
n_modes |
Number of output modes (int or tuple) |
eps |
Requested precision (1e-2 to 1e-14) |
isign |
Sign of exponent: +1 or -1 |
Algorithm
nufftax implements the NUFFT using:
- Spreading/Interpolation - Convolution with the exponential of semicircle (ES) kernel
- FFT - Standard FFT on an oversampled grid (2x by default)
- Deconvolution - Division by kernel Fourier coefficients
The ES kernel provides near-optimal accuracy for a given support width. All operations are implemented in pure JAX, enabling automatic differentiation through the entire transform.
Running Tests
pip install -e ".[dev]"
pytest tests/ -v
License
MIT
Acknowledgments
Algorithm based on FINUFFT by the Flatiron Institute.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nufftax-0.1.0.tar.gz.
File metadata
- Download URL: nufftax-0.1.0.tar.gz
- Upload date:
- Size: 42.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8b9ee7b1a35e3e1b107cf4a54b629611536373467ea44cf1f062eb67d68bef60
|
|
| MD5 |
415bc04379cedc663addb2bdb4e86e73
|
|
| BLAKE2b-256 |
be7572bcc9e12596fa9c52ba42ea03c724c7e193a6950362e5b4243e036bb226
|
Provenance
The following attestation bundles were made for nufftax-0.1.0.tar.gz:
Publisher:
ci.yml on geoffroyO/nufftax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nufftax-0.1.0.tar.gz -
Subject digest:
8b9ee7b1a35e3e1b107cf4a54b629611536373467ea44cf1f062eb67d68bef60 - Sigstore transparency entry: 788773057
- Sigstore integration time:
-
Permalink:
geoffroyO/nufftax@7ce4ac9dcabc36be20e2be40ce36111476db2cc2 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/geoffroyO
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yml@7ce4ac9dcabc36be20e2be40ce36111476db2cc2 -
Trigger Event:
release
-
Statement type:
File details
Details for the file nufftax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: nufftax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 30.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
873a16ab7afd0d4f418cf9544e482dcc0524168e7a112764740fb5892898f025
|
|
| MD5 |
799fae0ee7d0d408cabc8616475449af
|
|
| BLAKE2b-256 |
f506183d17ffd095dbe9298ed0625eea8c4b50b4597ece90227513c72346f475
|
Provenance
The following attestation bundles were made for nufftax-0.1.0-py3-none-any.whl:
Publisher:
ci.yml on geoffroyO/nufftax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nufftax-0.1.0-py3-none-any.whl -
Subject digest:
873a16ab7afd0d4f418cf9544e482dcc0524168e7a112764740fb5892898f025 - Sigstore transparency entry: 788773058
- Sigstore integration time:
-
Permalink:
geoffroyO/nufftax@7ce4ac9dcabc36be20e2be40ce36111476db2cc2 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/geoffroyO
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yml@7ce4ac9dcabc36be20e2be40ce36111476db2cc2 -
Trigger Event:
release
-
Statement type: