Skip to main content

Efficient and batched PyTorch-based neighbour list builder for atomistic simulations

Project description

batch_nl — Batched neighbour-list builder in PyTorch

batch_nl provides fully vectorised, GPU-accelerated batched neighbour-list construction for periodic atomistic systems using PyTorch. All configurations are processed together in a single tensor batch, to enable fast neighbour search and seamless integration with MLIPs and other batched workflows.

The package is in an early stage, so contributions and suggestions for improving API coverage are very welcome.

Performance benchmarks on RTX 6000

A test for batches of structures containing 128 atoms. Benchmark timings for batch_nl on RTX 6000

For the full benchmark against currently available neighbour lists, see

examples/benchmark_multiple_structure.ipynb.


Installation

One-line install (from GitHub)

pip install git+https://github.com/venkatkapil24/batch_nl.git

Install from source (recommended for development)

git clone https://github.com/venkatkapil24/batch_nl.git
cd batch_nl
pip install -e .

Quick start (batched usage)

from ase.build import bulk
from batch_nl import NeighbourList

cutoff = 3.0
device = "cuda:0"

base = bulk("C", "diamond", a=3.57)

configs = [
    base * (2, 2, 2),   # config 0
    base * (3, 3, 3),   # config 1
    base * (4, 4, 4),   # config 2
]

list_of_positions = [
    atoms.positions for atoms in configs
        ]

list_of_cells     = [
    atoms.cell.array for atoms in configs
    ]

nl = NeighbourList(
    list_of_positions=list_of_positions,
    list_of_cells=list_of_cells,
    cutoff=cutoff,
    device=device,
)

nl.load_data()

# Output with global batch indices
output = nl.calculate_neighbourlist(
    use_torch_compile=True
)

r_edges, r_S_int, r_S_cart, r_d = output 

# Or the familiar matscipy-style output
(
    atom_index_list,
    neighbor_index_list,
    int_shift_list,
    cart_shift_list,
    distance_list,
) = nl.get_matscipy_output_from_batch_output(*output)

for cfg in range(len(configs)):
    print(f"Configuration {cfg}: {len(atom_index_list[cfg])} neighbour pairs")

Understanding the output

1. Global batched output (calculate_neighbourlist)

r_edges                 (2, n_edges)
r_S_int                 (n_edges, 3)
r_S_cart                (n_edges, 3)
r_distances             (n_edges,)

Atoms from different configurations are concatenated into a global index:

  • Config 0: atoms [0 … N0−1]
  • Config 1: atoms [N0 … N0+N1−1]
  • Config 2: atoms [N0+N1 … N0+N1+N2−1]

Thus a pair like:

r_edges[:, k] = [42, 99]

means that global atom 42 has global atom 99 as a neighbour under the lattice shift r_S_int[k] (an integer triplet indicating the periodic image), or equivalently r_S_cart[k] (the same shift represented as a Cartesian displacement). This output representation is ideal for constructing graphs or atomic features over a batch of configurations.


2. Per‑configuration (matscipy-style) output

get_matscipy_output_from_batch_output(...) produces lists of length n_configs:

  • atom_index_list[cfg](n_edges_cfg,) local source indices
  • neighbor_index_list[cfg](n_edges_cfg,) local neighbour indices
  • int_shift_list[cfg](n_edges_cfg, 3) integer shifts
  • cart_shift_list[cfg](n_edges_cfg, 3) Cartesian shifts
  • distance_list[cfg](n_edges_cfg,) distances

Local indices always run from 0 … n_atoms_in_cfg−1.
This matches the standard matscipy interface.


batch_nl API Overview

NeighbourList

nl = NeighbourList(
    list_of_positions=list_of_positions,
    list_of_cells=list_of_cells,
    cutoff=cutoff,
    float_dtype=torch.float32,
    device=device,
)

Parameters

list_of_positions      (list[(n_i, 3)])       Cartesian coordinates per configuration
list_of_cells          (list[(3, 3)])         Cell matrices per configuration
cutoff                 (scalar)               Cutoff radius (float, int, or tensor)
float_dtype            (torch.dtype)          One of {float16, float32, float64, bfloat16}
device                 (str | torch.device)   "cpu", "cuda", or explicit device

load_data()

Produces (internal tensors)

batch_positions_tensor   (n_configs, n_max, 3)
batch_mask_tensor        (n_configs, n_max)
batch_cell_tensor        (n_configs, 3, 3)

calculate_neighbourlist(use_torch_compile=True)

Parameters

use_torch_compile        (bool)    Enable torch.compile acceleration

Returns

r_edges                  (2, n_edges)        Global source/target atom indices
r_S_int                  (n_edges, 3)        Integer lattice shifts
r_S_cart                 (n_edges, 3)        Cartesian lattice shifts
r_distances              (n_edges,)          Pair distances

get_matscipy_output_from_batch_output(...)

Converts global indexing → local indexing

Parameters

r_edges                  (2, n_edges)
r_integer_lattice_shifts (n_edges, 3)
r_cartesian_lattice_shifts (n_edges, 3)
r_distances              (n_edges,)
device                   ("cpu" | "cuda" | None)

Returns (lists, length = n_configs)

atom_index_list[c]        (n_edges_c,)       Local source atom indices
neighbor_index_list[c]    (n_edges_c,)       Local neighbour atom indices
int_shift_list[c]         (n_edges_c, 3)     Integer shifts per pair
cart_shift_list[c]        (n_edges_c, 3)     Cartesian shifts per pair
distance_list[c]          (n_edges_c,)       Distances per pair

Testing

Tests include unit cells with varying skews taken from torch-sim with matscipy as the ground truth.

pytest

License

batch_nl is distributed under the Apache License 2.0.
Users are free to use, modify, and redistribute the software, provided attribution and license terms are preserved.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

batch_nl_torch-0.1.0.tar.gz (14.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

batch_nl_torch-0.1.0-py3-none-any.whl (12.1 kB view details)

Uploaded Python 3

File details

Details for the file batch_nl_torch-0.1.0.tar.gz.

File metadata

  • Download URL: batch_nl_torch-0.1.0.tar.gz
  • Upload date:
  • Size: 14.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for batch_nl_torch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 da7cae1a353d35f63b2750451af33dd838797b56b9aab0707b5350b659318091
MD5 86881402948e6d3f5799e691b74e3267
BLAKE2b-256 b432ec4322576ec71fac06be868f5dca2b6407f6a8ec43c314d2a271fa94e173

See more details on using hashes here.

File details

Details for the file batch_nl_torch-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: batch_nl_torch-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 12.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for batch_nl_torch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 94e0204d81da9390b9c8bc02f6944e20d212c7f0a2bf9b681a925bc2ffe4d150
MD5 85e6b270c7d1abb3a16fb8bba8a60699
BLAKE2b-256 1fe29d70d70c0cd85698b2972695e4a743d7f8d6d392bdd2a2f72191606b493f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page