Skip to main content

JAX bindings for the Flatiron Institute Nonuniform Fast Fourier Transform library

Project description

JAX bindings to FINUFFT

GitHub Tests Jenkins Tests

This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.

Included features

This library includes CPU and GPU (CUDA) support. GPU support is implemented through the cuFINUFFT interface of the FINUFFT library.

Type 1 and 2 transforms are supported in 1, 2, and 3 dimensions on the CPU, and 2 and 3 dimensions on the GPU. All of these functions support forward, reverse, and higher-order differentiation, as well as batching using vmap.

[!NOTE] The GPU backend does not currently support 1D (#125).

Installation

The easiest ways to install jax-finufft is to install a pre-compiled binary from PyPI or conda-forge, but if you need GPU support or want to get tuned performance, you'll want to follow the instructions to install from source as described below.

Install binary from PyPI

[!NOTE] Only the CPU-enabled build of jax-finufft is available as a binary wheel on PyPI. For a GPU-enabled build, you'll need to build from source as described below.

To install a binary wheel from PyPI using pip, run the following commands:

python -m pip install "jax[cpu]"
python -m pip install jax-finufft

If this fails, you may need to use a conda-forge binary, or install from source.

Install binary from conda-forge

[!NOTE] Only the CPU-enabled build of jax-finufft is available as a binary from conda-forge. For a GPU-enabled build, you'll need to build from source as described below.

To install using mamba (or conda), run:

mamba install -c conda-forge jax-finufft

Install from source

Dependencies

Unsurprisingly, a key dependency is JAX, which can be installed following the directions in the JAX documentation. If you're going to want to run on a GPU, make sure that you install the appropriate JAX build.

The non-Python dependencies that you'll need are:

  • FFTW,
  • OpenMP (for CPU, optional),
  • CUDA (for GPU, >= 11.8)

Below we provide some example workflows for installing the required dependencies:

Install CPU dependencies with mamba or conda
mamba create -n jax-finufft -c conda-forge python jax fftw cxx-compiler
mamba activate jax-finufft
Install GPU dependencies with mamba or conda

For a GPU build, while the CUDA libraries and compiler are nominally available through conda, our experience trying to install them this way suggests that the "traditional" way of obtaining the CUDA Toolkit directly from NVIDIA may work best (see related advice for Horovod). After installing the CUDA Toolkit, one can set up the rest of the dependencies with:

mamba create -n gpu-jax-finufft -c conda-forge python numpy scipy fftw 'gxx<12'
mamba activate gpu-jax-finufft
export CMAKE_PREFIX_PATH=$CONDA_PREFIX:$CMAKE_PREFIX_PATH
python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Other ways of installing JAX are given on the JAX website; the "local CUDA" install methods are preferred for jax-finufft as this ensures the CUDA extensions are compiled with the same Toolkit version as the CUDA runtime. However, this is not required as long as both JAX and jax-finufft use CUDA with the same major version.

Install GPU dependencies using Flatiron module system
ml modules/2.3 \
   gcc \
   python/3.11 \
   fftw \
   cuda/12

export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"

Notes on CUDA versions

While jax-finufft may build with a wide range of CUDA versions, the resulting binaries may not be compatible with JAX (resulting in odd runtime errors, like failed cuDNN or cuBLAS initialization). For the greatest chance of success, we recommend building with the same version as JAX was built with. To discover that, one can look at the requirements in JAX's build directory (be sure to select the git tag for your version of JAX). Similarly, we encourage installing jax[cuda12-local] so JAX and jax-finufft use the same CUDA libraries.

Depending on how challenging the installation is, users might want to run jax-finufft in a container. The .devcontainer directory is a good starting point for this.

Configuring the build

There are several important CMake variables that control aspects of the jax-finufft and (cu)finufft builds. These include:

  • JAX_FINUFFT_USE_CUDA [disabled by default]: build with GPU support
  • CMAKE_CUDA_ARCHITECTURES [default native]: the target GPU architecture. native means the GPU arch of the build system.
  • FINUFFT_ARCH_FLAGS [default -march=native]: the target CPU architecture. The default is the native CPU arch of the build system.

Each of these can be set as -Ccmake.define.NAME=VALUE arguments to pip install. For example, to build with GPU support from the repo root, run:

pip install -Ccmake.define.JAX_FINUFFT_USE_CUDA=ON .

Use multiple -C arguments to set multiple variables. The -C argument will work with any of the source installation methods (e.g. PyPI source dist, GitHub, etc).

Build options can also be set with the CMAKE_ARGS environment variable. For example:

export CMAKE_ARGS="$CMAKE_ARGS -DJAX_FINUFFT_USE_CUDA=ON"

GPU build configuration

Building with GPU support requires passing JAX_FINUFFT_USE_CUDA=ON to CMake. See Configuring the build.

By default, jax-finufft will build for the GPU of the build machine. If you need to target a different compute capability, such as 8.0 for Ampere, set CMAKE_CUDA_ARCHITECTURES as a CMake define:

pip install -Ccmake.define.JAX_FINUFFT_USE_CUDA=ON -Ccmake.define.CMAKE_CUDA_ARCHITECTURES=80 .

CMAKE_CUDA_ARCHITECTURES also takes a semicolon-separated list.

To detect the arch for a specific GPU, one can run:

$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
8.0

The values are also listed on the NVIDIA website.

In some cases, you may also need the following at runtime:

export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH"

If CUDA_PATH isn't set, you'll need to replace it with the path to your CUDA installation in the above line, often something like /usr/local/cuda.

Install source from PyPI

The source code for all released versions of jax-finufft are available on PyPI, and this can be installed using:

python -m pip install --no-binary jax-finufft

Install source from GitHub

Alternatively, you can check out the source repository from GitHub:

git clone --recurse-submodules https://github.com/flatironinstitute/jax-finufft
cd jax-finufft

[!NOTE] Don't forget the --recurse-submodules argument when cloning the repo because the upstream FINUFFT library is included as a git submodule. If you do forget, you can run git submodule update --init --recursive in your local copy to checkout the submodule after the initial clone.

After cloning the repository, you can install the local copy using:

python -m pip install -e .

where the -e flag optionally runs an "editable" install.

As yet another alternative, the latest development version from GitHub can be installed directly (i.e. without cloning first) with

python -m pip install git+https://github.com/flatironinstitute/jax-finufft.git

Usage

This library provides two high-level functions (and these should be all that you generally need to interact with): nufft1 and nufft2 (for the two "types" of transforms). If you're already familiar with the Python interface to FINUFFT, please note that the function signatures here are different!

For example, here's how you can do a 1-dimensional type 1 transform (only works on CPU):

import numpy as np
from jax_finufft import nufft1

M = 100000
N = 200000

x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)

