Skip to main content

No project description provided

Project description

https://github.com/sail-sg/jax_xc/raw/main/figures/logo.png https://img.shields.io/pypi/v/jax-xc.svg https://readthedocs.org/projects/ansicolortags/badge/?version=latest

This library contains direct translations of exchange correlation functionals in libxc to jax. The core calculations in libxc are implemented in maple. This gives us the opportunity to translate them directly into python with the help of CodeGeneration.

Usage

Installation

pip install jax-xc

Invoking the Functionals

jax_xc’s API is functional: it receives $\rho$ a function of Callable type, and returns the $\varepsilon_{xc}$ as a function of Callable type.

E_{xc} = \int \rho(r) \varepsilon_{xc}(r) dr

LDA and GGA

Unlike libxc which takes pre-computed densities and their derivative at certain coordinates. In jax_xc, the API is designed to directly take a density function.

import jax
import jax.numpy as jnp
import jax_xc


def rho(r):
  """Electron number density. We take gaussian as an example.

  A function that takes a real coordinate, and returns a scalar
  indicating the number density of electron at coordinate r.

  Args:
  r: a 3D coordinate.
  Returns:
  rho: If it is unpolarized, it is a scalar.
      If it is polarized, it is a array of shape (2,).
  """
  return jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1))

# create a density functional
gga_xc_pbe = jax_xc.gga_x_pbe(polarized=False)

# a grid point in 3D
r = jnp.array([0.1, 0.2, 0.3])

# pass rho and r to the functional to compute epsilon_xc (energy density) at r.
# corresponding to the 'zk' in libxc
epsilon_xc_r = gga_xc_pbe(rho, r)
print(epsilon_xc_r)

mGGA

Unlike LDA and GGA that only depends on the density function, mGGA functionals also depend on the molecular orbitals.

import jax
import jax.numpy as jnp
import jax_xc


def mo(r):
  """Molecular orbital. We take gaussian as an example.

  A function that takes a real coordinate, and returns the value of
  molecular orbital at this coordinate.

  Args:
    r: a 3D coordinate.
  Returns:
    mo: If it is unpolarized, it is a array of shape (N,).
        If it is polarized, it is a array of shape (N, 2).
  """
  # Assume we have 3 molecular orbitals
  return jnp.array([
      jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1)),
      jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0.5, scale=1)),
      jnp.prod(jax.scipy.stats.norm.pdf(r, loc=-0.5, scale=1))
  ])


rho = lambda r: jnp.sum(mo(r)**2, axis=0)
mgga_xc_cc06 = jax_xc.mgga_xc_cc06(polarized=False)

# a grid point in 3D
r = jnp.array([0.1, 0.2, 0.3])

# evaluate the exchange correlation energy per particle at this point
# corresponding to the 'zk' in libxc
print(mgga_xc_cc06(rho, r, mo))

Hybrid Functionals

Hybrid functionals expose the same API, with extra attributes for the users to access parameters needed outside of libxc/jax_xc (e.g. the fraction of exact exchange).

import jax
import jax.numpy as jnp
import jax_xc


def rho(r):
  """Electron number density. We take gaussian as an example.

  A function that takes a real coordinate, and returns a scalar
  indicating the number density of electron at coordinate r.

  Args:
    r: a 3D coordinate.
  Returns:
    rho: If it is unpolarized, it is a scalar.
        If it is polarized, it is a array of shape (2,).
  """
  return jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1))


hyb_gga_xc_pbeb0 = jax_xc.hyb_gga_xc_pbeb0(polarized=False)

# a grid point in 3D
r = jnp.array([0.1, 0.2, 0.3])

# evaluate the exchange correlation energy per particle at this point
# corresponding to the 'zk' in libxc
print(hyb_gga_xc_pbeb0(rho, r))

# access to extra attributes
cam_alpha = hyb_gga_xc_pbep0.cam_alpha  # fraction of full Hartree-Fock exchange

