Skip to main content

CUDA accelerated equivariant operations

Project description

cuEquivariance

cuEquivariance is an NVIDIA Python library designed to facilitate the construction of high-performance equivariant neural networks using segmented tensor products. cuEquivariance provides a comprehensive API for describing segmented tensor products and optimized CUDA kernels for their execution. Additionally, cuEquivariance offers bindings for both PyTorch and JAX, ensuring broad compatibility and ease of integration.

Equivariance is the mathematical formalization of the concept of “respecting symmetries.” Robust physical models exhibit equivariance with respect to rotations and translations in three-dimensional space. Artificial intelligence models that incorporate equivariance are often more data-efficient.

Documentation

Please refer to the project documentation for more information https://docs.nvidia.com/cuda/cuequivariance/.

Installation

# Choose the frontend you want to use
pip install cuequivariance-jax
pip install cuequivariance-torch
pip install cuequivariance  # Installs only the core non-ML components

# CUDA kernels for different CUDA versions
pip install cuequivariance-ops-torch-cu11
pip install cuequivariance-ops-torch-cu12

License

All files hosted in this repository are subject to the Apache 2.0 license.

Disclaimer

cuEquivariance is in a Beta state. Beta products may not be fully functional, may contain errors or design flaws, and may be changed at any time without notice. We appreciate your feedback to improve and iterate on our Beta products.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cuequivariance_jax-0.2.0.tar.gz (26.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cuequivariance_jax-0.2.0-py3-none-any.whl (40.6 kB view details)

Uploaded Python 3

File details

Details for the file cuequivariance_jax-0.2.0.tar.gz.

File metadata

  • Download URL: cuequivariance_jax-0.2.0.tar.gz
  • Upload date:
  • Size: 26.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for cuequivariance_jax-0.2.0.tar.gz
Algorithm Hash digest
SHA256 836b90bfcd8f0c258e3d72ef7265226e7936b57deb11996cb9248d5634f25d3c
MD5 4431b04855881e80868827b7af532c23
BLAKE2b-256 d8bbfbf017f7b008202e5b60c0d8e849569c0c03fcbcfcba4d0f5db625eca961

See more details on using hashes here.

File details

Details for the file cuequivariance_jax-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for cuequivariance_jax-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b6b1e937d64fa32ef93197ab1af276d0322a4d3edfb77172133b53a278707834
MD5 6948e82e74d4c01e616f3c0fb108dd9d
BLAKE2b-256 91a7741e026b844f3f0d0bb16228523a12a65a7e7d0d0f93ff2f4eb39b18adba

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page