JAX-based Interatomic Potential
Project description
JAX-based Interatomic Potential
Description
JAXIP is an optimized Python library on basis of JAX that helps in the development of emerging machine learning interatomic potentials for use in computational physics and chemistry. These potentials are necessary for conducting large-scale molecular dynamics simulations of complex materials at the atomic level with ab initio accuracy.
JAXIP is designed to develop potentials for use in molecular dynamics simulations, rather than a package for performing the simulations themselves.
Documentation: https://jaxip.readthedocs.io.
Main features
The design of JAXIP is simple and flexible, which makes it easy to incorporate atomic descriptors and potentials
It uses autograd to make defining new descriptors straightforward
JAXIP is written purely in Python and optimized with just-in-time (JIT) compilation.
It also supports GPU-accelerated computing, which can significantly speed up preprocessing and model training.
Examples
Defining an atomic environment descriptor
The following example shows how to create an array of atomic-centered symmetry functions (ACSF) for a specific element. This descriptor can be applied to a given structure to produce the descriptor values that are required to build machine learning potentials.
from jaxip.datasets import RunnerStructureDataset
from jaxip.descriptors import ACSF
from jaxip.descriptors.acsf import CutoffFunction, G2, G3
# Read atomic structure dataset (e.g. water molecules)
structures = RunnerStructureDataset('input.data')
structure = structures[0]
# Define ACSF descriptor for hydrogen element
descriptor = ACSF(element='H')
# Add radial and angular symmetry functions
cfn = CutoffFunction(r_cutoff=12.0, cutoff_type='tanh')
descriptor.add( G2(cfn, eta=0.5, r_shift=0.0), 'H')
descriptor.add( G3(cfn, eta=0.001, zeta=2.0, lambda0=1.0, r_shift=12.0), 'H', 'O')
# Calculate descriptor values
dsc = descriptor(structure)
Outputs:
>> descriptor
ACSF(element='H', num_symmetry_functions=2, max_r_cutoff=12.0)
>> dsc[:3]
Array([[1.9689146e-03, 3.3253896e+00],
[1.9877951e-03, 3.5034575e+00],
[1.5204106e-03, 3.5458338e+00]], dtype=float32)
The gradient of the defined descriptor can be obtained for an atom using the grad method.
>> descriptor.grad(structure, atom_index=0)
Array([[-0.04337483, 0.22992024, -0.04233539],
[-0.07089673, 0.03088031, -0.16785064]], dtype=float32)
Training a machine learning potential
This example illustrates how to quickly create a high-dimensional neural network potential (HDNNP) instance from an in input setting files and train it on input structures. The trained potential can then be used to evaluate the energy and force components for new structures.
from jaxip.datasets import RunnerStructureDataset
from jaxip.potentials import NeuralNetworkPotential
# Atomic data
structures = RunnerStructureDataset("input.data")
structure = structures[0]
# Potential
nnp = NeuralNetworkPotential.create_from("input.nn")
# Descriptor
nnp.fit_scaler(structures)
#nnp.load_scaler()
# Train
nnp.fit_model(structures)
#nnp.load_model()
# Predict the total energy and force components
total_energy = nnp(structure)
force = nnp.compute_force(structure)
Outputs:
>> nnp
NeuralNetworkPotential(atomic_potential={'C': AtomicPotential(
descriptor=ACSF(element='C', num_symmetry_functions=30, r_cutoff_max=12.0),
scaler=DescriptorScaler(scale_type='center', scale_min=0.0, scale_max=1.0),
model=NeuralNetworkModel(hidden_layers=((15, 'tanh'), (15, 'tanh'))),
)})
>> total_energy
Array(-8.16754983, dtype=float64)
>> force
{'C': Array([[-4.1423317e-02, -1.7819289e-02, 6.5731630e-03],
[-5.2372105e-03, 1.3765628e-03, -1.5538651e-05],
[-5.7118265e-03, 6.4179506e-03, 3.0147154e-02],
...], dtype=float32)}
License
This project is licensed under the GNU General Public License (GPL) version 3 - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Hashes for jaxip-0.5.2-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7904c71e7f8ea775df2f1d02afba48c34ed4f26f49a8fe01dbf40a1212ea20e9 |
|
MD5 | 41a496d1de40a241e544bc811eac611d |
|
BLAKE2b-256 | dbfc1a470422c89d005571f2276d9c25b9c4feda3ed799ecc6d3fe769fe9e9f7 |