Skip to main content

Differentiable signal processing on the sphere for PyTorch.

Project description

torch-harmonics

Overview | Installation | More information | Getting started | Contributors | Cite us | References

tests coverage pypi

Overview

torch-harmonics implements differentiable signal processing on the sphere. This includes differentiable implementations of the spherical harmonic transforms, vector spherical harmonic transforms and discrete-continuous convolutions on the sphere. The package was originally implemented to enable Spherical Fourier Neural Operators (SFNO) [1].

The SHT algorithm uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes [2].

torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.

torch-harmonics has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators [1].

Installation

Prebuilt wheels

Prebuilt Linux wheels with compiled CUDA extensions are available on pypi.nvidia.com. Pick the package matching your CUDA toolkit version:

CUDA Package Min PyTorch Install command
12.6 torch-harmonics-cu126 2.6.0 pip install torch-harmonics-cu126 --extra-index-url https://pypi.nvidia.com
12.8 torch-harmonics-cu128 2.7.0 pip install torch-harmonics-cu128 --extra-index-url https://pypi.nvidia.com
12.9 torch-harmonics-cu129 2.8.0 pip install torch-harmonics-cu129 --extra-index-url https://pypi.nvidia.com
13.0 torch-harmonics-cu130 2.9.1 pip install torch-harmonics-cu130 --extra-index-url https://pypi.nvidia.com

If you don't need a specific CUDA version, use one of the rolling aliases:

# latest CUDA build
pip install torch-harmonics-cuda-latest --extra-index-url https://pypi.nvidia.com

# CPU only
pip install torch-harmonics-cpu-latest --extra-index-url https://pypi.nvidia.com

Tip: Run nvidia-smi to check your driver's CUDA version.

PyPI

The vanilla torch-harmonics package on PyPI ships a CPU-only prebuilt wheel. This version is built for the newest PyTorch release. For GPU support, use the NVIDIA PyPI packages above.

pip install torch-harmonics

Building from source

If your OS, PyTorch or CUDA toolkit version is not covered by the available wheels, we recomment building torch-harmonics from the GitHub repository. Use --no-build-isolation so that custom CPU and CUDA kernels compile against your existing torch installation:

git clone git@github.com:NVIDIA/torch-harmonics.git
cd torch-harmonics
pip install --no-build-isolation -e .

If CUDA devices are not detected automatically (e.g. inside a container), set the FORCE_CUDA_EXTENSION flag. Set TORCH_CUDA_ARCH_LIST to only the architectures you need to reduce compilation time:

export FORCE_CUDA_EXTENSION=1
export TORCH_CUDA_ARCH_LIST="8.0 8.6 9.0 10.0+PTX"
pip install --no-build-isolation -e .

:warning: Custom CUDA extensions require architectures >= 7.0.

Alternatively, build a Docker container:

git clone git@github.com:NVIDIA/torch-harmonics.git
cd torch-harmonics
docker build . -t torch_harmonics
docker run --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 torch_harmonics

More about torch-harmonics

Spherical harmonics

The spherical harmonics are special functions defined on the two-dimensional sphere $S^2$ (embedded in three dimensions). They form an orthonormal basis of the space of square-integrable functions defined on the sphere $L^2(S^2)$ and are comparable to the harmonic functions defined on a circle/torus. The spherical harmonics are defined as

$$ Y_l^m(\theta, \lambda) = \sqrt{\frac{(2l + 1)}{4 \pi} \frac{(l - m)!}{(l + m)!}} P_l^m(\cos \theta) \exp(im\lambda), $$

where $\theta$ and $\lambda$ are colatitude and longitude respectively, and $P_l^m$ the normalized, associated Legendre polynomials.


Spherical harmonics up to degree 5

Spherical harmonic transform

The spherical harmonic transform (SHT)

$$ f_l^m = \int_{S^2} \overline{Y_{l}^{m}}(\theta, \lambda) f(\theta, \lambda) \mathrm{d} \mu(\theta, \lambda) $$

