A Python package for generating ML layers for MLIPs
Project description
KLay — Composable layers for KLIFF & MLIPs
KLay takes the “LEGO-block” approach to machine-learning interatomic potentials (MLIPs): every layer is a first-class citizen you can swap, re-wire and re-train — no copy-pasting opaque model code.
- Works out-of-the-box with KLIFF
- Converts a single YAML file →
torch.fx.GraphModule(ready fortorch.compile, TorchScript, ONNX...) - Or: returns a dict of instantiated layers when you just want the bricks
- Built-in
validate,visualize,export,layers&typesCLI commands - Ships NequIP blocks today; MACE, EGNN & M3GNet on the roadmap
Installation
pip install klay
# or: dev version
pip install git+https://github.com/openkim/klay.git
Usage in Python
from klay.io import load_config
from klay.builder import build_model
cfg = load_config("example/new_model.yaml")
model = build_model(cfg) # GraphModule (because cfg has inputs/outputs)
Need only the bricks?
from klay.builder import build_layers
from klay.io import load_config
layers = build_layers(load_config("example/new_model_layers.yaml"))
print(layers.keys()) # dict_keys([...])
Command-line utilities
| Command | What it does | Most common flags |
|---|---|---|
klay layers |
Pretty table of every registered layer, showing inputs / outputs and from_config signature with coloured required/optional args. |
--type embedding --all |
klay types |
Lists all ModuleCategory values (convolution, embedding, …). |
— |
klay validate |
Cycle detection, missing sources, dangling layers, alias→alias detection, unused outputs plus optional Graphviz diagram. | --allow-dangling -v/--visualize --fmt svg |
klay export |
Builds a full model, TorchScripts it (.pt) or dumps the state_dict (.pth). |
-o out.pt --format state_dict -n 10 |
Examples:
# 1. Inspect all embedding layers
klay layers --type embedding
result:
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer ┃ Inputs ┃ Outputs ┃ from_config args ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩
│ BesselBasis │ x │ y │ r_max, │
│ │ │ │ num_radial_basis=8, │
│ │ │ │ trainable=True │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ BinaryAtomicNumberEncoding │ x │ representation │ - │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ ElectronicConfigurationEncoding │ x │ representation │ - │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ OneHotAtomEncoding │ x │ representation │ num_elems, │
│ │ │ │ input_is_atomic_numb… │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ PolynomialCutoff │ x │ y │ r_max, │
│ │ │ │ polynomial_degree=6 │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ RadialBasisEdgeEncoding │ edge_length │ edge_length_embedded │ r_max, │
│ │ │ │ num_radial_basis=8, │
│ │ │ │ polynomial_degree=6, │
│ │ │ │ radial_basis_trainab… │
│ │ │ │ basis='BesselBasis', │
│ │ │ │ cutoff='PolynomialCu… │
│ │ │ │ basis_kwargs={}, │
│ │ │ │ cutoff_kwargs={} │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ SphericalHarmonicEdgeAttrs │ pos, edge_index, shift │ 0: edge_vec │ lmax=1, │
│ │ │ 1: edge_length │ normalization='compo… │
│ │ │ 2: edge_sh │ │
└─────────────────────────────────┴────────────────────────┴──────────────────────┴───────────────────────┘
# 2. Strict validation + PNG diagram
klay validate example/new_model.yaml -v
result:
Other Examples:
# 3. Script & save the model
klay export example/new_model.yaml
YAML cheat-sheet
model_params:
r_max: 4.0
n_channels: 32
num_elems: 2
model_inputs: # (omit for "library mode")
atomic_numbers: "Tensor (N,)"
positions: "Tensor (N,3)"
edge_index: "Tensor (2,E)"
shifts: "Tensor (E,3)"
model_layers:
element_embedding:
type: OneHotAtomEncoding
config: {num_elems: ${model_params.num_elems}}
edge_feature0:
type: SphericalHarmonicEdgeAttrs
config: {lmax: 1}
output:
0: vec0
1: len0
2: sh0
radial_basis_func:
type: RadialBasisEdgeEncoding
config: {r_max: ${model_params.r_max}}
inputs: {edge_length: len0}
node_features:
type: AtomwiseLinear
config:
irreps_in_block: [{l:0, mul:${model_params.num_elems}}]
irreps_out_block: [{l:0, mul:${model_params.n_channels}}]
conv_shared:
type: ConvNetLayer
config:
hidden_irreps_lmax: 1
edge_sh_lmax: 1
conv_feature_size: ${model_params.n_channels}
conv1:
alias: conv_shared
inputs:
h: node_features
edge_sh: sh0
edge_length_embeddings: radial_basis_func
output_projection:
type: AtomwiseLinear
config:
irreps_in_block:
- {l:0, mul:${model_params.n_channels}}
- {l:1, mul:${model_params.n_channels}}
irreps_out_block:
- {l:0, mul:1}
model_outputs:
energy: output_projection
representation: conv1.h
Roadmap
- EGNN & SEGNN blocks
- Pre-trained M3GNet & GemNet-T embeddings
- ONNX export & torch.export backend
- cuEquivariance/OpenEquivariance backend for the layers
Pull requests welcome.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
klay-0.9.7.tar.gz
(40.2 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
klay-0.9.7-py3-none-any.whl
(55.2 kB
view details)
File details
Details for the file klay-0.9.7.tar.gz.
File metadata
- Download URL: klay-0.9.7.tar.gz
- Upload date:
- Size: 40.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b167ac3021caaa5879e3baea5f85227cfe7c92295131459a99f4490f8afd19dc
|
|
| MD5 |
384392d9938a17bfae22b421677413c8
|
|
| BLAKE2b-256 |
1c6f3bbccc9cd1126f61c498aee1eba277c2c48e802b77f25c211fb2c3a8d0d5
|
File details
Details for the file klay-0.9.7-py3-none-any.whl.
File metadata
- Download URL: klay-0.9.7-py3-none-any.whl
- Upload date:
- Size: 55.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7dbcc731d18c443936da5ddf75a627f88a9c27d89e3a15d85b29b329677f7291
|
|
| MD5 |
fa389a1fe50e3e22dde3e4cb0bc121a3
|
|
| BLAKE2b-256 |
cdadd59c20d8e52557973165d14d58897806f9f07ef56f0db38c542ae36a4fd0
|