Skip to main content

No project description provided

Project description

https://github.com/sail-sg/jax_xc/raw/main/figures/logo.png https://img.shields.io/pypi/v/jax-xc.svg https://readthedocs.org/projects/ansicolortags/badge/?version=latest

This library contains direct translations of exchange correlation functionals in libxc to jax. The core calculations in libxc are implemented in maple. This gives us the opportunity to translate them directly into python with the help of CodeGeneration.

Usage

Installation

pip install jax-xc

Invoking the Functionals

jax_xc’s API is functional: it receives $\rho$ a function of Callable type, and returns the $\varepsilon_{xc}$ as a function of Callable type.

E_{xc} = \int \rho(r) \varepsilon_{xc}(r) dr

LDA and GGA

Unlike libxc which takes pre-computed densities and their derivative at certain coordinates. In jax_xc, the API is designed to directly take a density function.

import jax
import jax.numpy as jnp
import jax_xc


def rho(r):
  """Electron number density. We take gaussian as an example.

  A function that takes a real coordinate, and returns a scalar
  indicating the number density of electron at coordinate r.

  Args:
  r: a 3D coordinate.
  Returns:
  rho: If it is unpolarized, it is a scalar.
      If it is polarized, it is a array of shape (2,).
  """
  return jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1))

# create a density functional
gga_xc_pbe = jax_xc.gga_x_pbe(polarized=False)

# a grid point in 3D
r = jnp.array([0.1, 0.2, 0.3])

# pass rho and r to the functional to compute epsilon_xc (energy density) at r.
# corresponding to the 'zk' in libxc
epsilon_xc_r = gga_xc_pbe(rho, r)
print(epsilon_xc_r)

mGGA

Unlike LDA and GGA that only depends on the density function, mGGA functionals also depend on the molecular orbitals.

import jax
import jax.numpy as jnp
import jax_xc


def mo(r):
  """Molecular orbital. We take gaussian as an example.

  A function that takes a real coordinate, and returns the value of
  molecular orbital at this coordinate.

  Args:
    r: a 3D coordinate.
  Returns:
    mo: If it is unpolarized, it is a array of shape (N,).
        If it is polarized, it is a array of shape (N, 2).
  """
  # Assume we have 3 molecular orbitals
  return jnp.array([
      jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1)),
      jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0.5, scale=1)),
      jnp.prod(jax.scipy.stats.norm.pdf(r, loc=-0.5, scale=1))
  ])


rho = lambda r: jnp.sum(mo(r)**2, axis=0)
mgga_xc_cc06 = jax_xc.mgga_xc_cc06(polarized=False)

# a grid point in 3D
r = jnp.array([0.1, 0.2, 0.3])

# evaluate the exchange correlation energy per particle at this point
# corresponding to the 'zk' in libxc
print(mgga_xc_cc06(rho, r, mo))

Hybrid Functionals

Hybrid functionals expose the same API, with extra attributes for the users to access parameters needed outside of libxc/jax_xc (e.g. the fraction of exact exchange).

import jax
import jax.numpy as jnp
import jax_xc


def rho(r):
  """Electron number density. We take gaussian as an example.

  A function that takes a real coordinate, and returns a scalar
  indicating the number density of electron at coordinate r.

  Args:
    r: a 3D coordinate.
  Returns:
    rho: If it is unpolarized, it is a scalar.
        If it is polarized, it is a array of shape (2,).
  """
  return jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1))


hyb_gga_xc_pbeb0 = jax_xc.hyb_gga_xc_pbeb0(polarized=False)

# a grid point in 3D
r = jnp.array([0.1, 0.2, 0.3])

# evaluate the exchange correlation energy per particle at this point
# corresponding to the 'zk' in libxc
print(hyb_gga_xc_pbeb0(rho, r))

# access to extra attributes
cam_alpha = hyb_gga_xc_pbep0.cam_alpha  # fraction of full Hartree-Fock exchange

The complete list of extra attributes can be found below:

cam_alpha: float
cam_beta: float
cam_omega: float
nlc_b: float
nlc_C: float

The meaning for each attribute is the same as libxc:

  • cam_alpha: fraction of full Hartree-Fock exchange, used both for usual hybrids as well as range-separated ones

  • cam_beta: fraction of short-range only(!) exchange in range-separated hybrids

  • cam_omega: range separation constant

  • nlc_b: non-local correlation, b parameter

  • nlc_C: non-local correlation, C parameter

