Skip to main content

JAX implementation of the Skala neural exchange-correlation functional

Project description

skalax

tests

Skalax Logo

pypi: https://pypi.org/project/skalax/

JAX/Equinox implementation of the Skala neural exchange-correlation functional for density functional theory (DFT) calculations.

Overview

skalax is a pure JAX port of the Skala neural XC functional. It reproduces the PyTorch reference to within ~2 kcal/mol (due to custom JAX PySCF wrapper) NPE on the tested systems and exposes the usual JAX machinery: jax.grad, jax.jit, jax.vmap, with XLA compilation.

The goal is to make this functional usable, finetunable and modifiable from JAX-based DFT codes.

[!WARNING] Work in progress, tested on CPU only so far. The PySCF JAX wrapper I wrote is not optimal and is therefore slower than the original PySCF Skala in Torch, but the model itself is comparable in training and inference performance.

Performance

JAX JIT (XLA) matches PyTorch on tested grid sizes. All variants use radius_cutoff=5.0 and are benchmarked in steady state (post-compilation, CPU). GPU validation is the next step.

Performance benchmark

Left: forward pass only. Right: forward + backward. Eager: op-by-op execution (no compilation). JIT / traced: compiled graph; steady state = timed after a warm-up call, so compile cost is excluded. At 32k grid points on CPU, JAX JIT forward is ~1.4× faster than PyTorch traced, and JAX JIT fwd+grad is ~1.6× faster than PyTorch traced fwd+backward.

Installation

Requires gfortran and cmake (for dftd3/pyscf via skala).

pip install skalax

GPU Support

pip install skalax
pip install --upgrade "jax[cuda12]"

Development (no Fortran compiler needed)

git clone https://github.com/Brogis1/skalax
cd skalax
pip install -e .[dev]
pip install --no-deps microsoft-skala

Quick Start

Basic Usage

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from skalax import SkalaFunctional, load_weights_from_npz, load_config, get_default_weights_dir

# Load pretrained weights (bundled with package)
weights_dir = get_default_weights_dir()
config = load_config(weights_dir)

key = jax.random.PRNGKey(0)
model = SkalaFunctional(
    lmax=config["lmax"],
    non_local=config["non_local"],
    non_local_hidden_nf=config["non_local_hidden_nf"],
    radius_cutoff=config["radius_cutoff"],
    key=key,
)
model = load_weights_from_npz(model, weights_dir)

n_points, n_atoms = 100, 3
mol = {
    "density": jnp.ones((2, n_points)) * 0.1,
    "grad": jnp.zeros((2, 3, n_points)),
    "kin": jnp.ones((2, n_points)) * 0.05,
    "grid_coords": jnp.zeros((n_points, 3)),
    "grid_weights": jnp.ones(n_points) * 0.01,
    "coarse_0_atomic_coords": jnp.zeros((n_atoms, 3)),
}

E_xc = model.get_exc(mol)
print(f"E_xc = {E_xc:.10f} Ha")

JAX Autodiff

# Gradient of E_xc with respect to all inputs in one line
grads = jax.grad(model.get_exc)(mol)
print(f"dE/d(density): {grads['density'].shape}")

JIT Compilation

import equinox as eqx

@eqx.filter_jit
def get_exc_jit(m, mol):
    return m.get_exc(mol)

# First call compiles; subsequent calls are fast
E = get_exc_jit(model, mol)

PySCF Integration

import jax
jax.config.update("jax_enable_x64", True)

from pyscf import gto
from skalax import SkalaFunctional, load_weights_from_npz, load_config, get_default_weights_dir
from skalax.pyscf import JaxSkalaKS

weights_dir = get_default_weights_dir()
config = load_config(weights_dir)
key = jax.random.PRNGKey(0)
model = SkalaFunctional(
    lmax=config["lmax"],
    non_local=config["non_local"],
    non_local_hidden_nf=config["non_local_hidden_nf"],
    radius_cutoff=config["radius_cutoff"],
    key=key,
)
model = load_weights_from_npz(model, weights_dir)

mol = gto.M(
    atom="O 0 0 0; H 0.757 0.586 0; H -0.757 0.586 0",
    basis="sto-3g",
    verbose=0,
)
ks = JaxSkalaKS(mol, xc=model)
energy = ks.kernel()
print(f"Total energy: {energy:.8f} Ha")

Input/Output Specification

Input Features

Feature Shape Description
density (2, n_points) Spin densities [α, β]
grad (2, 3, n_points) Density gradients [spin, xyz, points]
kin (2, n_points) Kinetic energy densities
grid_coords (n_points, 3) Grid coordinates (Bohr)
grid_weights (n_points,) Integration weights
coarse_0_atomic_coords (n_atoms, 3) Atomic positions (Bohr)

Outputs

Method Shape Description
model.get_exc(mol) () Scalar E_xc (Hartree)
model.get_exc_density(mol) (n_points,) Energy density per grid point

Model Architecture

Roughly 276k parameters, in three stages:

  1. Input MLP. Per grid point, the 7 scalar features (spin densities, gradient norms, kinetic densities, and the α+β gradient norm) go through Linear(7→256) → SiLU → Linear(256→256) → SiLU. Spin-swapped features are pushed through the same MLP and averaged, so the model is symmetric under α↔β.

  2. Non-local branch (optional, lmax=3, radius_cutoff≈5 Bohr). The 256-dim scalar features are squeezed to 16 channels (pre_down_linear), then a fine→coarse tensor product (tp_down) aggregates to atomic centers, a coarse→fine tensor product (tp_up) broadcasts back to the grid, and a final post_up_linear (SiLU) mixes channels. Edge features use an exponential radial basis and spherical harmonics up to lmax. The non-local output is damped by exp(-ρ) and concatenated to the scalar features.

  3. Output MLP. The 256+16 dim features go through three Linear(→256) → SiLU layers, a final Linear(→1), and a ScaledSigmoid(scale=2.0). The scalar output is an enhancement factor multiplied against the LDA exchange density to give the per-point XC energy density.

