Skip to main content

NVIDIA cuQuantum Python JAX

Project description

cuQuantum Python JAX

cuQuantum Python JAX provides a JAX extension for cuQuantum Python. It exposes selected functionality of cuQuantum SDK in a JAX-compatible way that enables JAX frameworks to directly interface with the exposed cuQuantum API. In the current release, cuQuantum JAX exposes a JAX interface to the Operator Action API from the cuDensityMat library.

Documentation

Please visit the NVIDIA cuQuantum Python documentation.

Building and installing cuQuantum Python JAX

Requirements

The install-time dependencies of the cuQuantum Python JAX package include:

  • cuquantum-python-cu12~=26.1.0 for CUDA 12 or cuquantum-python-cu13~=26.1.0 for CUDA 13
  • jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
  • pybind11
  • setuptools>=77.0.3

Note:

  1. cuQuantum Python JAX is only supported with CUDA 12 and CUDA 13.
  2. cuQuantum Python JAX installation does not support build isolation. The user needs to pass in --no-build-isolation to pip when installing cuQuantum Python JAX.
  3. cuQuantum Python JAX wheels are CUDA-versioned: cuquantum-python-jax-cu12 for CUDA 12 and cuquantum-python-jax-cu13 for CUDA 13.

Installation using jax[cudaXX-local]

cuquantum-python-jax-cu12 (or cuquantum-python-jax-cu13) depends explicitly on jax[cudaXX-local]. Installing the package will also install jax[cudaXX-local].

Using jax[cudaXX-local] assumes the user provides both cuDNN and the CUDA Toolkit. cuDNN is not a part of the CUDA Toolkit and requires an additional installation. The user must also specify LD_LIBRARY_PATH, including the library folders containing libcudnn.so and libcupti.so.

libcupti.so is provided by the CUDA Toolkit. If the CUDA Toolkit is installed under /usr/local/cuda, libcupti.so is located under /usr/local/cuda/extras/CUPTI/lib64 and LD_LIBRARY_PATH should contain this path.

libcudnn.so is installed separately from the CUDA Toolkit. The default installation location is /usr/local/cuda/lib64, and LD_LIBRARY_PATH should contain this path.

Both libcudnn.so and libcupti.so are installable with pip:

pip install nvidia-cudnn-cu12
pip install nvidia-cuda-cupti-cu12

After installing cuDNN and cuPTI, the user may install cuQuantum Python JAX with pip using either:

pip install --no-build-isolation cuquantum-python-jax-cu12   # for CUDA 12
pip install --no-build-isolation cuquantum-python-jax-cu13   # for CUDA 13

or one of

pip install --no-build-isolation cuquantum-python-cu12[jax]
pip install --no-build-isolation cuquantum-python-cu13[jax]

where the CUDA version is explicitly specified on cuquantum-python.

Note:

  1. If cuDNN and cuPTI are installed with pip, the user does not need to specify library folders in LD_LIBRARY_PATH.
  2. When the latter command pip install --no-build-isolation cuquantum-python-cu12[jax]/pip install --no-build-isolation cuquantum-python-cu13[jax] is used, --no-build-isolation applies to both cuquantum-python and cuquantum-python-jax. The user needs to ensure cuquantum-python's build dependencies are installed before the installation.

Installing from source

To install cuQuantum Python JAX from source, first compile cuQuantum Python from source using the instructions on GitHub. Once complete, navigate to python/extensions, then:

pip install .

The CUDA version is detected automatically from $CUDA_PATH and the wheel will be named accordingly (cuquantum-python-jax-cu12 or cuquantum-python-jax-cu13).

Running

Requirements

Runtime dependencies of the cuQuantum Python JAX package include:

  • An NVIDIA GPU with compute capability 7.5+
  • cuquantum-python-cu12~=26.1.0 for CUDA 12 or cuquantum-python-cu13~=26.1.0 for CUDA 13
  • jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
  • pybind11

Developer Notes

  • cuQuantum Python JAX does not support editable installation.
  • Both cuQuantum Python and cuQuantum Python JAX need to be installed into site-packages for proper import of the library.
  • cuQuantum Python JAX assumes cuQuantum Python will be available under the current site-packages directory.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cuquantum_python_jax_cu13-0.0.5.tar.gz (58.0 kB view details)

Uploaded Source

File details

Details for the file cuquantum_python_jax_cu13-0.0.5.tar.gz.

File metadata

File hashes

Hashes for cuquantum_python_jax_cu13-0.0.5.tar.gz
Algorithm Hash digest
SHA256 7f8ed62ec57190fd74f3f3469a1862bdcaffac5ab9c4dadaed1ba895b3985ca9
MD5 ebe50fe54e42dd905652f3e4521de87e
BLAKE2b-256 d94d750392235bcd168ccc53b6e26024a8c9da03b807c7dd1181dcecaf031375

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page