Experimental

We support automatic functional derivative!

import jax
import jax_xc
import autofd.operators as o
from autofd import function
import jax.numpy as jnp
from jaxtyping import Array, Float32

@function
def rho(r: Float32[Array, "3"]) -> Float32[Array, ""]:
  """Electron number density. We take gaussian as an example.

  A function that takes a real coordinate, and returns a scalar
  indicating the number density of electron at coordinate r.

  Args:
  r: a 3D coordinate.
  Returns:
  rho: If it is unpolarized, it is a scalar.
      If it is polarized, it is a array of shape (2,).
  """
  return jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1))

# create a density functional
gga_x_pbe = jax_xc.experimental.gga_x_pbe
epsilon_xc = gga_x_pbe(rho)

# a grid point in 3D
r = jnp.array([0.1, 0.2, 0.3])

# pass rho and r to the functional to compute epsilon_xc (energy density) at r.
# corresponding to the 'zk' in libxc
print(f"The function signature of epsilon_xc is {epsilon_xc}")

energy_density = epsilon_xc(r)
print(f"epsilon_xc(r) = {energy_density}")

vxc = jax.grad(lambda rho: o.integrate(rho * gga_x_pbe(rho)))(rho)
print(f"The function signature of vxc is {vxc}")
print(vxc(r))

Support Functionals

Please refer to the functionals section in jax_xc’s documentation for the complete list of supported functionals.

Numerical Correctness

We test all the functionals that are auto-generated from maple files against the reference values in libxc. The test is performed by comparing the output of libxc and jax_xc and make sure they are within a certain tolerance, namely atol=2e-10 and rtol=2e-10.

Performance Benchmark

We report the performance benchmark of jax_xc against libxc on a 64-core machine with Intel(R) Xeon(R) Silver 4216 CPU @ 2.10GHz.

We sample the points to evaluate the functionals by varying the number of points from 1 to $10^7$. The benchmark is performed by evaluating the runtime of the functional. Note that the runtime of jax_xc is measured by excluding the time of just-in-time compilation.

We visualize the mean value (averaged for both polarized and unpolarized) of the runtime of jax_xc and libxc in the following figure. The y-axis is log-scale.

jax_xc’s runtime is constantly below libxc’s for all batch sizes. The speed up is ranging from 3x to 10x, and it is more significant for larger batch sizes.

We hypothesize that the reason for the speed up is that Jax’s JIT compiler is able to optimize the functionals (e.g. vectorization, parallel execution, instruction fusion, constant folding for floating points, etc.) better than libxc.

https://raw.githubusercontent.com/sail-sg/jax_xc/main/figures/jax_xc_speed.svg

We visualize the distribution of the runtime ratio of jax_xc and libxc in the following figure. The ratio is closer to 0.1 for large batch sizes (~ 10x speed up). The ratio is constantly below 1.0.

https://raw.githubusercontent.com/sail-sg/jax_xc/main/figures/jax_xc_ratio.svg

Note that, we exclude one datapoint mgga_x_2d_prhg07 from the runtime ratio visualization because it is an outlier due to Jax’s lack of support oflamberw function and we use tensorflow_probability.substrates.jax.math.lambertw.

Caveates

The following functionals from libxc are not available in jax_xc because some functions are not available in jax.

gga_x_fd_lb94          # Becke-Roussel not having a closed-form expression
gga_x_fd_revlb94       # Becke-Roussel not having a closed-form expression
gga_x_gg99             # Becke-Roussel not having a closed-form expression
gga_x_kgg99            # Becke-Roussel not having a closed-form expression
hyb_gga_xc_case21      # Becke-Roussel not having a closed-form expression
hyb_mgga_xc_b94_hyb    # Becke-Roussel not having a closed-form expression
hyb_mgga_xc_br3p86     # Becke-Roussel not having a closed-form expression
lda_x_1d_exponential   # Requires explicit 1D integration
lda_x_1d_soft          # Requires explicit 1D integration
mgga_c_b94             # Becke-Roussel not having a closed-form expression
mgga_x_b00             # Becke-Roussel not having a closed-form expression
mgga_x_bj06            # Becke-Roussel not having a closed-form expression
mgga_x_br89            # Becke-Roussel not having a closed-form expression
mgga_x_br89_1          # Becke-Roussel not having a closed-form expression
mgga_x_mbr             # Becke-Roussel not having a closed-form expression
mgga_x_mbrxc_bg        # Becke-Roussel not having a closed-form expression
mgga_x_mbrxh_bg        # Becke-Roussel not having a closed-form expression
mgga_x_mggac           # Becke-Roussel not having a closed-form expression
mgga_x_rpp09           # Becke-Roussel not having a closed-form expression
mgga_x_tb09            # Becke-Roussel not having a closed-form expression
gga_x_wpbeh            # jit too long for E1_scaled
gga_c_ft97             # jit too long for E1_scaled
lda_xc_tih             # vxc functional
gga_c_pbe_jrgx         # vxc functional
gga_x_lb               # vxc functional