The complete list of extra attributes can be found below:

cam_alpha: float
cam_beta: float
cam_omega: float
nlc_b: float
nlc_C: float

The meaning for each attribute is the same as libxc:

  • cam_alpha: fraction of full Hartree-Fock exchange, used both for usual hybrids as well as range-separated ones

  • cam_beta: fraction of short-range only(!) exchange in range-separated hybrids

  • cam_omega: range separation constant

  • nlc_b: non-local correlation, b parameter

  • nlc_C: non-local correlation, C parameter

Support Functionals

Please refer to the functionals section in jax_xc’s documentation for the complete list of supported functionals.

Numerical Correctness

We test all the functionals that are auto-generated from maple files against the reference values in libxc. The test is performed by comparing the output of libxc and jax_xc and make sure they are within a certain tolerance, namely atol=2e-10 and rtol=2e-10.

Performance Benchmark

We report the performance benchmark of jax_xc against libxc on a 64-core machine with Intel(R) Xeon(R) Silver 4216 CPU @ 2.10GHz.

We sample the points to evaluate the functionals by varying the number of points from 1 to $10^7$. The benchmark is performed by evaluating the runtime of the functional. Note that the runtime of jax_xc is measured by excluding the time of just-in-time compilation.

We visualize the mean value (averaged for both polarized and unpolarized) of the runtime of jax_xc and libxc in the following figure. The y-axis is log-scale.

jax_xc’s runtime is constantly below libxc’s for all batch sizes. The speed up is ranging from 3x to 10x, and it is more significant for larger batch sizes.

We hypothesize that the reason for the speed up is that Jax’s JIT compiler is able to optimize the functionals (e.g. vectorization, parallel execution, instruction fusion, constant folding for floating points, etc.) better than libxc.

https://raw.githubusercontent.com/sail-sg/jax_xc/main/figures/jax_xc_speed.svg

We visualize the distribution of the runtime ratio of jax_xc and libxc in the following figure. The ratio is closer to 0.1 for large batch sizes (~ 10x speed up). The ratio is constantly below 1.0.

https://raw.githubusercontent.com/sail-sg/jax_xc/main/figures/jax_xc_ratio.svg

Note that, we exclude one datapoint mgga_x_2d_prhg07 from the runtime ratio visualization because it is an outlier due to Jax’s lack of support oflamberw function and we use tensorflow_probability.substrates.jax.math.lambertw.

Caveates

The following functionals from libxc are not available in jax_xc because some functions are not available in jax.

gga_x_fd_lb94          # Becke-Roussel not having a closed-form expression
gga_x_fd_revlb94       # Becke-Roussel not having a closed-form expression
gga_x_gg99             # Becke-Roussel not having a closed-form expression
gga_x_kgg99            # Becke-Roussel not having a closed-form expression
hyb_gga_xc_case21      # Becke-Roussel not having a closed-form expression
hyb_mgga_xc_b94_hyb    # Becke-Roussel not having a closed-form expression
hyb_mgga_xc_br3p86     # Becke-Roussel not having a closed-form expression
lda_x_1d_exponential   # Requires explicit 1D integration
lda_x_1d_soft          # Requires explicit 1D integration
mgga_c_b94             # Becke-Roussel not having a closed-form expression
mgga_x_b00             # Becke-Roussel not having a closed-form expression
mgga_x_bj06            # Becke-Roussel not having a closed-form expression
mgga_x_br89            # Becke-Roussel not having a closed-form expression
mgga_x_br89_1          # Becke-Roussel not having a closed-form expression
mgga_x_mbr             # Becke-Roussel not having a closed-form expression
mgga_x_mbrxc_bg        # Becke-Roussel not having a closed-form expression
mgga_x_mbrxh_bg        # Becke-Roussel not having a closed-form expression
mgga_x_mggac           # Becke-Roussel not having a closed-form expression
mgga_x_rpp09           # Becke-Roussel not having a closed-form expression
mgga_x_tb09            # Becke-Roussel not having a closed-form expression
gga_x_wpbeh            # jit too long for E1_scaled
gga_c_ft97             # jit too long for E1_scaled
lda_xc_tih             # vxc functional
gga_c_pbe_jrgx         # vxc functional
gga_x_lb               # vxc functional