Noting that the eps and iflag are optional, and that (for good reason, I promise!) the order of the positional arguments is reversed from the finufft Python package.

The syntax for a 2-, or 3-dimensional transform (CPU or GPU) is:

f = nufft1((Nx, Ny), c, x, y)  # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z)  # 3D

The syntax for a type 2 transform is (also allowing optional iflag and eps parameters):

c = nufft2(f, x)  # 1D
c = nufft2(f, x, y)  # 2D
c = nufft2(f, x, y, z)  # 3D

All of these functions support batching using vmap, and forward and reverse mode differentiation.

Selecting a platform

If you compiled jax-finufft with GPU support, you can force it to use a particular backend by setting the environment variable JAX_PLATFORMS=cpu or JAX_PLATFORMS=cuda.

Advanced usage

The tuning parameters for the library can be set using the opts parameter to nufft1 and nufft2. For example, to explicitly set the CPU up-sampling factor that FINUFFT should use, you can update the example from above as follows:

from jax_finufft import options

opts = options.Opts(upsampfac=2.0)
nufft1(N, c, x, opts=opts)

The corresponding option for the GPU is gpu_upsampfac. In fact, all options for the GPU are prefixed with gpu_.

One complication here is that the vector-Jacobian product for a NUFFT requires evaluating a NUFFT of a different type. This means that you might want to separately tune the options for the forward and backward pass. This can be achieved using the options.NestedOpts interface. For example, to use a different up-sampling factor for the forward and backward passes, the code from above becomes:

import jax

opts = options.NestedOpts(
  forward=options.Opts(upsampfac=2.0),
  backward=options.Opts(upsampfac=1.25),
)
jax.grad(lambda args: nufft1(N, *args, opts=opts).real.sum())((c, x))

or, in this case equivalently:

opts = options.NestedOpts(
  type1=options.Opts(upsampfac=2.0),
  type2=options.Opts(upsampfac=1.25),
)

See the FINUFFT docs for descriptions of all the CPU tuning parameters. The corresponding GPU parameters are currently only listed in source code form in cufinufft_opts.h.

Similar libraries

  • finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
  • mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.

License & attribution

This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:

Copyright 2021-2025 The Simons Foundation, Inc.

If you use this software, please cite the primary references listed on the FINUFFT docs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_finufft-1.0.1.tar.gz (4.1 MB view details)

Uploaded Source

Built Distributions

jax_finufft-1.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

jax_finufft-1.0.1-cp313-cp313-macosx_14_0_arm64.whl (2.1 MB view details)

Uploaded CPython 3.13macOS 14.0+ ARM64

jax_finufft-1.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

jax_finufft-1.0.1-cp312-cp312-macosx_14_0_arm64.whl (2.1 MB view details)

Uploaded CPython 3.12macOS 14.0+ ARM64

jax_finufft-1.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

jax_finufft-1.0.1-cp311-cp311-macosx_14_0_arm64.whl (2.1 MB view details)