Numerical Equivalence

On a handful of test cases the JAX implementation matches the PyTorch reference to machine precision:

Test Max |ΔE|
get_exc (local only) 0.00e+00 Ha
get_exc (with non-local) 1.14e-13 Ha
get_exc_density 1.17e-13 Ha

More comprehensive benchmarks follow below.

Benchmarks

A few tests to check the correctness of the implementation. Note that results can be affected (positively or negatively) by the JAX PySCF wrapper I included, which explains the imperfect match with the PyTorch reference.

Forward pass equivalence

Forward pass equivalence: relative error on total XC energy and max absolute error on per-point XC density, both well below threshold across system sizes

Energy profiles

The non-parallelity error (NPE) is crucial for the correctness of a prediction in chemistry. Here I compare the JAX implementation against the PyTorch reference:

Simple system:

H2 dissociation curve: total energy vs H-H distance (PyTorch vs JAX) and absolute energy difference

And more challenging:

CH4 symmetric stretch: total energy vs C-H distance (PyTorch vs JAX) and non-parallelity error

The curves agree well given that the two implementations share the same parameters but run on completely different backends (PyTorch vs JAX).

Dependencies

Tested versions

The following versions are known to work together (tested on CPU, Python 3.12):

Package Tested version Role
jax 0.9.2 Core
jaxlib 0.9.2 Core
equinox 0.13.6 Core
e3nn-jax 0.20.8 Core
numpy 2.4.3 Core
skala (microsoft-skala) 1.1.1 Full install
torch 2.10.0 Full install (via skala)
pyscf 2.12.1 Full install (via skala)
e3nn 0.6.0 Full install (via skala)
dftd3 Full install (via skala, requires gfortran)
huggingface_hub 1.7.1 Full install (via skala)
opt_einsum_fx 0.1.4 Full install (via skala)
ase Full install (via skala)
qcelemental Full install (via skala)

Cluster / custom installs

To install only the JAX core without PyTorch or Fortran dependencies:

pip install --no-deps skalax
pip install "jax>=0.4.0" "jaxlib>=0.4.0" "equinox>=0.11.0" "e3nn-jax>=0.20.0" "numpy>=1.21.0"

To pin specific versions for reproducibility (e.g. on a cluster):

pip install --no-deps skalax
pip install jax==0.9.2 jaxlib==0.9.2 equinox==0.13.6 "e3nn-jax==0.20.8" numpy==2.4.3

For a full install with PyTorch reference and PySCF integration, see Installation.

Tests

pytest tests/ -v

License

MIT License, see LICENSE.txt.

The pretrained weights bundled with this package are derived from the Skala model originally released by Microsoft Corporation under the MIT License (github.com/microsoft/skala).

Citation

If you use skalax, please cite the Skala paper:

@misc{luise2025skala,
  title={Accurate and scalable exchange-correlation with deep learning},
  author={Giulia Luise and Chin-Wei Huang and Thijs Vogels and Derk P. Kooi
          and Sebastian Ehlert and Stephanie Lanius and Klaas J. H. Giesbertz
          and Amir Karton and Deniz Gunceler and Megan Stanley
          and Wessel P. Bruinsma and Lin Huang and Xinran Wei
          and Jose Garrido Torres and Abylay Katbashev and Rodrigo Chavez Zavaleta
          and B{\'a}lint M{\'a}t{\'e} and Roberto Sordillo and Yingrong Chen
          and David B. Williams-Young and Christopher M. Bishop
          and Jan Hermann and Rianne van den Berg and Paola Gori-Giorgi},
  year={2025},
  eprint={2506.14665},
  archivePrefix={arXiv},
  url={https://arxiv.org/abs/2506.14665}
}

If you additionally use this JAX implementation, please also cite:

@software{sokolov2025skalax,
  author  = {Sokolov, Igor O.},
  title   = {skalax: {JAX} implementation of the {Skala} neural exchange-correlation functional},
  year    = {2025},
  url     = {https://github.com/Brogis1/skalax},
  license = {MIT}
}

Related Projects

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skalax-1.0.5.tar.gz (1.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

skalax-1.0.5-py3-none-any.whl (1.3 MB view details)

Uploaded Python 3

File details

Details for the file skalax-1.0.5.tar.gz.

File metadata

  • Download URL: skalax-1.0.5.tar.gz
  • Upload date:
  • Size: 1.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for skalax-1.0.5.tar.gz
Algorithm Hash digest
SHA256 5dbe6f748b75ff9c19d6c6b349642467e38b098a005453f4621ef59d4627d2a3
MD5 7d80b6561e81b543a434d271094f71d0
BLAKE2b-256 a3cc2d6c1d8b0cb3daab20d6162add730925b36807b4cec0c63fd6b224679c8c

See more details on using hashes here.

File details

Details for the file skalax-1.0.5-py3-none-any.whl.

File metadata

  • Download URL: skalax-1.0.5-py3-none-any.whl
  • Upload date:
  • Size: 1.3 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for skalax-1.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 a5d283d224db602e34fe1201dd7288a778cccbc1b96f26039d06a386a182f865
MD5 a569c001c4e3d0a0007b6433d264f7f0
BLAKE2b-256 c03d1cc06a9e108a6e1b322622b6ddb8add9fba767bcaa951ed5be35748aa89e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page