Building from Source Code

Modify the .env.example to fill in your envrionment variables, then rename it to .env. Then run source .env to load them into your shell.

  • OUTPUT_USER_ROOT: The path to the bazel cache. This is where the bazel cache will be stored. This is useful if you are building on a shared filesystem.

  • MAPLE_PATH: The path to the maple binary.

  • TMP_INSTALL_PATH: The path to a temporary directory where the wheel will be installed. This is useful if you are building on a shared filesystem.

Make sure you have bazel and maple installed. Your python envrionment has installed the dependencies in requirements.txt.

How to build python wheel.

bazel --output_user_root=$OUTPUT_USER_ROOT build --action_env=PATH=$PATH:$MAPLE_PATH @jax_xc_repo//:jax_xc_wheel

Once the build finished, the python wheel could be found under bazel-bin/external/jax_xc_repo. For example, the name for version 0.0.7 is jax_xc-0.0.7-cp310-cp310-manylinux_2_17_x86_64.whl.

Install the python wheel. If needed, specify the install path by

pip install {{wheel_name}} --target $TMP_INSTALL_PATH

Running Test

The test could be run without the command above that builds wheel from source, though it might take longer time to build all the components needed for the test. To run all the test:

bazel --output_user_root=$OUTPUT_USER_ROOT test --action_env=PATH=$PATH:$MAPLE_PATH //tests/...

To run a specific test, for example test_impl:

bazel --output_user_root=$OUTPUT_USER_ROOT test --action_env=PATH=$PATH:$MAPLE_PATH //tests:test_impl

The test output could be found in bazel-testlogs/tests/test_impl/test.log for the tests:test_impl and similar to the others. If you prefer output in command line, add --test_output=all to the above command.

License

Aligned with libxc, jax_xc is licensed under the Mozilla Public License 2.0. See LICENSE for the full license text.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

jax_xc-0.0.11-cp312-cp312-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.12 Windows x86-64

jax_xc-0.0.11-cp312-cp312-manylinux_2_17_x86_64.whl (2.4 MB view details)

Uploaded CPython 3.12 manylinux: glibc 2.17+ x86-64

jax_xc-0.0.11-cp312-cp312-macosx_11_0_x86_64.whl (1.5 MB view details)

Uploaded CPython 3.12 macOS 11.0+ x86-64

jax_xc-0.0.11-cp311-cp311-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.11 Windows x86-64

