Skip to main content

Efficient and batched PyTorch-based neighbour list builder for atomistic simulations

Project description

batch_nl — Batched neighbour-list builder in PyTorch

batch_nl provides a portable, fully vectorised, GPU-accelerated batched neighbour-list construction for periodic atomistic systems using PyTorch. All configurations are processed together in a single tensor batch, to enable fast neighbour search and seamless integration with MLIPs and other batched workflows.

The package is at an early stage, so contributions and suggestions to improve API coverage are very welcome.

Performance benchmarks on RTX 6000

Summary. batch_nl outperforms CPU-based neighbour-list implementations and previous batched GPU-based approaches, but is superseded by NVIDIA’s ALCHEMI implementation.

When you might use batch_nl for performance. As a fallback when ALCHEMI is not available, or on non-NVIDIA accelerators (e.g. AMD/Intel GPUs or TPUs).

What the benchmark figure does not show. batch_nl is more memory-intensive than the O(N) codes and NVIDIA’s O(N²) implementation, due to large intermediate tensors created during broadcasting. Reducing this memory footprint is a target for future versions.

Benchmark timings for batch_nl on RTX 6000

For the full benchmark against currently available neighbour lists, see

examples/benchmark_multiple_structure.ipynb.


Installation

Install from PyPI (recommended)

pip install batch-nl-torch

Install from source (for development)

git clone https://github.com/venkatkapil24/batch_nl.git
cd batch_nl
pip install -e .

Quick start (batched usage)

from ase.build import bulk
from batch_nl import NeighbourList

cutoff = 3.0
device = "cuda:0"

base = bulk("C", "diamond", a=3.57)

configs = [
    base * (2, 2, 2),   # config 0
    base * (3, 3, 3),   # config 1
    base * (4, 4, 4),   # config 2
]

list_of_positions = [
    atoms.positions for atoms in configs
        ]

list_of_cells     = [
    atoms.cell.array for atoms in configs
    ]

nl = NeighbourList(
    list_of_positions=list_of_positions,
    list_of_cells=list_of_cells,
    cutoff=cutoff,
    device=device,
)

nl.load_data()

# Output with global batch indices
output = nl.calculate_neighbourlist(
    use_torch_compile=True
)

r_edges, r_S_int, r_S_cart, r_d = output 

# Or the familiar matscipy-style output
(
    atom_index_list,
    neighbor_index_list,
    int_shift_list,
    cart_shift_list,
    distance_list,
) = nl.get_matscipy_output_from_batch_output(*output)

for cfg in range(len(configs)):
    print(f"Configuration {cfg}: {len(atom_index_list[cfg])} neighbour pairs")

Understanding the output

1. Global batched output (calculate_neighbourlist)

r_edges                 (2, n_edges)
r_S_int                 (n_edges, 3)
r_S_cart                (n_edges, 3)
r_distances             (n_edges,)

Atoms from different configurations are concatenated into a global index:

  • Config 0: atoms [0 … N0−1]
  • Config 1: atoms [N0 … N0+N1−1]
  • Config 2: atoms [N0+N1 … N0+N1+N2−1]

Thus a pair like:

r_edges[:, k] = [42, 99]

means that global atom 42 has global atom 99 as a neighbour under the lattice shift r_S_int[k] (an integer triplet indicating the periodic image), or equivalently r_S_cart[k] (the same shift represented as a Cartesian displacement). This output representation is ideal for constructing graphs or atomic features over a batch of configurations.


2. Per‑configuration (matscipy-style) output

get_matscipy_output_from_batch_output(...) produces lists of length n_configs:

  • atom_index_list[cfg](n_edges_cfg,) local source indices
  • neighbor_index_list[cfg](n_edges_cfg,) local neighbour indices
  • int_shift_list[cfg](n_edges_cfg, 3) integer shifts
  • cart_shift_list[cfg](n_edges_cfg, 3) Cartesian shifts
  • distance_list[cfg](n_edges_cfg,) distances

Local indices always run from 0 … n_atoms_in_cfg−1.
This matches the standard matscipy interface.


batch_nl API Overview

NeighbourList

nl = NeighbourList(
    list_of_positions=list_of_positions,
    list_of_cells=list_of_cells,
    cutoff=cutoff,
    float_dtype=torch.float32,
    device=device,
)

Parameters

list_of_positions      (list[(n_i, 3)])       Cartesian coordinates per configuration
list_of_cells          (list[(3, 3)])         Cell matrices per configuration
cutoff                 (scalar)               Cutoff radius (float, int, or tensor)
float_dtype            (torch.dtype)          One of {float16, float32, float64, bfloat16}
device                 (str | torch.device)   "cpu", "cuda", or explicit device

load_data()

Produces (internal tensors)

batch_positions_tensor   (n_configs, n_max, 3)
batch_mask_tensor        (n_configs, n_max)
batch_cell_tensor        (n_configs, 3, 3)

calculate_neighbourlist(use_torch_compile=True)

Parameters

use_torch_compile        (bool)    Enable torch.compile acceleration

Returns

r_edges                  (2, n_edges)        Global source/target atom indices
r_S_int                  (n_edges, 3)        Integer lattice shifts
r_S_cart                 (n_edges, 3)        Cartesian lattice shifts
r_distances              (n_edges,)          Pair distances

get_matscipy_output_from_batch_output(...)

Converts global indexing → local indexing

Parameters

r_edges                  (2, n_edges)
r_integer_lattice_shifts (n_edges, 3)
r_cartesian_lattice_shifts (n_edges, 3)
r_distances              (n_edges,)
device                   ("cpu" | "cuda" | None)

Returns (lists, length = n_configs)

atom_index_list[c]        (n_edges_c,)       Local source atom indices
neighbor_index_list[c]    (n_edges_c,)       Local neighbour atom indices
int_shift_list[c]         (n_edges_c, 3)     Integer shifts per pair
cart_shift_list[c]        (n_edges_c, 3)     Cartesian shifts per pair
distance_list[c]          (n_edges_c,)       Distances per pair

Testing

Tests include unit cells with varying skews taken from torch-sim with matscipy as the ground truth.

pytest

License

batch_nl is distributed under the Apache License 2.0.
Users are free to use, modify, and redistribute the software, provided attribution and license terms are preserved.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

batch_nl_torch-0.1.1.tar.gz (15.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

batch_nl_torch-0.1.1-py3-none-any.whl (12.3 kB view details)

Uploaded Python 3

File details

Details for the file batch_nl_torch-0.1.1.tar.gz.

File metadata

  • Download URL: batch_nl_torch-0.1.1.tar.gz
  • Upload date:
  • Size: 15.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for batch_nl_torch-0.1.1.tar.gz
Algorithm Hash digest
SHA256 0110b7b975c14ca792a59681dbdbe986ce195297624f9135e1a6023938edb1b7
MD5 3cf05ba08b985ead9a4d44adf4971a6f
BLAKE2b-256 4a2301eeec8086046fd3c7ea53513e62a63d98f1ed1d2c2d0d05e80889dc4feb

See more details on using hashes here.

File details

Details for the file batch_nl_torch-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: batch_nl_torch-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 12.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for batch_nl_torch-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 91ef6aeb09213776c4b4fbe0d6e21e6ffd3f8edf542a6c94cdacb64010d54338
MD5 9de1e9daaf83444117291055e633251b
BLAKE2b-256 30caf12a070cf1d202aa6235d5860776036739f328c46fb386a341241ad83d08

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page