realizes the projection of a signal $f(\theta, \lambda)$ on $S^2$ onto the spherical harmonics basis. The SHT generalizes the Fourier transform on the sphere. Conversely, a truncated series expansion of a function $f$ can be written in terms of spherical harmonics as

$$ f (\theta, \lambda) = \sum_{m=-M}^{M} \exp(im\lambda) \sum_{l=|m|}^{M} \hat f_l^m P_l^m (\cos \theta), $$

where $\hat{f}_l^m$, are the expansion coefficients associated to the mode $m$, $n$.

The implementation of the SHT follows the algorithm as presented in [2]. A direct spherical harmonic transform can be accomplished by a Fourier transform

$$ \hat f^m(\theta) = \frac{1}{2 \pi} \int_{0}^{2\pi} f(\theta, \lambda) \exp(-im\lambda) \mathrm{d} \lambda $$

in longitude and a Legendre transform

$$ \hat f_l^m = \frac{1}{2} \int^{\pi}_0 \hat f^{m} (\theta) P_l^m (\cos \theta) \sin \theta \mathrm{d} \theta $$

in latitude.

Discrete Legendre transform

The second integral, which computed the projection onto the Legendre polynomials is realized with quadrature. On the Gaussian grid, we use Gaussian quadrature in the $\cos \theta$ domain. The integral

$$ \hat f_l^m = \frac{1}{2} \int_{-1}^1 \hat{f}^m(\arccos x) P_l^m (x) \mathrm{d} x $$

is obtained with the substitution $x = \cos \theta$ and then approximated by the sum

$$ \hat f_l^m = \sum_{j=1}^{N_\theta} \hat{f}^m(\arccos x_j) P_l^m(x_j) w_j. $$

Here, $x_j \in [-1,1]$ are the quadrature nodes with the respective quadrature weights $w_j$.

Discrete-continuous convolutions on the sphere

torch-harmonics now provides local discrete-continuous (DISCO) convolutions as outlined in [5] on the sphere. These are use in local neural operators [2] to generalize convolutions to structured and unstructured meshes on the sphere.

Spherical (neighborhood) attention

torch-harmonics introduces spherical attention mechanisms which correctly generalize the attention mechanism to the sphere. The use of quadrature rules makes the resulting operations approximately equivariant and equivariant in the continuous limit. Moreover, neighborhood attention is correctly generalized onto the sphere by using the geodesic distance to determine the size of the neighborhood.

Getting started

The main functionality of torch_harmonics is provided in the form of torch.nn.Modules for composability. A minimum example is given by:

import torch
import torch_harmonics as th

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

nlat = 512
nlon = 2*nlat
batch_size = 32
signal = torch.randn(batch_size, nlat, nlon, device=device)

# transform data on an equiangular grid
sht = th.RealSHT(nlat, nlon, grid="equiangular").to(device)

coeffs = sht(signal)

To enable scalable model-parallelism, torch-harmonics implements a distributed variant of the SHT located in torch_harmonics.distributed.

Detailed usage of torch-harmonics, alongside helpful analysis provided in a series of notebooks:

  1. Getting started
  2. Quadrature
  3. Visualizing the spherical harmonics
  4. Spectral fitting vs. SHT
  5. Conditioning of the Gramian
  6. Solving the Helmholtz equation
  7. Solving the shallow water equations
  8. Training Spherical Fourier Neural Operators (SFNO)
  9. Resampling signals on the sphere
  10. Computing partial derivatives with the SHT

Examples and reproducibility

The examples folder contains training scripts for three distinct tasks:

Results from the papers can generally be reproduced by running python train.py. In the case of some older results the number of epochs and learning-rate may need to be adjusted by passing the corresponding command line argument.

Remarks on automatic mixed precision (AMP) support

Note that torch-harmonics uses Fourier transforms from torch.fft which in turn uses kernels from the optimized cuFFT library. This library supports fourier transforms of float32 and float64 (i.e. single and double precision) tensors for all input sizes. For float16 (i.e. half precision) and bfloat16 inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when torch.autocast is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:

import torch
import torch_harmonics as th

sht = th.RealSHT(512, 1024, grid="equiangular").cuda()

with torch.autocast(device_type="cuda", enabled = True):
   # do some AMP converted math here
   x = some_math(x)
   # convert tensor to float32
   x = x.to(torch.float32)
   # now disable autocast specifically for the transform,
   # making sure that the tensors are not converted
   # back to reduced precision internally
   with torch.autocast(device_type="cuda", enabled = False):
      xt = sht(x)

   # continue operating on the transformed tensor
   xt = some_more_math(xt)

Depending on the problem, it might be beneficial to upcast data to float64 instead of float32 precision for numerical stability.

Contributors

Boris Bonev (bbonev@nvidia.com), Thorsten Kurth (tkurth@nvidia.com), Max Rietmann, Mauro Bisson, Andrea Paris, Alberto Carpentieri, Massimiliano Fatica, Nikola Kovachki, Jean Kossaifi, Christian Hundt

Cite us

If you use torch-harmonics in an academic paper, please cite [1]

@misc{bonev2023spherical,
      title={Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere},
      author={Boris Bonev and Thorsten Kurth and Christian Hundt and Jaideep Pathak and Maximilian Baust and Karthik Kashinath and Anima Anandkumar},
      year={2023},
      eprint={2306.03838},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

References

[1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.; Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere; International Conference on Machine Learning, 2023. arxiv link

[2] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.; Neural Operators with Localized Integral and Differential Kernels; International Conference on Machine Learning, 2024. arxiv link

[3] Schaeffer N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations; G3: Geochemistry, Geophysics, Geosystems, 2013.

[4] Wang B., Wang L., Xie Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math, 2018.

[5] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603

[6] Bonev B., Rietmann M., Paris A., Carpentieri A., Kurth T.; Attention on the Sphere; arxiv link

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

torch_harmonics_cu128-0.9.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (508.3 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

torch_harmonics_cu128-0.9.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (506.8 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

torch_harmonics_cu128-0.9.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (503.8 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

torch_harmonics_cu128-0.9.1-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (503.9 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file torch_harmonics_cu128-0.9.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_harmonics_cu128-0.9.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 70e39f45e49b55ced5554c2fca4bc0c4c3b720068323ba7f792a7d2af6a3b662
MD5 d66ffc70964acc6b1a45b3ec9a14e679
BLAKE2b-256 e14465a909ce3749c0a24a6af97765091d4890fc87b036596004b2e1b31a4e18

See more details on using hashes here.

File details

Details for the file torch_harmonics_cu128-0.9.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_harmonics_cu128-0.9.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 15274066035fb20ece8e435bf3a2eec843b79c7f585733b2e5bc0e84f40c54c8
MD5 676a48451d819ecdb455dc335f7404c8
BLAKE2b-256 e799cd1f9206ff1cd9d1277dcbe45070e0287e8cf8042b46649f76f2926f79fc

See more details on using hashes here.

File details

Details for the file torch_harmonics_cu128-0.9.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_harmonics_cu128-0.9.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 20f92120f7d0b507ca4b96fc1d366c0ab888f20326bcd3fa0fdafbb3177ab9a6
MD5 71c0234e9e8415f1ec8dc85e0b0b18f0
BLAKE2b-256 4aa998cfb3b194d0a8b7d626d4b64a1d7263ea6c772f82f803b025359469b0a8

See more details on using hashes here.

File details

Details for the file torch_harmonics_cu128-0.9.1-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_harmonics_cu128-0.9.1-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 bef9b3b07f6d6030530f01e1b1229f6f7dec676c26c37b6b889923949fc4eccc
MD5 f735ad3d725be6adcfacbee4a64e2b64
BLAKE2b-256 3342cc451e303dec3c6b01a9c44894934062877438b715de8820c3af24e672b1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page