Uploaded CPython 3.11macOS 14.0+ ARM64

jax_finufft-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.9 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

jax_finufft-1.0.1-cp310-cp310-macosx_14_0_arm64.whl (2.1 MB view details)

Uploaded CPython 3.10macOS 14.0+ ARM64

File details

Details for the file jax_finufft-1.0.1.tar.gz.

File metadata

  • Download URL: jax_finufft-1.0.1.tar.gz
  • Upload date:
  • Size: 4.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for jax_finufft-1.0.1.tar.gz
Algorithm Hash digest
SHA256 53d1b598bae3e61bf1d545cfbd3bb3463892f64fafd79f089d249dbc1f7f1e10
MD5 9cbc3c15dcf1940ecf2fe06de5f42f61
BLAKE2b-256 464f0ca04aeb148bc609f9d33516fdf792dd5df66aab11f6ae4359a6f93b44be

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_finufft-1.0.1.tar.gz:

Publisher: wheels.yml on flatironinstitute/jax-finufft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_finufft-1.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-1.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 4bbd0656daa995b0ba919171bd3e64f29bc589ffc5cdd02fdb1b30d41b177e7b
MD5 628708b039df7e48be4df9e0535347bc
BLAKE2b-256 921265e607da478b99249c8de02301f3644c9db88c57b0cbdc800fc5af81df50

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_finufft-1.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: wheels.yml on flatironinstitute/jax-finufft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_finufft-1.0.1-cp313-cp313-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-1.0.1-cp313-cp313-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 96bd7d30ae5cc0a98b2b005e4b626005267586f447ea6542dbd9828c37a9288a
MD5 bdf7bbd90980e7e986ff62397641a1d6
BLAKE2b-256 7017feb79f98e190200cf4c37ebf9f0ec597ddffa010f1800a6287ad0bfa5c35

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_finufft-1.0.1-cp313-cp313-macosx_14_0_arm64.whl:

Publisher: wheels.yml on flatironinstitute/jax-finufft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_finufft-1.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-1.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d1c03bcc5fb3cb5eb604702a575d8feecd1bc94ab2c132077240786fc019bdf9
MD5 864c6b0836bab2082ec1eb38ea3429dc
BLAKE2b-256 4b1ccff5d348735cdddf46f58b94d05a1015e7278514ad6654c2a4c00c00e6d8

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_finufft-1.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: wheels.yml on flatironinstitute/jax-finufft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_finufft-1.0.1-cp312-cp312-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-1.0.1-cp312-cp312-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 ca049effc123a624f893bf9dcfa34b666c6f801ab137814b10deea98d66506c7
MD5 50620a1893150d864c8e3b01f53c27d1
BLAKE2b-256 9683bab33ecb06301e9affe76e81bd6d2495ee7bd1ab89236fba3874b4b3c821

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_finufft-1.0.1-cp312-cp312-macosx_14_0_arm64.whl:

Publisher: wheels.yml on flatironinstitute/jax-finufft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_finufft-1.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-1.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 17a5286b37a62d4229d63fb2005515f57b0a6fc6504058bb151957724247c0f1
MD5 ae4add40eb6f8e36539581a34c0bb005
BLAKE2b-256 5f4b42514559a091a5f719b1e27f2e2edf897a20e6983c721527439d717dfffa

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_finufft-1.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: wheels.yml on flatironinstitute/jax-finufft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_finufft-1.0.1-cp311-cp311-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-1.0.1-cp311-cp311-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 74378584886b879420293710ee9cd456dddaa97784d98cff2858c457e88cac30
MD5 51e7e9e8ab3d24c3d86dd46aebb142e5
BLAKE2b-256 c9d0b478a7105679a0cd5e25cff9e21cda54bceabd70acb3ae8977637a132a16

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_finufft-1.0.1-cp311-cp311-macosx_14_0_arm64.whl:

Publisher: wheels.yml on flatironinstitute/jax-finufft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_finufft-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jax_finufft-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 f9d6951eda2a581937b8e56daa9cfb801dd3277a61767d671bcbab5240927ac8
MD5 028f3611f9706489364365a14be0f82f
BLAKE2b-256 db24f0a4c01dc074410e42bbbbac036c38737dddc8437c54584492174d702c6d

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_finufft-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: wheels.yml on flatironinstitute/jax-finufft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_finufft-1.0.1-cp310-cp310-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for jax_finufft-1.0.1-cp310-cp310-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 a39a369974ad0016cba2b5670c232b2acde68549b81fdd7640358bf3f368e731
MD5 2f6dffd6f9abf01a7742f55e73815672
BLAKE2b-256 66c290a76f7cf03dc9d0953606b5438e46ec3599ca5fbca75e68a49b217f6ec5

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_finufft-1.0.1-cp310-cp310-macosx_14_0_arm64.whl:

Publisher: wheels.yml on flatironinstitute/jax-finufft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page