jax_xc-0.0.11-cp311-cp311-manylinux_2_17_x86_64.whl (2.4 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

jax_xc-0.0.11-cp311-cp311-macosx_11_0_x86_64.whl (1.5 MB view details)

Uploaded CPython 3.11 macOS 11.0+ x86-64

jax_xc-0.0.11-cp310-cp310-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.10 Windows x86-64

jax_xc-0.0.11-cp310-cp310-manylinux_2_17_x86_64.whl (2.4 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

jax_xc-0.0.11-cp310-cp310-macosx_11_0_x86_64.whl (1.5 MB view details)

Uploaded CPython 3.10 macOS 11.0+ x86-64

jax_xc-0.0.11-cp39-cp39-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.9 Windows x86-64

jax_xc-0.0.11-cp39-cp39-manylinux_2_17_x86_64.whl (2.4 MB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

jax_xc-0.0.11-cp39-cp39-macosx_11_0_x86_64.whl (1.5 MB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

File details

Details for the file jax_xc-0.0.11-cp312-cp312-win_amd64.whl.

File metadata

  • Download URL: jax_xc-0.0.11-cp312-cp312-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.12, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for jax_xc-0.0.11-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 e94e9260baeb34f75f2b70d0b3cc0432848e369d16ee5d60515b3d1e0517da14
MD5 b5a1bcb8090ffb5d77591f089e4025ce
BLAKE2b-256 bc2bfd2416e88d0deef010b2b59c2687566dc6e905c0adb10aa9acdf5058ecad

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.11-cp312-cp312-manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.11-cp312-cp312-manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 89f50314447f2246aa66ee62eb535c709cca93f02434dc6f7e5c2ee5e65ada59
MD5 58752d547d966b736ff0477103c4cce6
BLAKE2b-256 036d1527862b28dd5d9e2817aa23728933f13fbf3976fecd954c3ecf83e70a86

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.11-cp312-cp312-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.11-cp312-cp312-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 278ea99830c70f417c895639f889430765a888b0e04ef5019a6e9054af234c95
MD5 febd30c64de652d0afbda0b044e2731c
BLAKE2b-256 af78c86a92510cfce9f809d13b6e57f6737c2cf64419c2fb026fdf3835c8607e

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.11-cp311-cp311-win_amd64.whl.

File metadata

  • Download URL: jax_xc-0.0.11-cp311-cp311-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.11, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for jax_xc-0.0.11-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 d1cbc24dc4731abf33cea0c0a176fbf3a0c2a403b0c212b7f3a71edba078e30b
MD5 bda9de0c1dd196aa9f2afc3c5e5d1aef
BLAKE2b-256 5c1af973bc28a3670dd0e1a63a749610a706756d5fb2dac555441a211830ee4c

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.11-cp311-cp311-manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.11-cp311-cp311-manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 79522318718f98a0cd863be03fd7e37b79840558d1ba7350d2e21368130c9214
MD5 8a02339623b76190bd6dabffa78fe65c
BLAKE2b-256 6d52ed0be90d690c41e37a0755fa9a8598b2b7957776f2dc6d0dd214200a8860

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.11-cp311-cp311-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.11-cp311-cp311-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 5afb27eb4454fa435980bd9f75feb15a84b9caddab078683f1116e9cd6b36453
MD5 7b07b3404800894a00f99499dfd9a6f4
BLAKE2b-256 680c47df95889b4ae06ab8625b5951eba012bdf0dee31495a3d90cfed8092203

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.11-cp310-cp310-win_amd64.whl.

File metadata

  • Download URL: jax_xc-0.0.11-cp310-cp310-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for jax_xc-0.0.11-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 b2b1bb9fa6c157a0249765350c303cff12f9ea017d7e11b93d9141ec8bc42f3a
MD5 08388ea7d993d69ad6469d9b83036ccd
BLAKE2b-256 d60eac1378d9bb0d208f4fdcfe86f6f2c471e3efd3439e26b811bbcae0a08a05

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.11-cp310-cp310-manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.11-cp310-cp310-manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 bd0954f8c66c7e62fa0c4b4c56010fce5ba45af949ab746fe4fd64b666c3a849
MD5 1eac36e86fea8f749c49449a55e895f0
BLAKE2b-256 247d4f2692f8f30553976685dea41c05a2e9eef19d85052a9d4e21aba2268a7b

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.11-cp310-cp310-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.11-cp310-cp310-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 01eeb0d41ffd3b59c6cc3ee04562c210c44b4320136f50d755aacc1e6162dedf
MD5 8b3771a02566d9728cb1fa1deba5fc0d
BLAKE2b-256 bdef40538323222d6df10f37469876ab72a3dc58c787239f57f8bd126cfa48d1

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.11-cp39-cp39-win_amd64.whl.

File metadata

  • Download URL: jax_xc-0.0.11-cp39-cp39-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.9, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for jax_xc-0.0.11-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 8a758eac5ac53969140ef75999c0c3ea5e70cddc966e340256eff6e4273012ca
MD5 907906194e58281ec131b6198fd64781
BLAKE2b-256 7cae0127ec11d6edccbb5d15117e591694412e53e2b71112e479ec561a1a4839

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.11-cp39-cp39-manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.11-cp39-cp39-manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 70794d979a434a631cbb14413cedfc1697c0859ff876dfa9608914788cbbeb95
MD5 842dd815bbc05438fb9c182bd9677ec9
BLAKE2b-256 94b4abcc553ede55dd8bb4e8b3a976497d37cc5a07a53f85182c401f46509e6c

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.11-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.11-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 10da3b0016e3d1112380fd6d141fa3694b0c0130f55b809b555d44c6089c760c
MD5 1687d0b61c4be31097d18bbff19fe5f2
BLAKE2b-256 e8a44bb5860e456c82e1ed63b258e7104bfe40e3368ddc00df4a749920e58322

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page