Skip to main content

JAX-based Interatomic Potential

Project description

JAXIP

JAX-based Interatomic Potential

https://img.shields.io/pypi/v/jaxip.svg https://img.shields.io/travis/hghcomphys/jaxip.svg Documentation Status

JAXIP is a Python library on basis of JAX that helps in the development of emerging machine learning interatomic potentials for use in computational physics and chemistry. These potentials are necessary for conducting large-scale molecular dynamics simulations of complex materials at the atomic level with ab initio accuracy.

JAXIP is a tool for developing potentials for molecular dynamics simulations, rather than a package for performing molecular dynamics simulations itself.

Why JAXIP?

  • The design of JAXIP is simple and flexible, which makes it easy to incorporate atomic descriptors and potentials

  • It uses autograd to make defining new descriptors straightforward

  • JAXIP is written purely in Python and optimized with just-in-time (JIT) compilation.

  • It also supports GPU computing, which can significantly speed up preprocessing and model training.

Examples

Defining an atomic environment descriptor

The following example shows how to create an array of atomic-centered symmetry functions (ACSF) for a specific element. This descriptor can be applied to a given structure to produce the descriptor values that are required to build machine learning potentials.

from jaxip.datasets import RunnerStructureDataset
from jaxip.descriptors import ACSF
from jaxip.descriptors.acsf import CutoffFunction, G2, G3


# Read atomic structure dataset (e.g. water molecules)
structures = RunnerStructureDataset('input.data')
structure = structures[0]

# Define ACSF descriptor for hydrogen element
descriptor = ACSF(element='H')

# Add radial and angular symmetry functions
cfn = CutoffFunction(r_cutoff=12.0, cutoff_type='tanh')
descriptor.add( G2(cfn, eta=0.5, r_shift=0.0), 'H')
descriptor.add( G3(cfn, eta=0.001, zeta=2.0, lambda0=1.0, r_shift=12.0), 'H', 'O')

# Calculate descriptor values
values = descriptor(structure)

Output:

>> values.shape
(128, 2)

>> values[:3]
DeviceArray([[1.9689142e-03, 3.3253882e+00],
        [1.9877939e-03, 3.5034561e+00],
        [1.5204106e-03, 3.5458331e+00]], dtype=float32)

Training a machine learning potential

This example illustrates how to quickly create a high-dimensional neural network potential (HDNNP) and train it on input structures. The trained potential can then be used to evaluate the energy and force components for new structures.

from jaxip.datasets import RunnerStructureDataset
from jaxip.potentials import NeuralNetworkPotential

# Atomic data
structures = RunnerStructureDataset("input.data")

# Potential
nnp = NeuralNetworkPotential("input.nn")

# Descriptor
nnp.fit_scaler(structures)
#nnp.load_scaler()

# Train
nnp.fit_model(structures)
#nnp.load_model()

# Predict energy and force components
structure = structures[0]
energy = nnp(structure)
force = nnp.compute_force(structure)

License

This project is licensed under the GNU General Public License (GPL) version 3 - see the LICENSE file for details.

History

v0.4.0 (2023-01-03)

  • Applied extensive refactoring

  • Replaced PyTorch main dependency with JAX

v0.3.0 (2022-12-07)

  • JAX optimization of ACSF descriptor

v0.2.0 (2022-11-11)

  • Small optimizations using torch.jit.script

v0.1.0 (2022-10-28)

  • Primary implementation and validation

v0.0.1 (2022-01-01)

  • Start

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxip-0.4.1.tar.gz (53.0 kB view hashes)

Uploaded Source

Built Distribution

jaxip-0.4.1-py2.py3-none-any.whl (51.7 kB view hashes)

Uploaded Python 2 Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page