Building from Source Code

Modify the .env.example to fill in your envrionment variables, then rename it to .env. Then run source .env to load them into your shell.

  • OUTPUT_USER_ROOT: The path to the bazel cache. This is where the bazel cache will be stored. This is useful if you are building on a shared filesystem.

  • MAPLE_PATH: The path to the maple binary.

  • TMP_INSTALL_PATH: The path to a temporary directory where the wheel will be installed. This is useful if you are building on a shared filesystem.

How to build.

bazel --output_user_root=$OUTPUT_USER_ROOT build --action_env=PATH=$PATH:$MAPLE_PATH @maple2jax//:jax_xc_wheel

License

Aligned with libxc, jax_xc is licensed under the Mozilla Public License 2.0. See LICENSE for the full license text.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

jax_xc-0.0.7-cp311-cp311-manylinux_2_17_x86_64.whl (3.8 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

jax_xc-0.0.7-cp310-cp310-manylinux_2_17_x86_64.whl (3.8 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

jax_xc-0.0.7-cp39-cp39-manylinux_2_17_x86_64.whl (3.8 MB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

jax_xc-0.0.7-cp38-cp38-manylinux_2_17_x86_64.whl (3.8 MB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ x86-64

jax_xc-0.0.7-cp37-cp37m-manylinux_2_17_x86_64.whl (3.8 MB view details)

Uploaded CPython 3.7m manylinux: glibc 2.17+ x86-64

File details

Details for the file jax_xc-0.0.7-cp311-cp311-manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.7-cp311-cp311-manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 26e47e92b9935be3266f05763d9df3c6c76649d29e4ca2113bf551c50a22d1d7
MD5 6b7f827004ecef18658a2c0e79c37e17
BLAKE2b-256 f8c6bae9be4d01e4bf8e7d6c656422a9e9f8e512060bb1806a629fff0415f6af

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.7-cp310-cp310-manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.7-cp310-cp310-manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 b1447ff1f00597eafc00731d9a0f046feac16d83801b0442b286bc12d871de2a
MD5 4df3a05db9f3af9ac14e453ccfc35ea5
BLAKE2b-256 3c1630f7b9f7e4ad789bb12b98e41685976bdbb549a76d57eb97b4e2be70aa2d

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.7-cp39-cp39-manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.7-cp39-cp39-manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 dec5d49fd5857079615ec455e232c30fdbb49a95e971fe1dd4812c4d712e1b69
MD5 b12a04cf6ab8f7c29694f6f668ceb51e
BLAKE2b-256 7cc87f4e0597151667f137d9e840dc014287b4a0c58d6e0217287b540fa4823a

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.7-cp38-cp38-manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.7-cp38-cp38-manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 2c049ca17c584231ccd8ef30deba7601871d6c8a299cd9dfa9bd6facbf426fdb
MD5 2e69c1530be025aebc32b44c2225d025
BLAKE2b-256 247434260cb6c57fb46087f9d7c6aaefc8e3464a1a7fe2acf7ec43b1b8b15cc3

See more details on using hashes here.

File details

Details for the file jax_xc-0.0.7-cp37-cp37m-manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for jax_xc-0.0.7-cp37-cp37m-manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 4b0719f16aad8934a40c2482d7b105767d701818db226100c2cbb05049d341ee
MD5 a8807ac25675a50557889350875710f8
BLAKE2b-256 b701e577394d71c5f6573ac6558c248ba1b3a7f6e96a83879ab26c90f25a0c96

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page